mirror of
https://github.com/elastic/eland.git
synced 2025-07-11 00:02:14 +08:00
Add override option to specify the model's max input size(#674)
If the max input size cannot be found in the configuration the user can specify it as a parameter to the eland_import_hub_model script
This commit is contained in:
parent
9b335315bb
commit
5d34dc3cc4
@ -141,6 +141,20 @@ def get_arg_parser():
|
||||
help="String to prepend to model input at search",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--max-model-input-length",
|
||||
required=False,
|
||||
default=None,
|
||||
help="""Set the model's max input length.
|
||||
Usually the max input length is derived from the Hugging Face
|
||||
model confifguation. Use this option to explicity set the model's
|
||||
max input length if the value can not be found in the Hugging
|
||||
Face configuration. Max input length should never exceed the
|
||||
model's true max length, setting a smaller max length is valid.
|
||||
""",
|
||||
type=int,
|
||||
)
|
||||
|
||||
return parser
|
||||
|
||||
|
||||
@ -220,6 +234,7 @@ def main():
|
||||
SUPPORTED_TASK_TYPES,
|
||||
TaskTypeError,
|
||||
TransformerModel,
|
||||
UnknownModelInputSizeError,
|
||||
)
|
||||
except ModuleNotFoundError as e:
|
||||
logger.error(
|
||||
@ -259,6 +274,7 @@ def main():
|
||||
quantize=args.quantize,
|
||||
ingest_prefix=args.ingest_prefix,
|
||||
search_prefix=args.search_prefix,
|
||||
max_model_input_size=args.max_model_input_length,
|
||||
)
|
||||
model_path, config, vocab_path = tm.save(tmp_dir)
|
||||
except TaskTypeError as err:
|
||||
@ -266,6 +282,12 @@ def main():
|
||||
f"Failed to get model for task type, please provide valid task type via '--task-type' parameter. Caused by {err}"
|
||||
)
|
||||
exit(1)
|
||||
except UnknownModelInputSizeError as err:
|
||||
logger.error(
|
||||
f"""Could not automatically determine the model's max input size from the model configuration.
|
||||
Please provde the max input size via the --max-model-input-length parameter. Caused by {err}"""
|
||||
)
|
||||
exit(1)
|
||||
|
||||
ptm = PyTorchModel(
|
||||
es, args.es_model_id if args.es_model_id else tm.elasticsearch_model_id()
|
||||
|
@ -31,7 +31,10 @@ from eland.ml.pytorch.nlp_ml_model import (
|
||||
ZeroShotClassificationInferenceOptions,
|
||||
)
|
||||
from eland.ml.pytorch.traceable_model import TraceableModel # noqa: F401
|
||||
from eland.ml.pytorch.transformers import task_type_from_model_config
|
||||
from eland.ml.pytorch.transformers import (
|
||||
UnknownModelInputSizeError,
|
||||
task_type_from_model_config,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"PyTorchModel",
|
||||
@ -49,4 +52,5 @@ __all__ = [
|
||||
"TextSimilarityInferenceOptions",
|
||||
"ZeroShotClassificationInferenceOptions",
|
||||
"task_type_from_model_config",
|
||||
"UnknownModelInputSizeError",
|
||||
]
|
||||
|
@ -130,6 +130,10 @@ class TaskTypeError(Exception):
|
||||
pass
|
||||
|
||||
|
||||
class UnknownModelInputSizeError(Exception):
|
||||
pass
|
||||
|
||||
|
||||
def task_type_from_model_config(model_config: PretrainedConfig) -> Optional[str]:
|
||||
if model_config.architectures is None:
|
||||
if model_config.name_or_path.startswith("sentence-transformers/"):
|
||||
@ -598,6 +602,7 @@ class TransformerModel:
|
||||
access_token: Optional[str] = None,
|
||||
ingest_prefix: Optional[str] = None,
|
||||
search_prefix: Optional[str] = None,
|
||||
max_model_input_size: Optional[int] = None,
|
||||
):
|
||||
"""
|
||||
Loads a model from the Hugging Face repository or local file and creates
|
||||
@ -629,6 +634,12 @@ class TransformerModel:
|
||||
|
||||
search_prefix: Optional[str]
|
||||
Prefix string to prepend to input at search
|
||||
|
||||
max_model_input_size: Optional[int]
|
||||
The max model input size counted in tokens.
|
||||
Usually this value should be extracted from the model configuration
|
||||
but if that is not possible or the data is missing it can be
|
||||
explicitly set with this parameter.
|
||||
"""
|
||||
|
||||
self._model_id = model_id
|
||||
@ -636,6 +647,7 @@ class TransformerModel:
|
||||
self._task_type = task_type.replace("-", "_")
|
||||
self._ingest_prefix = ingest_prefix
|
||||
self._search_prefix = search_prefix
|
||||
self._max_model_input_size = max_model_input_size
|
||||
|
||||
# load Hugging Face model and tokenizer
|
||||
# use padding in the tokenizer to ensure max length sequences are used for tracing (at call time)
|
||||
@ -685,7 +697,10 @@ class TransformerModel:
|
||||
return vocab_obj
|
||||
|
||||
def _create_tokenization_config(self) -> NlpTokenizationConfig:
|
||||
_max_sequence_length = self._find_max_sequence_length()
|
||||
if self._max_model_input_size:
|
||||
_max_sequence_length = self._max_model_input_size
|
||||
else:
|
||||
_max_sequence_length = self._find_max_sequence_length()
|
||||
|
||||
if isinstance(self._tokenizer, transformers.MPNetTokenizer):
|
||||
return NlpMPNetTokenizationConfig(
|
||||
@ -724,25 +739,25 @@ class TransformerModel:
|
||||
# Sometimes the max_... values are present but contain
|
||||
# a random or very large value.
|
||||
REASONABLE_MAX_LENGTH = 8192
|
||||
max_len = getattr(self._tokenizer, "max_model_input_sizes", dict()).get(
|
||||
self._model_id
|
||||
)
|
||||
if max_len is not None and max_len < REASONABLE_MAX_LENGTH:
|
||||
return int(max_len)
|
||||
|
||||
max_len = getattr(self._tokenizer, "model_max_length", None)
|
||||
if max_len is not None and max_len < REASONABLE_MAX_LENGTH:
|
||||
return int(max_len)
|
||||
|
||||
model_config = getattr(self._traceable_model._model, "config", None)
|
||||
if model_config is None:
|
||||
raise ValueError("Cannot determine model max input length")
|
||||
|
||||
max_len = getattr(model_config, "max_position_embeddings", None)
|
||||
max_sizes = getattr(self._tokenizer, "max_model_input_sizes", dict())
|
||||
max_len = max_sizes.get(self._model_id)
|
||||
if max_len is not None and max_len < REASONABLE_MAX_LENGTH:
|
||||
return int(max_len)
|
||||
|
||||
raise ValueError("Cannot determine model max input length")
|
||||
if max_sizes:
|
||||
# The model id wasn't found in the max sizes dict but
|
||||
# if all the values correspond then take that value
|
||||
sizes = {size for size in max_sizes.values()}
|
||||
if len(sizes) == 1:
|
||||
max_len = sizes.pop()
|
||||
if max_len is not None and max_len < REASONABLE_MAX_LENGTH:
|
||||
return int(max_len)
|
||||
|
||||
raise UnknownModelInputSizeError("Cannot determine model max input length")
|
||||
|
||||
def _create_config(
|
||||
self, es_version: Optional[Tuple[int, int, int]]
|
||||
|
@ -149,6 +149,14 @@ if HAS_PYTORCH and HAS_SKLEARN and HAS_TRANSFORMERS:
|
||||
1024,
|
||||
None,
|
||||
),
|
||||
(
|
||||
"cardiffnlp/twitter-roberta-base-sentiment",
|
||||
"text_classification",
|
||||
TextClassificationInferenceOptions,
|
||||
NlpRobertaTokenizationConfig,
|
||||
512,
|
||||
None,
|
||||
),
|
||||
]
|
||||
else:
|
||||
MODEL_CONFIGURATIONS = []
|
||||
@ -235,3 +243,16 @@ class TestModelConfguration:
|
||||
ingest_prefix="INGEST:",
|
||||
search_prefix="SEARCH:",
|
||||
)
|
||||
|
||||
def test_model_config_with_user_specified_input_length(self):
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
tm = TransformerModel(
|
||||
model_id="sentence-transformers/all-distilroberta-v1",
|
||||
task_type="text_embedding",
|
||||
es_version=(8, 13, 0),
|
||||
quantize=False,
|
||||
max_model_input_size=213,
|
||||
)
|
||||
_, config, _ = tm.save(tmp_dir)
|
||||
tokenization = config.inference_config.tokenization
|
||||
assert tokenization.max_sequence_length == 213
|
||||
|
Loading…
x
Reference in New Issue
Block a user