diff --git a/eland/common.py b/eland/common.py index bd83c84..5f50625 100644 --- a/eland/common.py +++ b/eland/common.py @@ -311,7 +311,7 @@ def ensure_es_client( if isinstance(es_client, tuple): es_client = list(es_client) if not isinstance(es_client, Elasticsearch): - es_client = Elasticsearch(es_client) # type: ignore[arg-type] + es_client = Elasticsearch(es_client) return es_client diff --git a/eland/ml/pytorch/__init__.py b/eland/ml/pytorch/__init__.py index 4862028..465e80f 100644 --- a/eland/ml/pytorch/__init__.py +++ b/eland/ml/pytorch/__init__.py @@ -17,10 +17,17 @@ from eland.ml.pytorch._pytorch_model import PyTorchModel # noqa: F401 from eland.ml.pytorch.nlp_ml_model import ( + FillMaskInferenceOptions, + NerInferenceOptions, NlpBertTokenizationConfig, NlpMPNetTokenizationConfig, NlpRobertaTokenizationConfig, NlpTrainedModelConfig, + QuestionAnsweringInferenceOptions, + TextClassificationInferenceOptions, + TextEmbeddingInferenceOptions, + TextSimilarityInferenceOptions, + ZeroShotClassificationInferenceOptions, ) from eland.ml.pytorch.traceable_model import TraceableModel # noqa: F401 from eland.ml.pytorch.transformers import task_type_from_model_config @@ -28,10 +35,17 @@ from eland.ml.pytorch.transformers import task_type_from_model_config __all__ = [ "PyTorchModel", "TraceableModel", + "FillMaskInferenceOptions", + "NerInferenceOptions", "NlpTrainedModelConfig", "NlpBertTokenizationConfig", "NlpRobertaTokenizationConfig", "NlpXLMRobertaTokenizationConfig", "NlpMPNetTokenizationConfig", + "QuestionAnsweringInferenceOptions", + "TextClassificationInferenceOptions", + "TextEmbeddingInferenceOptions", + "TextSimilarityInferenceOptions", + "ZeroShotClassificationInferenceOptions", "task_type_from_model_config", ] diff --git a/eland/ml/pytorch/nlp_ml_model.py b/eland/ml/pytorch/nlp_ml_model.py index dc5e055..4a7284d 100644 --- a/eland/ml/pytorch/nlp_ml_model.py +++ b/eland/ml/pytorch/nlp_ml_model.py @@ -317,11 +317,9 @@ class NlpTrainedModelConfig: input: TrainedModelInput = TrainedModelInput(field_names=["text_field"]), metadata: t.Optional[dict] = None, model_type: t.Union["t.Literal['pytorch']", str] = "pytorch", - default_field_map: t.Optional[t.Mapping[str, str]] = None, tags: t.Optional[t.Union[t.List[str], t.Tuple[str, ...]]] = None, ): self.tags = tags - self.default_field_map = default_field_map self.description = description self.inference_config = inference_config self.input = input diff --git a/eland/ml/pytorch/transformers.py b/eland/ml/pytorch/transformers.py index 22bdf52..fb15bec 100644 --- a/eland/ml/pytorch/transformers.py +++ b/eland/ml/pytorch/transformers.py @@ -664,27 +664,23 @@ class TransformerModel: return vocab_obj def _create_tokenization_config(self) -> NlpTokenizationConfig: + _max_sequence_length = self._find_max_sequence_length() + if isinstance(self._tokenizer, transformers.MPNetTokenizer): return NlpMPNetTokenizationConfig( do_lower_case=getattr(self._tokenizer, "do_lower_case", None), - max_sequence_length=getattr( - self._tokenizer, "max_model_input_sizes", dict() - ).get(self._model_id), + max_sequence_length=_max_sequence_length, ) elif isinstance( self._tokenizer, (transformers.RobertaTokenizer, transformers.BartTokenizer) ): return NlpRobertaTokenizationConfig( add_prefix_space=getattr(self._tokenizer, "add_prefix_space", None), - max_sequence_length=getattr( - self._tokenizer, "max_model_input_sizes", dict() - ).get(self._model_id), + max_sequence_length=_max_sequence_length, ) elif isinstance(self._tokenizer, transformers.XLMRobertaTokenizer): return NlpXLMRobertaTokenizationConfig( - max_sequence_length=getattr( - self._tokenizer, "max_model_input_sizes", dict() - ).get(self._model_id), + max_sequence_length=_max_sequence_length ) else: japanese_morphological_tokenizers = ["mecab"] @@ -695,18 +691,38 @@ class TransformerModel: ): return NlpBertJapaneseTokenizationConfig( do_lower_case=getattr(self._tokenizer, "do_lower_case", None), - max_sequence_length=getattr( - self._tokenizer, "max_model_input_sizes", dict() - ).get(self._model_id), + max_sequence_length=_max_sequence_length, ) else: return NlpBertTokenizationConfig( do_lower_case=getattr(self._tokenizer, "do_lower_case", None), - max_sequence_length=getattr( - self._tokenizer, "max_model_input_sizes", dict() - ).get(self._model_id), + max_sequence_length=_max_sequence_length, ) + def _find_max_sequence_length(self) -> int: + # Sometimes the max_... values are present but contain + # a random or very large value. + REASONABLE_MAX_LENGTH = 8192 + max_len = getattr(self._tokenizer, "max_model_input_sizes", dict()).get( + self._model_id + ) + if max_len is not None and max_len < REASONABLE_MAX_LENGTH: + return int(max_len) + + max_len = getattr(self._tokenizer, "model_max_length", None) + if max_len is not None and max_len < REASONABLE_MAX_LENGTH: + return int(max_len) + + model_config = getattr(self._traceable_model._model, "config", None) + if model_config is None: + raise ValueError("Cannot determine model max input length") + + max_len = getattr(model_config, "max_position_embeddings", None) + if max_len is not None and max_len < REASONABLE_MAX_LENGTH: + return int(max_len) + + raise ValueError("Cannot determine model max input length") + def _create_config( self, es_version: Optional[Tuple[int, int, int]] ) -> NlpTrainedModelConfig: @@ -756,7 +772,7 @@ class TransformerModel: ), ) - def _create_traceable_model(self) -> TraceableModel: + def _create_traceable_model(self) -> _TransformerTraceableModel: if self._task_type == "auto": model = transformers.AutoModel.from_pretrained( self._model_id, token=self._access_token, torchscript=True diff --git a/tests/ml/pytorch/test_pytorch_model_config.py b/tests/ml/pytorch/test_pytorch_model_config.py new file mode 100644 index 0000000..beac170 --- /dev/null +++ b/tests/ml/pytorch/test_pytorch_model_config.py @@ -0,0 +1,216 @@ +# Licensed to Elasticsearch B.V. under one or more contributor +# license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright +# ownership. Elasticsearch B.V. licenses this file to you under +# the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import tempfile + +import pytest + +try: + import sklearn # noqa: F401 + + HAS_SKLEARN = True +except ImportError: + HAS_SKLEARN = False + +try: + from eland.ml.pytorch.transformers import TransformerModel + + HAS_TRANSFORMERS = True +except ImportError: + HAS_TRANSFORMERS = False + +try: + import torch # noqa: F401 + + from eland.ml.pytorch import ( + FillMaskInferenceOptions, + NerInferenceOptions, + NlpBertTokenizationConfig, + NlpMPNetTokenizationConfig, + NlpRobertaTokenizationConfig, + QuestionAnsweringInferenceOptions, + TextClassificationInferenceOptions, + TextEmbeddingInferenceOptions, + TextSimilarityInferenceOptions, + ZeroShotClassificationInferenceOptions, + ) + + HAS_PYTORCH = True +except ImportError: + HAS_PYTORCH = False + + +from tests import ES_VERSION + +pytestmark = [ + pytest.mark.skipif( + ES_VERSION < (8, 7, 0), + reason="Eland uses Pytorch 1.13.1, versions of Elasticsearch prior to 8.7.0 are incompatible with PyTorch 1.13.1", + ), + pytest.mark.skipif( + not HAS_SKLEARN, reason="This test requires 'scikit-learn' package to run" + ), + pytest.mark.skipif( + not HAS_TRANSFORMERS, reason="This test requires 'transformers' package to run" + ), + pytest.mark.skipif( + not HAS_PYTORCH, reason="This test requires 'torch' package to run" + ), +] + +# If the required imports are missing the test will be skipped. +# Only define th test configurations if the referenced classes +# have been imported +if HAS_PYTORCH and HAS_SKLEARN and HAS_TRANSFORMERS: + MODEL_CONFIGURATIONS = [ + ( + "intfloat/e5-small-v2", + "text_embedding", + TextEmbeddingInferenceOptions, + NlpBertTokenizationConfig, + 512, + 384, + ), + ( + "sentence-transformers/all-mpnet-base-v2", + "text_embedding", + TextEmbeddingInferenceOptions, + NlpMPNetTokenizationConfig, + 512, + 768, + ), + ( + "sentence-transformers/all-MiniLM-L12-v2", + "text_embedding", + TextEmbeddingInferenceOptions, + NlpBertTokenizationConfig, + 512, + 384, + ), + ( + "facebook/dpr-ctx_encoder-multiset-base", + "text_embedding", + TextEmbeddingInferenceOptions, + NlpBertTokenizationConfig, + 512, + 768, + ), + ( + "distilbert-base-uncased", + "fill_mask", + FillMaskInferenceOptions, + NlpBertTokenizationConfig, + 512, + None, + ), + ( + "bert-base-uncased", + "fill_mask", + FillMaskInferenceOptions, + NlpBertTokenizationConfig, + 512, + None, + ), + ( + "elastic/distilbert-base-uncased-finetuned-conll03-english", + "ner", + NerInferenceOptions, + NlpBertTokenizationConfig, + 512, + None, + ), + ( + "SamLowe/roberta-base-go_emotions", + "text_classification", + TextClassificationInferenceOptions, + NlpRobertaTokenizationConfig, + 512, + None, + ), + ( + "distilbert-base-cased-distilled-squad", + "question_answering", + QuestionAnsweringInferenceOptions, + NlpBertTokenizationConfig, + 386, + None, + ), + ( + "cross-encoder/ms-marco-TinyBERT-L-2-v2", + "text_similarity", + TextSimilarityInferenceOptions, + NlpBertTokenizationConfig, + 512, + None, + ), + ( + "valhalla/distilbart-mnli-12-6", + "zero_shot_classification", + ZeroShotClassificationInferenceOptions, + NlpRobertaTokenizationConfig, + 1024, + None, + ), + ] +else: + MODEL_CONFIGURATIONS = [] + + +class TestModelConfguration: + @pytest.mark.parametrize( + "model_id,task_type,config_type,tokenizer_type,max_sequence_len,embedding_size", + MODEL_CONFIGURATIONS, + ) + def test_text_prediction( + self, + model_id, + task_type, + config_type, + tokenizer_type, + max_sequence_len, + embedding_size, + ): + with tempfile.TemporaryDirectory() as tmp_dir: + print("loading model " + model_id) + tm = TransformerModel( + model_id=model_id, + task_type=task_type, + es_version=ES_VERSION, + quantize=False, + ) + _, config, _ = tm.save(tmp_dir) + assert "pytorch" == config.model_type + assert ["text_field"] == config.input.field_names + assert isinstance(config.inference_config, config_type) + tokenization = config.inference_config.tokenization + assert isinstance(tokenization, tokenizer_type) + assert max_sequence_len == tokenization.max_sequence_length + + if task_type == "text_classification": + assert isinstance(config.inference_config.classification_labels, list) + assert len(config.inference_config.classification_labels) > 0 + + if task_type == "text_embedding": + assert embedding_size == config.inference_config.embedding_size + + if task_type == "question_answering": + assert tokenization.truncate == "none" + assert tokenization.span > 0 + + if task_type == "zero_shot_classification": + assert isinstance(config.inference_config.classification_labels, list) + assert len(config.inference_config.classification_labels) > 0 diff --git a/tests/ml/pytorch/test_transformer_pytorch_model_pytest.py b/tests/ml/pytorch/test_transformer_pytorch_model_pytest.py index c394ae7..ac23b05 100644 --- a/tests/ml/pytorch/test_transformer_pytorch_model_pytest.py +++ b/tests/ml/pytorch/test_transformer_pytorch_model_pytest.py @@ -24,13 +24,6 @@ import numpy as np import pytest from elasticsearch import NotFoundError -try: - import sklearn # noqa: F401 - - HAS_SKLEARN = True -except ImportError: - HAS_SKLEARN = False - try: import torch # noqa: F401 from torch import Tensor, nn # noqa: F401 @@ -67,9 +60,6 @@ pytestmark = [ ES_VERSION < (8, 0, 0), reason="This test requires at least Elasticsearch version 8.0.0", ), - pytest.mark.skipif( - not HAS_SKLEARN, reason="This test requires 'scikit-learn' package to run" - ), pytest.mark.skipif( not HAS_PYTORCH, reason="This test requires 'pytorch' package to run" ),