mirror of
https://github.com/elastic/eland.git
synced 2025-07-11 00:02:14 +08:00
[NLP] Tests for NLP model configurations (#623)
Add tests for generated Elasticsearch model configurations
This commit is contained in:
parent
0c0a8ab19f
commit
ab6e44f430
@ -311,7 +311,7 @@ def ensure_es_client(
|
|||||||
if isinstance(es_client, tuple):
|
if isinstance(es_client, tuple):
|
||||||
es_client = list(es_client)
|
es_client = list(es_client)
|
||||||
if not isinstance(es_client, Elasticsearch):
|
if not isinstance(es_client, Elasticsearch):
|
||||||
es_client = Elasticsearch(es_client) # type: ignore[arg-type]
|
es_client = Elasticsearch(es_client)
|
||||||
return es_client
|
return es_client
|
||||||
|
|
||||||
|
|
||||||
|
@ -17,10 +17,17 @@
|
|||||||
|
|
||||||
from eland.ml.pytorch._pytorch_model import PyTorchModel # noqa: F401
|
from eland.ml.pytorch._pytorch_model import PyTorchModel # noqa: F401
|
||||||
from eland.ml.pytorch.nlp_ml_model import (
|
from eland.ml.pytorch.nlp_ml_model import (
|
||||||
|
FillMaskInferenceOptions,
|
||||||
|
NerInferenceOptions,
|
||||||
NlpBertTokenizationConfig,
|
NlpBertTokenizationConfig,
|
||||||
NlpMPNetTokenizationConfig,
|
NlpMPNetTokenizationConfig,
|
||||||
NlpRobertaTokenizationConfig,
|
NlpRobertaTokenizationConfig,
|
||||||
NlpTrainedModelConfig,
|
NlpTrainedModelConfig,
|
||||||
|
QuestionAnsweringInferenceOptions,
|
||||||
|
TextClassificationInferenceOptions,
|
||||||
|
TextEmbeddingInferenceOptions,
|
||||||
|
TextSimilarityInferenceOptions,
|
||||||
|
ZeroShotClassificationInferenceOptions,
|
||||||
)
|
)
|
||||||
from eland.ml.pytorch.traceable_model import TraceableModel # noqa: F401
|
from eland.ml.pytorch.traceable_model import TraceableModel # noqa: F401
|
||||||
from eland.ml.pytorch.transformers import task_type_from_model_config
|
from eland.ml.pytorch.transformers import task_type_from_model_config
|
||||||
@ -28,10 +35,17 @@ from eland.ml.pytorch.transformers import task_type_from_model_config
|
|||||||
__all__ = [
|
__all__ = [
|
||||||
"PyTorchModel",
|
"PyTorchModel",
|
||||||
"TraceableModel",
|
"TraceableModel",
|
||||||
|
"FillMaskInferenceOptions",
|
||||||
|
"NerInferenceOptions",
|
||||||
"NlpTrainedModelConfig",
|
"NlpTrainedModelConfig",
|
||||||
"NlpBertTokenizationConfig",
|
"NlpBertTokenizationConfig",
|
||||||
"NlpRobertaTokenizationConfig",
|
"NlpRobertaTokenizationConfig",
|
||||||
"NlpXLMRobertaTokenizationConfig",
|
"NlpXLMRobertaTokenizationConfig",
|
||||||
"NlpMPNetTokenizationConfig",
|
"NlpMPNetTokenizationConfig",
|
||||||
|
"QuestionAnsweringInferenceOptions",
|
||||||
|
"TextClassificationInferenceOptions",
|
||||||
|
"TextEmbeddingInferenceOptions",
|
||||||
|
"TextSimilarityInferenceOptions",
|
||||||
|
"ZeroShotClassificationInferenceOptions",
|
||||||
"task_type_from_model_config",
|
"task_type_from_model_config",
|
||||||
]
|
]
|
||||||
|
@ -317,11 +317,9 @@ class NlpTrainedModelConfig:
|
|||||||
input: TrainedModelInput = TrainedModelInput(field_names=["text_field"]),
|
input: TrainedModelInput = TrainedModelInput(field_names=["text_field"]),
|
||||||
metadata: t.Optional[dict] = None,
|
metadata: t.Optional[dict] = None,
|
||||||
model_type: t.Union["t.Literal['pytorch']", str] = "pytorch",
|
model_type: t.Union["t.Literal['pytorch']", str] = "pytorch",
|
||||||
default_field_map: t.Optional[t.Mapping[str, str]] = None,
|
|
||||||
tags: t.Optional[t.Union[t.List[str], t.Tuple[str, ...]]] = None,
|
tags: t.Optional[t.Union[t.List[str], t.Tuple[str, ...]]] = None,
|
||||||
):
|
):
|
||||||
self.tags = tags
|
self.tags = tags
|
||||||
self.default_field_map = default_field_map
|
|
||||||
self.description = description
|
self.description = description
|
||||||
self.inference_config = inference_config
|
self.inference_config = inference_config
|
||||||
self.input = input
|
self.input = input
|
||||||
|
@ -664,27 +664,23 @@ class TransformerModel:
|
|||||||
return vocab_obj
|
return vocab_obj
|
||||||
|
|
||||||
def _create_tokenization_config(self) -> NlpTokenizationConfig:
|
def _create_tokenization_config(self) -> NlpTokenizationConfig:
|
||||||
|
_max_sequence_length = self._find_max_sequence_length()
|
||||||
|
|
||||||
if isinstance(self._tokenizer, transformers.MPNetTokenizer):
|
if isinstance(self._tokenizer, transformers.MPNetTokenizer):
|
||||||
return NlpMPNetTokenizationConfig(
|
return NlpMPNetTokenizationConfig(
|
||||||
do_lower_case=getattr(self._tokenizer, "do_lower_case", None),
|
do_lower_case=getattr(self._tokenizer, "do_lower_case", None),
|
||||||
max_sequence_length=getattr(
|
max_sequence_length=_max_sequence_length,
|
||||||
self._tokenizer, "max_model_input_sizes", dict()
|
|
||||||
).get(self._model_id),
|
|
||||||
)
|
)
|
||||||
elif isinstance(
|
elif isinstance(
|
||||||
self._tokenizer, (transformers.RobertaTokenizer, transformers.BartTokenizer)
|
self._tokenizer, (transformers.RobertaTokenizer, transformers.BartTokenizer)
|
||||||
):
|
):
|
||||||
return NlpRobertaTokenizationConfig(
|
return NlpRobertaTokenizationConfig(
|
||||||
add_prefix_space=getattr(self._tokenizer, "add_prefix_space", None),
|
add_prefix_space=getattr(self._tokenizer, "add_prefix_space", None),
|
||||||
max_sequence_length=getattr(
|
max_sequence_length=_max_sequence_length,
|
||||||
self._tokenizer, "max_model_input_sizes", dict()
|
|
||||||
).get(self._model_id),
|
|
||||||
)
|
)
|
||||||
elif isinstance(self._tokenizer, transformers.XLMRobertaTokenizer):
|
elif isinstance(self._tokenizer, transformers.XLMRobertaTokenizer):
|
||||||
return NlpXLMRobertaTokenizationConfig(
|
return NlpXLMRobertaTokenizationConfig(
|
||||||
max_sequence_length=getattr(
|
max_sequence_length=_max_sequence_length
|
||||||
self._tokenizer, "max_model_input_sizes", dict()
|
|
||||||
).get(self._model_id),
|
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
japanese_morphological_tokenizers = ["mecab"]
|
japanese_morphological_tokenizers = ["mecab"]
|
||||||
@ -695,18 +691,38 @@ class TransformerModel:
|
|||||||
):
|
):
|
||||||
return NlpBertJapaneseTokenizationConfig(
|
return NlpBertJapaneseTokenizationConfig(
|
||||||
do_lower_case=getattr(self._tokenizer, "do_lower_case", None),
|
do_lower_case=getattr(self._tokenizer, "do_lower_case", None),
|
||||||
max_sequence_length=getattr(
|
max_sequence_length=_max_sequence_length,
|
||||||
self._tokenizer, "max_model_input_sizes", dict()
|
|
||||||
).get(self._model_id),
|
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
return NlpBertTokenizationConfig(
|
return NlpBertTokenizationConfig(
|
||||||
do_lower_case=getattr(self._tokenizer, "do_lower_case", None),
|
do_lower_case=getattr(self._tokenizer, "do_lower_case", None),
|
||||||
max_sequence_length=getattr(
|
max_sequence_length=_max_sequence_length,
|
||||||
self._tokenizer, "max_model_input_sizes", dict()
|
|
||||||
).get(self._model_id),
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def _find_max_sequence_length(self) -> int:
|
||||||
|
# Sometimes the max_... values are present but contain
|
||||||
|
# a random or very large value.
|
||||||
|
REASONABLE_MAX_LENGTH = 8192
|
||||||
|
max_len = getattr(self._tokenizer, "max_model_input_sizes", dict()).get(
|
||||||
|
self._model_id
|
||||||
|
)
|
||||||
|
if max_len is not None and max_len < REASONABLE_MAX_LENGTH:
|
||||||
|
return int(max_len)
|
||||||
|
|
||||||
|
max_len = getattr(self._tokenizer, "model_max_length", None)
|
||||||
|
if max_len is not None and max_len < REASONABLE_MAX_LENGTH:
|
||||||
|
return int(max_len)
|
||||||
|
|
||||||
|
model_config = getattr(self._traceable_model._model, "config", None)
|
||||||
|
if model_config is None:
|
||||||
|
raise ValueError("Cannot determine model max input length")
|
||||||
|
|
||||||
|
max_len = getattr(model_config, "max_position_embeddings", None)
|
||||||
|
if max_len is not None and max_len < REASONABLE_MAX_LENGTH:
|
||||||
|
return int(max_len)
|
||||||
|
|
||||||
|
raise ValueError("Cannot determine model max input length")
|
||||||
|
|
||||||
def _create_config(
|
def _create_config(
|
||||||
self, es_version: Optional[Tuple[int, int, int]]
|
self, es_version: Optional[Tuple[int, int, int]]
|
||||||
) -> NlpTrainedModelConfig:
|
) -> NlpTrainedModelConfig:
|
||||||
@ -756,7 +772,7 @@ class TransformerModel:
|
|||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
def _create_traceable_model(self) -> TraceableModel:
|
def _create_traceable_model(self) -> _TransformerTraceableModel:
|
||||||
if self._task_type == "auto":
|
if self._task_type == "auto":
|
||||||
model = transformers.AutoModel.from_pretrained(
|
model = transformers.AutoModel.from_pretrained(
|
||||||
self._model_id, token=self._access_token, torchscript=True
|
self._model_id, token=self._access_token, torchscript=True
|
||||||
|
216
tests/ml/pytorch/test_pytorch_model_config.py
Normal file
216
tests/ml/pytorch/test_pytorch_model_config.py
Normal file
@ -0,0 +1,216 @@
|
|||||||
|
# Licensed to Elasticsearch B.V. under one or more contributor
|
||||||
|
# license agreements. See the NOTICE file distributed with
|
||||||
|
# this work for additional information regarding copyright
|
||||||
|
# ownership. Elasticsearch B.V. licenses this file to you under
|
||||||
|
# the Apache License, Version 2.0 (the "License"); you may
|
||||||
|
# not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing,
|
||||||
|
# software distributed under the License is distributed on an
|
||||||
|
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||||
|
# KIND, either express or implied. See the License for the
|
||||||
|
# specific language governing permissions and limitations
|
||||||
|
# under the License.
|
||||||
|
|
||||||
|
import tempfile
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
try:
|
||||||
|
import sklearn # noqa: F401
|
||||||
|
|
||||||
|
HAS_SKLEARN = True
|
||||||
|
except ImportError:
|
||||||
|
HAS_SKLEARN = False
|
||||||
|
|
||||||
|
try:
|
||||||
|
from eland.ml.pytorch.transformers import TransformerModel
|
||||||
|
|
||||||
|
HAS_TRANSFORMERS = True
|
||||||
|
except ImportError:
|
||||||
|
HAS_TRANSFORMERS = False
|
||||||
|
|
||||||
|
try:
|
||||||
|
import torch # noqa: F401
|
||||||
|
|
||||||
|
from eland.ml.pytorch import (
|
||||||
|
FillMaskInferenceOptions,
|
||||||
|
NerInferenceOptions,
|
||||||
|
NlpBertTokenizationConfig,
|
||||||
|
NlpMPNetTokenizationConfig,
|
||||||
|
NlpRobertaTokenizationConfig,
|
||||||
|
QuestionAnsweringInferenceOptions,
|
||||||
|
TextClassificationInferenceOptions,
|
||||||
|
TextEmbeddingInferenceOptions,
|
||||||
|
TextSimilarityInferenceOptions,
|
||||||
|
ZeroShotClassificationInferenceOptions,
|
||||||
|
)
|
||||||
|
|
||||||
|
HAS_PYTORCH = True
|
||||||
|
except ImportError:
|
||||||
|
HAS_PYTORCH = False
|
||||||
|
|
||||||
|
|
||||||
|
from tests import ES_VERSION
|
||||||
|
|
||||||
|
pytestmark = [
|
||||||
|
pytest.mark.skipif(
|
||||||
|
ES_VERSION < (8, 7, 0),
|
||||||
|
reason="Eland uses Pytorch 1.13.1, versions of Elasticsearch prior to 8.7.0 are incompatible with PyTorch 1.13.1",
|
||||||
|
),
|
||||||
|
pytest.mark.skipif(
|
||||||
|
not HAS_SKLEARN, reason="This test requires 'scikit-learn' package to run"
|
||||||
|
),
|
||||||
|
pytest.mark.skipif(
|
||||||
|
not HAS_TRANSFORMERS, reason="This test requires 'transformers' package to run"
|
||||||
|
),
|
||||||
|
pytest.mark.skipif(
|
||||||
|
not HAS_PYTORCH, reason="This test requires 'torch' package to run"
|
||||||
|
),
|
||||||
|
]
|
||||||
|
|
||||||
|
# If the required imports are missing the test will be skipped.
|
||||||
|
# Only define th test configurations if the referenced classes
|
||||||
|
# have been imported
|
||||||
|
if HAS_PYTORCH and HAS_SKLEARN and HAS_TRANSFORMERS:
|
||||||
|
MODEL_CONFIGURATIONS = [
|
||||||
|
(
|
||||||
|
"intfloat/e5-small-v2",
|
||||||
|
"text_embedding",
|
||||||
|
TextEmbeddingInferenceOptions,
|
||||||
|
NlpBertTokenizationConfig,
|
||||||
|
512,
|
||||||
|
384,
|
||||||
|
),
|
||||||
|
(
|
||||||
|
"sentence-transformers/all-mpnet-base-v2",
|
||||||
|
"text_embedding",
|
||||||
|
TextEmbeddingInferenceOptions,
|
||||||
|
NlpMPNetTokenizationConfig,
|
||||||
|
512,
|
||||||
|
768,
|
||||||
|
),
|
||||||
|
(
|
||||||
|
"sentence-transformers/all-MiniLM-L12-v2",
|
||||||
|
"text_embedding",
|
||||||
|
TextEmbeddingInferenceOptions,
|
||||||
|
NlpBertTokenizationConfig,
|
||||||
|
512,
|
||||||
|
384,
|
||||||
|
),
|
||||||
|
(
|
||||||
|
"facebook/dpr-ctx_encoder-multiset-base",
|
||||||
|
"text_embedding",
|
||||||
|
TextEmbeddingInferenceOptions,
|
||||||
|
NlpBertTokenizationConfig,
|
||||||
|
512,
|
||||||
|
768,
|
||||||
|
),
|
||||||
|
(
|
||||||
|
"distilbert-base-uncased",
|
||||||
|
"fill_mask",
|
||||||
|
FillMaskInferenceOptions,
|
||||||
|
NlpBertTokenizationConfig,
|
||||||
|
512,
|
||||||
|
None,
|
||||||
|
),
|
||||||
|
(
|
||||||
|
"bert-base-uncased",
|
||||||
|
"fill_mask",
|
||||||
|
FillMaskInferenceOptions,
|
||||||
|
NlpBertTokenizationConfig,
|
||||||
|
512,
|
||||||
|
None,
|
||||||
|
),
|
||||||
|
(
|
||||||
|
"elastic/distilbert-base-uncased-finetuned-conll03-english",
|
||||||
|
"ner",
|
||||||
|
NerInferenceOptions,
|
||||||
|
NlpBertTokenizationConfig,
|
||||||
|
512,
|
||||||
|
None,
|
||||||
|
),
|
||||||
|
(
|
||||||
|
"SamLowe/roberta-base-go_emotions",
|
||||||
|
"text_classification",
|
||||||
|
TextClassificationInferenceOptions,
|
||||||
|
NlpRobertaTokenizationConfig,
|
||||||
|
512,
|
||||||
|
None,
|
||||||
|
),
|
||||||
|
(
|
||||||
|
"distilbert-base-cased-distilled-squad",
|
||||||
|
"question_answering",
|
||||||
|
QuestionAnsweringInferenceOptions,
|
||||||
|
NlpBertTokenizationConfig,
|
||||||
|
386,
|
||||||
|
None,
|
||||||
|
),
|
||||||
|
(
|
||||||
|
"cross-encoder/ms-marco-TinyBERT-L-2-v2",
|
||||||
|
"text_similarity",
|
||||||
|
TextSimilarityInferenceOptions,
|
||||||
|
NlpBertTokenizationConfig,
|
||||||
|
512,
|
||||||
|
None,
|
||||||
|
),
|
||||||
|
(
|
||||||
|
"valhalla/distilbart-mnli-12-6",
|
||||||
|
"zero_shot_classification",
|
||||||
|
ZeroShotClassificationInferenceOptions,
|
||||||
|
NlpRobertaTokenizationConfig,
|
||||||
|
1024,
|
||||||
|
None,
|
||||||
|
),
|
||||||
|
]
|
||||||
|
else:
|
||||||
|
MODEL_CONFIGURATIONS = []
|
||||||
|
|
||||||
|
|
||||||
|
class TestModelConfguration:
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"model_id,task_type,config_type,tokenizer_type,max_sequence_len,embedding_size",
|
||||||
|
MODEL_CONFIGURATIONS,
|
||||||
|
)
|
||||||
|
def test_text_prediction(
|
||||||
|
self,
|
||||||
|
model_id,
|
||||||
|
task_type,
|
||||||
|
config_type,
|
||||||
|
tokenizer_type,
|
||||||
|
max_sequence_len,
|
||||||
|
embedding_size,
|
||||||
|
):
|
||||||
|
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||||
|
print("loading model " + model_id)
|
||||||
|
tm = TransformerModel(
|
||||||
|
model_id=model_id,
|
||||||
|
task_type=task_type,
|
||||||
|
es_version=ES_VERSION,
|
||||||
|
quantize=False,
|
||||||
|
)
|
||||||
|
_, config, _ = tm.save(tmp_dir)
|
||||||
|
assert "pytorch" == config.model_type
|
||||||
|
assert ["text_field"] == config.input.field_names
|
||||||
|
assert isinstance(config.inference_config, config_type)
|
||||||
|
tokenization = config.inference_config.tokenization
|
||||||
|
assert isinstance(tokenization, tokenizer_type)
|
||||||
|
assert max_sequence_len == tokenization.max_sequence_length
|
||||||
|
|
||||||
|
if task_type == "text_classification":
|
||||||
|
assert isinstance(config.inference_config.classification_labels, list)
|
||||||
|
assert len(config.inference_config.classification_labels) > 0
|
||||||
|
|
||||||
|
if task_type == "text_embedding":
|
||||||
|
assert embedding_size == config.inference_config.embedding_size
|
||||||
|
|
||||||
|
if task_type == "question_answering":
|
||||||
|
assert tokenization.truncate == "none"
|
||||||
|
assert tokenization.span > 0
|
||||||
|
|
||||||
|
if task_type == "zero_shot_classification":
|
||||||
|
assert isinstance(config.inference_config.classification_labels, list)
|
||||||
|
assert len(config.inference_config.classification_labels) > 0
|
@ -24,13 +24,6 @@ import numpy as np
|
|||||||
import pytest
|
import pytest
|
||||||
from elasticsearch import NotFoundError
|
from elasticsearch import NotFoundError
|
||||||
|
|
||||||
try:
|
|
||||||
import sklearn # noqa: F401
|
|
||||||
|
|
||||||
HAS_SKLEARN = True
|
|
||||||
except ImportError:
|
|
||||||
HAS_SKLEARN = False
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
import torch # noqa: F401
|
import torch # noqa: F401
|
||||||
from torch import Tensor, nn # noqa: F401
|
from torch import Tensor, nn # noqa: F401
|
||||||
@ -67,9 +60,6 @@ pytestmark = [
|
|||||||
ES_VERSION < (8, 0, 0),
|
ES_VERSION < (8, 0, 0),
|
||||||
reason="This test requires at least Elasticsearch version 8.0.0",
|
reason="This test requires at least Elasticsearch version 8.0.0",
|
||||||
),
|
),
|
||||||
pytest.mark.skipif(
|
|
||||||
not HAS_SKLEARN, reason="This test requires 'scikit-learn' package to run"
|
|
||||||
),
|
|
||||||
pytest.mark.skipif(
|
pytest.mark.skipif(
|
||||||
not HAS_PYTORCH, reason="This test requires 'pytorch' package to run"
|
not HAS_PYTORCH, reason="This test requires 'pytorch' package to run"
|
||||||
),
|
),
|
||||||
|
Loading…
x
Reference in New Issue
Block a user