mirror of
https://github.com/elastic/eland.git
synced 2025-07-11 00:02:14 +08:00
Add override option to specify the model's max input size(#674)
If the max input size cannot be found in the configuration the user can specify it as a parameter to the eland_import_hub_model script
This commit is contained in:
parent
9b335315bb
commit
5d34dc3cc4
@ -141,6 +141,20 @@ def get_arg_parser():
|
|||||||
help="String to prepend to model input at search",
|
help="String to prepend to model input at search",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--max-model-input-length",
|
||||||
|
required=False,
|
||||||
|
default=None,
|
||||||
|
help="""Set the model's max input length.
|
||||||
|
Usually the max input length is derived from the Hugging Face
|
||||||
|
model confifguation. Use this option to explicity set the model's
|
||||||
|
max input length if the value can not be found in the Hugging
|
||||||
|
Face configuration. Max input length should never exceed the
|
||||||
|
model's true max length, setting a smaller max length is valid.
|
||||||
|
""",
|
||||||
|
type=int,
|
||||||
|
)
|
||||||
|
|
||||||
return parser
|
return parser
|
||||||
|
|
||||||
|
|
||||||
@ -220,6 +234,7 @@ def main():
|
|||||||
SUPPORTED_TASK_TYPES,
|
SUPPORTED_TASK_TYPES,
|
||||||
TaskTypeError,
|
TaskTypeError,
|
||||||
TransformerModel,
|
TransformerModel,
|
||||||
|
UnknownModelInputSizeError,
|
||||||
)
|
)
|
||||||
except ModuleNotFoundError as e:
|
except ModuleNotFoundError as e:
|
||||||
logger.error(
|
logger.error(
|
||||||
@ -259,6 +274,7 @@ def main():
|
|||||||
quantize=args.quantize,
|
quantize=args.quantize,
|
||||||
ingest_prefix=args.ingest_prefix,
|
ingest_prefix=args.ingest_prefix,
|
||||||
search_prefix=args.search_prefix,
|
search_prefix=args.search_prefix,
|
||||||
|
max_model_input_size=args.max_model_input_length,
|
||||||
)
|
)
|
||||||
model_path, config, vocab_path = tm.save(tmp_dir)
|
model_path, config, vocab_path = tm.save(tmp_dir)
|
||||||
except TaskTypeError as err:
|
except TaskTypeError as err:
|
||||||
@ -266,6 +282,12 @@ def main():
|
|||||||
f"Failed to get model for task type, please provide valid task type via '--task-type' parameter. Caused by {err}"
|
f"Failed to get model for task type, please provide valid task type via '--task-type' parameter. Caused by {err}"
|
||||||
)
|
)
|
||||||
exit(1)
|
exit(1)
|
||||||
|
except UnknownModelInputSizeError as err:
|
||||||
|
logger.error(
|
||||||
|
f"""Could not automatically determine the model's max input size from the model configuration.
|
||||||
|
Please provde the max input size via the --max-model-input-length parameter. Caused by {err}"""
|
||||||
|
)
|
||||||
|
exit(1)
|
||||||
|
|
||||||
ptm = PyTorchModel(
|
ptm = PyTorchModel(
|
||||||
es, args.es_model_id if args.es_model_id else tm.elasticsearch_model_id()
|
es, args.es_model_id if args.es_model_id else tm.elasticsearch_model_id()
|
||||||
|
@ -31,7 +31,10 @@ from eland.ml.pytorch.nlp_ml_model import (
|
|||||||
ZeroShotClassificationInferenceOptions,
|
ZeroShotClassificationInferenceOptions,
|
||||||
)
|
)
|
||||||
from eland.ml.pytorch.traceable_model import TraceableModel # noqa: F401
|
from eland.ml.pytorch.traceable_model import TraceableModel # noqa: F401
|
||||||
from eland.ml.pytorch.transformers import task_type_from_model_config
|
from eland.ml.pytorch.transformers import (
|
||||||
|
UnknownModelInputSizeError,
|
||||||
|
task_type_from_model_config,
|
||||||
|
)
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"PyTorchModel",
|
"PyTorchModel",
|
||||||
@ -49,4 +52,5 @@ __all__ = [
|
|||||||
"TextSimilarityInferenceOptions",
|
"TextSimilarityInferenceOptions",
|
||||||
"ZeroShotClassificationInferenceOptions",
|
"ZeroShotClassificationInferenceOptions",
|
||||||
"task_type_from_model_config",
|
"task_type_from_model_config",
|
||||||
|
"UnknownModelInputSizeError",
|
||||||
]
|
]
|
||||||
|
@ -130,6 +130,10 @@ class TaskTypeError(Exception):
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class UnknownModelInputSizeError(Exception):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
def task_type_from_model_config(model_config: PretrainedConfig) -> Optional[str]:
|
def task_type_from_model_config(model_config: PretrainedConfig) -> Optional[str]:
|
||||||
if model_config.architectures is None:
|
if model_config.architectures is None:
|
||||||
if model_config.name_or_path.startswith("sentence-transformers/"):
|
if model_config.name_or_path.startswith("sentence-transformers/"):
|
||||||
@ -598,6 +602,7 @@ class TransformerModel:
|
|||||||
access_token: Optional[str] = None,
|
access_token: Optional[str] = None,
|
||||||
ingest_prefix: Optional[str] = None,
|
ingest_prefix: Optional[str] = None,
|
||||||
search_prefix: Optional[str] = None,
|
search_prefix: Optional[str] = None,
|
||||||
|
max_model_input_size: Optional[int] = None,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Loads a model from the Hugging Face repository or local file and creates
|
Loads a model from the Hugging Face repository or local file and creates
|
||||||
@ -629,6 +634,12 @@ class TransformerModel:
|
|||||||
|
|
||||||
search_prefix: Optional[str]
|
search_prefix: Optional[str]
|
||||||
Prefix string to prepend to input at search
|
Prefix string to prepend to input at search
|
||||||
|
|
||||||
|
max_model_input_size: Optional[int]
|
||||||
|
The max model input size counted in tokens.
|
||||||
|
Usually this value should be extracted from the model configuration
|
||||||
|
but if that is not possible or the data is missing it can be
|
||||||
|
explicitly set with this parameter.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
self._model_id = model_id
|
self._model_id = model_id
|
||||||
@ -636,6 +647,7 @@ class TransformerModel:
|
|||||||
self._task_type = task_type.replace("-", "_")
|
self._task_type = task_type.replace("-", "_")
|
||||||
self._ingest_prefix = ingest_prefix
|
self._ingest_prefix = ingest_prefix
|
||||||
self._search_prefix = search_prefix
|
self._search_prefix = search_prefix
|
||||||
|
self._max_model_input_size = max_model_input_size
|
||||||
|
|
||||||
# load Hugging Face model and tokenizer
|
# load Hugging Face model and tokenizer
|
||||||
# use padding in the tokenizer to ensure max length sequences are used for tracing (at call time)
|
# use padding in the tokenizer to ensure max length sequences are used for tracing (at call time)
|
||||||
@ -685,7 +697,10 @@ class TransformerModel:
|
|||||||
return vocab_obj
|
return vocab_obj
|
||||||
|
|
||||||
def _create_tokenization_config(self) -> NlpTokenizationConfig:
|
def _create_tokenization_config(self) -> NlpTokenizationConfig:
|
||||||
_max_sequence_length = self._find_max_sequence_length()
|
if self._max_model_input_size:
|
||||||
|
_max_sequence_length = self._max_model_input_size
|
||||||
|
else:
|
||||||
|
_max_sequence_length = self._find_max_sequence_length()
|
||||||
|
|
||||||
if isinstance(self._tokenizer, transformers.MPNetTokenizer):
|
if isinstance(self._tokenizer, transformers.MPNetTokenizer):
|
||||||
return NlpMPNetTokenizationConfig(
|
return NlpMPNetTokenizationConfig(
|
||||||
@ -724,25 +739,25 @@ class TransformerModel:
|
|||||||
# Sometimes the max_... values are present but contain
|
# Sometimes the max_... values are present but contain
|
||||||
# a random or very large value.
|
# a random or very large value.
|
||||||
REASONABLE_MAX_LENGTH = 8192
|
REASONABLE_MAX_LENGTH = 8192
|
||||||
max_len = getattr(self._tokenizer, "max_model_input_sizes", dict()).get(
|
|
||||||
self._model_id
|
|
||||||
)
|
|
||||||
if max_len is not None and max_len < REASONABLE_MAX_LENGTH:
|
|
||||||
return int(max_len)
|
|
||||||
|
|
||||||
max_len = getattr(self._tokenizer, "model_max_length", None)
|
max_len = getattr(self._tokenizer, "model_max_length", None)
|
||||||
if max_len is not None and max_len < REASONABLE_MAX_LENGTH:
|
if max_len is not None and max_len < REASONABLE_MAX_LENGTH:
|
||||||
return int(max_len)
|
return int(max_len)
|
||||||
|
|
||||||
model_config = getattr(self._traceable_model._model, "config", None)
|
max_sizes = getattr(self._tokenizer, "max_model_input_sizes", dict())
|
||||||
if model_config is None:
|
max_len = max_sizes.get(self._model_id)
|
||||||
raise ValueError("Cannot determine model max input length")
|
|
||||||
|
|
||||||
max_len = getattr(model_config, "max_position_embeddings", None)
|
|
||||||
if max_len is not None and max_len < REASONABLE_MAX_LENGTH:
|
if max_len is not None and max_len < REASONABLE_MAX_LENGTH:
|
||||||
return int(max_len)
|
return int(max_len)
|
||||||
|
|
||||||
raise ValueError("Cannot determine model max input length")
|
if max_sizes:
|
||||||
|
# The model id wasn't found in the max sizes dict but
|
||||||
|
# if all the values correspond then take that value
|
||||||
|
sizes = {size for size in max_sizes.values()}
|
||||||
|
if len(sizes) == 1:
|
||||||
|
max_len = sizes.pop()
|
||||||
|
if max_len is not None and max_len < REASONABLE_MAX_LENGTH:
|
||||||
|
return int(max_len)
|
||||||
|
|
||||||
|
raise UnknownModelInputSizeError("Cannot determine model max input length")
|
||||||
|
|
||||||
def _create_config(
|
def _create_config(
|
||||||
self, es_version: Optional[Tuple[int, int, int]]
|
self, es_version: Optional[Tuple[int, int, int]]
|
||||||
|
@ -149,6 +149,14 @@ if HAS_PYTORCH and HAS_SKLEARN and HAS_TRANSFORMERS:
|
|||||||
1024,
|
1024,
|
||||||
None,
|
None,
|
||||||
),
|
),
|
||||||
|
(
|
||||||
|
"cardiffnlp/twitter-roberta-base-sentiment",
|
||||||
|
"text_classification",
|
||||||
|
TextClassificationInferenceOptions,
|
||||||
|
NlpRobertaTokenizationConfig,
|
||||||
|
512,
|
||||||
|
None,
|
||||||
|
),
|
||||||
]
|
]
|
||||||
else:
|
else:
|
||||||
MODEL_CONFIGURATIONS = []
|
MODEL_CONFIGURATIONS = []
|
||||||
@ -235,3 +243,16 @@ class TestModelConfguration:
|
|||||||
ingest_prefix="INGEST:",
|
ingest_prefix="INGEST:",
|
||||||
search_prefix="SEARCH:",
|
search_prefix="SEARCH:",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def test_model_config_with_user_specified_input_length(self):
|
||||||
|
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||||
|
tm = TransformerModel(
|
||||||
|
model_id="sentence-transformers/all-distilroberta-v1",
|
||||||
|
task_type="text_embedding",
|
||||||
|
es_version=(8, 13, 0),
|
||||||
|
quantize=False,
|
||||||
|
max_model_input_size=213,
|
||||||
|
)
|
||||||
|
_, config, _ = tm.save(tmp_dir)
|
||||||
|
tokenization = config.inference_config.tokenization
|
||||||
|
assert tokenization.max_sequence_length == 213
|
||||||
|
Loading…
x
Reference in New Issue
Block a user