mirror of
https://github.com/elastic/eland.git
synced 2025-07-11 00:02:14 +08:00
[NLP] Tests for NLP model configurations (#623)
Add tests for generated Elasticsearch model configurations
This commit is contained in:
parent
0c0a8ab19f
commit
ab6e44f430
@ -311,7 +311,7 @@ def ensure_es_client(
|
||||
if isinstance(es_client, tuple):
|
||||
es_client = list(es_client)
|
||||
if not isinstance(es_client, Elasticsearch):
|
||||
es_client = Elasticsearch(es_client) # type: ignore[arg-type]
|
||||
es_client = Elasticsearch(es_client)
|
||||
return es_client
|
||||
|
||||
|
||||
|
@ -17,10 +17,17 @@
|
||||
|
||||
from eland.ml.pytorch._pytorch_model import PyTorchModel # noqa: F401
|
||||
from eland.ml.pytorch.nlp_ml_model import (
|
||||
FillMaskInferenceOptions,
|
||||
NerInferenceOptions,
|
||||
NlpBertTokenizationConfig,
|
||||
NlpMPNetTokenizationConfig,
|
||||
NlpRobertaTokenizationConfig,
|
||||
NlpTrainedModelConfig,
|
||||
QuestionAnsweringInferenceOptions,
|
||||
TextClassificationInferenceOptions,
|
||||
TextEmbeddingInferenceOptions,
|
||||
TextSimilarityInferenceOptions,
|
||||
ZeroShotClassificationInferenceOptions,
|
||||
)
|
||||
from eland.ml.pytorch.traceable_model import TraceableModel # noqa: F401
|
||||
from eland.ml.pytorch.transformers import task_type_from_model_config
|
||||
@ -28,10 +35,17 @@ from eland.ml.pytorch.transformers import task_type_from_model_config
|
||||
__all__ = [
|
||||
"PyTorchModel",
|
||||
"TraceableModel",
|
||||
"FillMaskInferenceOptions",
|
||||
"NerInferenceOptions",
|
||||
"NlpTrainedModelConfig",
|
||||
"NlpBertTokenizationConfig",
|
||||
"NlpRobertaTokenizationConfig",
|
||||
"NlpXLMRobertaTokenizationConfig",
|
||||
"NlpMPNetTokenizationConfig",
|
||||
"QuestionAnsweringInferenceOptions",
|
||||
"TextClassificationInferenceOptions",
|
||||
"TextEmbeddingInferenceOptions",
|
||||
"TextSimilarityInferenceOptions",
|
||||
"ZeroShotClassificationInferenceOptions",
|
||||
"task_type_from_model_config",
|
||||
]
|
||||
|
@ -317,11 +317,9 @@ class NlpTrainedModelConfig:
|
||||
input: TrainedModelInput = TrainedModelInput(field_names=["text_field"]),
|
||||
metadata: t.Optional[dict] = None,
|
||||
model_type: t.Union["t.Literal['pytorch']", str] = "pytorch",
|
||||
default_field_map: t.Optional[t.Mapping[str, str]] = None,
|
||||
tags: t.Optional[t.Union[t.List[str], t.Tuple[str, ...]]] = None,
|
||||
):
|
||||
self.tags = tags
|
||||
self.default_field_map = default_field_map
|
||||
self.description = description
|
||||
self.inference_config = inference_config
|
||||
self.input = input
|
||||
|
@ -664,27 +664,23 @@ class TransformerModel:
|
||||
return vocab_obj
|
||||
|
||||
def _create_tokenization_config(self) -> NlpTokenizationConfig:
|
||||
_max_sequence_length = self._find_max_sequence_length()
|
||||
|
||||
if isinstance(self._tokenizer, transformers.MPNetTokenizer):
|
||||
return NlpMPNetTokenizationConfig(
|
||||
do_lower_case=getattr(self._tokenizer, "do_lower_case", None),
|
||||
max_sequence_length=getattr(
|
||||
self._tokenizer, "max_model_input_sizes", dict()
|
||||
).get(self._model_id),
|
||||
max_sequence_length=_max_sequence_length,
|
||||
)
|
||||
elif isinstance(
|
||||
self._tokenizer, (transformers.RobertaTokenizer, transformers.BartTokenizer)
|
||||
):
|
||||
return NlpRobertaTokenizationConfig(
|
||||
add_prefix_space=getattr(self._tokenizer, "add_prefix_space", None),
|
||||
max_sequence_length=getattr(
|
||||
self._tokenizer, "max_model_input_sizes", dict()
|
||||
).get(self._model_id),
|
||||
max_sequence_length=_max_sequence_length,
|
||||
)
|
||||
elif isinstance(self._tokenizer, transformers.XLMRobertaTokenizer):
|
||||
return NlpXLMRobertaTokenizationConfig(
|
||||
max_sequence_length=getattr(
|
||||
self._tokenizer, "max_model_input_sizes", dict()
|
||||
).get(self._model_id),
|
||||
max_sequence_length=_max_sequence_length
|
||||
)
|
||||
else:
|
||||
japanese_morphological_tokenizers = ["mecab"]
|
||||
@ -695,18 +691,38 @@ class TransformerModel:
|
||||
):
|
||||
return NlpBertJapaneseTokenizationConfig(
|
||||
do_lower_case=getattr(self._tokenizer, "do_lower_case", None),
|
||||
max_sequence_length=getattr(
|
||||
self._tokenizer, "max_model_input_sizes", dict()
|
||||
).get(self._model_id),
|
||||
max_sequence_length=_max_sequence_length,
|
||||
)
|
||||
else:
|
||||
return NlpBertTokenizationConfig(
|
||||
do_lower_case=getattr(self._tokenizer, "do_lower_case", None),
|
||||
max_sequence_length=getattr(
|
||||
self._tokenizer, "max_model_input_sizes", dict()
|
||||
).get(self._model_id),
|
||||
max_sequence_length=_max_sequence_length,
|
||||
)
|
||||
|
||||
def _find_max_sequence_length(self) -> int:
|
||||
# Sometimes the max_... values are present but contain
|
||||
# a random or very large value.
|
||||
REASONABLE_MAX_LENGTH = 8192
|
||||
max_len = getattr(self._tokenizer, "max_model_input_sizes", dict()).get(
|
||||
self._model_id
|
||||
)
|
||||
if max_len is not None and max_len < REASONABLE_MAX_LENGTH:
|
||||
return int(max_len)
|
||||
|
||||
max_len = getattr(self._tokenizer, "model_max_length", None)
|
||||
if max_len is not None and max_len < REASONABLE_MAX_LENGTH:
|
||||
return int(max_len)
|
||||
|
||||
model_config = getattr(self._traceable_model._model, "config", None)
|
||||
if model_config is None:
|
||||
raise ValueError("Cannot determine model max input length")
|
||||
|
||||
max_len = getattr(model_config, "max_position_embeddings", None)
|
||||
if max_len is not None and max_len < REASONABLE_MAX_LENGTH:
|
||||
return int(max_len)
|
||||
|
||||
raise ValueError("Cannot determine model max input length")
|
||||
|
||||
def _create_config(
|
||||
self, es_version: Optional[Tuple[int, int, int]]
|
||||
) -> NlpTrainedModelConfig:
|
||||
@ -756,7 +772,7 @@ class TransformerModel:
|
||||
),
|
||||
)
|
||||
|
||||
def _create_traceable_model(self) -> TraceableModel:
|
||||
def _create_traceable_model(self) -> _TransformerTraceableModel:
|
||||
if self._task_type == "auto":
|
||||
model = transformers.AutoModel.from_pretrained(
|
||||
self._model_id, token=self._access_token, torchscript=True
|
||||
|
216
tests/ml/pytorch/test_pytorch_model_config.py
Normal file
216
tests/ml/pytorch/test_pytorch_model_config.py
Normal file
@ -0,0 +1,216 @@
|
||||
# Licensed to Elasticsearch B.V. under one or more contributor
|
||||
# license agreements. See the NOTICE file distributed with
|
||||
# this work for additional information regarding copyright
|
||||
# ownership. Elasticsearch B.V. licenses this file to you under
|
||||
# the Apache License, Version 2.0 (the "License"); you may
|
||||
# not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
|
||||
import tempfile
|
||||
|
||||
import pytest
|
||||
|
||||
try:
|
||||
import sklearn # noqa: F401
|
||||
|
||||
HAS_SKLEARN = True
|
||||
except ImportError:
|
||||
HAS_SKLEARN = False
|
||||
|
||||
try:
|
||||
from eland.ml.pytorch.transformers import TransformerModel
|
||||
|
||||
HAS_TRANSFORMERS = True
|
||||
except ImportError:
|
||||
HAS_TRANSFORMERS = False
|
||||
|
||||
try:
|
||||
import torch # noqa: F401
|
||||
|
||||
from eland.ml.pytorch import (
|
||||
FillMaskInferenceOptions,
|
||||
NerInferenceOptions,
|
||||
NlpBertTokenizationConfig,
|
||||
NlpMPNetTokenizationConfig,
|
||||
NlpRobertaTokenizationConfig,
|
||||
QuestionAnsweringInferenceOptions,
|
||||
TextClassificationInferenceOptions,
|
||||
TextEmbeddingInferenceOptions,
|
||||
TextSimilarityInferenceOptions,
|
||||
ZeroShotClassificationInferenceOptions,
|
||||
)
|
||||
|
||||
HAS_PYTORCH = True
|
||||
except ImportError:
|
||||
HAS_PYTORCH = False
|
||||
|
||||
|
||||
from tests import ES_VERSION
|
||||
|
||||
pytestmark = [
|
||||
pytest.mark.skipif(
|
||||
ES_VERSION < (8, 7, 0),
|
||||
reason="Eland uses Pytorch 1.13.1, versions of Elasticsearch prior to 8.7.0 are incompatible with PyTorch 1.13.1",
|
||||
),
|
||||
pytest.mark.skipif(
|
||||
not HAS_SKLEARN, reason="This test requires 'scikit-learn' package to run"
|
||||
),
|
||||
pytest.mark.skipif(
|
||||
not HAS_TRANSFORMERS, reason="This test requires 'transformers' package to run"
|
||||
),
|
||||
pytest.mark.skipif(
|
||||
not HAS_PYTORCH, reason="This test requires 'torch' package to run"
|
||||
),
|
||||
]
|
||||
|
||||
# If the required imports are missing the test will be skipped.
|
||||
# Only define th test configurations if the referenced classes
|
||||
# have been imported
|
||||
if HAS_PYTORCH and HAS_SKLEARN and HAS_TRANSFORMERS:
|
||||
MODEL_CONFIGURATIONS = [
|
||||
(
|
||||
"intfloat/e5-small-v2",
|
||||
"text_embedding",
|
||||
TextEmbeddingInferenceOptions,
|
||||
NlpBertTokenizationConfig,
|
||||
512,
|
||||
384,
|
||||
),
|
||||
(
|
||||
"sentence-transformers/all-mpnet-base-v2",
|
||||
"text_embedding",
|
||||
TextEmbeddingInferenceOptions,
|
||||
NlpMPNetTokenizationConfig,
|
||||
512,
|
||||
768,
|
||||
),
|
||||
(
|
||||
"sentence-transformers/all-MiniLM-L12-v2",
|
||||
"text_embedding",
|
||||
TextEmbeddingInferenceOptions,
|
||||
NlpBertTokenizationConfig,
|
||||
512,
|
||||
384,
|
||||
),
|
||||
(
|
||||
"facebook/dpr-ctx_encoder-multiset-base",
|
||||
"text_embedding",
|
||||
TextEmbeddingInferenceOptions,
|
||||
NlpBertTokenizationConfig,
|
||||
512,
|
||||
768,
|
||||
),
|
||||
(
|
||||
"distilbert-base-uncased",
|
||||
"fill_mask",
|
||||
FillMaskInferenceOptions,
|
||||
NlpBertTokenizationConfig,
|
||||
512,
|
||||
None,
|
||||
),
|
||||
(
|
||||
"bert-base-uncased",
|
||||
"fill_mask",
|
||||
FillMaskInferenceOptions,
|
||||
NlpBertTokenizationConfig,
|
||||
512,
|
||||
None,
|
||||
),
|
||||
(
|
||||
"elastic/distilbert-base-uncased-finetuned-conll03-english",
|
||||
"ner",
|
||||
NerInferenceOptions,
|
||||
NlpBertTokenizationConfig,
|
||||
512,
|
||||
None,
|
||||
),
|
||||
(
|
||||
"SamLowe/roberta-base-go_emotions",
|
||||
"text_classification",
|
||||
TextClassificationInferenceOptions,
|
||||
NlpRobertaTokenizationConfig,
|
||||
512,
|
||||
None,
|
||||
),
|
||||
(
|
||||
"distilbert-base-cased-distilled-squad",
|
||||
"question_answering",
|
||||
QuestionAnsweringInferenceOptions,
|
||||
NlpBertTokenizationConfig,
|
||||
386,
|
||||
None,
|
||||
),
|
||||
(
|
||||
"cross-encoder/ms-marco-TinyBERT-L-2-v2",
|
||||
"text_similarity",
|
||||
TextSimilarityInferenceOptions,
|
||||
NlpBertTokenizationConfig,
|
||||
512,
|
||||
None,
|
||||
),
|
||||
(
|
||||
"valhalla/distilbart-mnli-12-6",
|
||||
"zero_shot_classification",
|
||||
ZeroShotClassificationInferenceOptions,
|
||||
NlpRobertaTokenizationConfig,
|
||||
1024,
|
||||
None,
|
||||
),
|
||||
]
|
||||
else:
|
||||
MODEL_CONFIGURATIONS = []
|
||||
|
||||
|
||||
class TestModelConfguration:
|
||||
@pytest.mark.parametrize(
|
||||
"model_id,task_type,config_type,tokenizer_type,max_sequence_len,embedding_size",
|
||||
MODEL_CONFIGURATIONS,
|
||||
)
|
||||
def test_text_prediction(
|
||||
self,
|
||||
model_id,
|
||||
task_type,
|
||||
config_type,
|
||||
tokenizer_type,
|
||||
max_sequence_len,
|
||||
embedding_size,
|
||||
):
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
print("loading model " + model_id)
|
||||
tm = TransformerModel(
|
||||
model_id=model_id,
|
||||
task_type=task_type,
|
||||
es_version=ES_VERSION,
|
||||
quantize=False,
|
||||
)
|
||||
_, config, _ = tm.save(tmp_dir)
|
||||
assert "pytorch" == config.model_type
|
||||
assert ["text_field"] == config.input.field_names
|
||||
assert isinstance(config.inference_config, config_type)
|
||||
tokenization = config.inference_config.tokenization
|
||||
assert isinstance(tokenization, tokenizer_type)
|
||||
assert max_sequence_len == tokenization.max_sequence_length
|
||||
|
||||
if task_type == "text_classification":
|
||||
assert isinstance(config.inference_config.classification_labels, list)
|
||||
assert len(config.inference_config.classification_labels) > 0
|
||||
|
||||
if task_type == "text_embedding":
|
||||
assert embedding_size == config.inference_config.embedding_size
|
||||
|
||||
if task_type == "question_answering":
|
||||
assert tokenization.truncate == "none"
|
||||
assert tokenization.span > 0
|
||||
|
||||
if task_type == "zero_shot_classification":
|
||||
assert isinstance(config.inference_config.classification_labels, list)
|
||||
assert len(config.inference_config.classification_labels) > 0
|
@ -24,13 +24,6 @@ import numpy as np
|
||||
import pytest
|
||||
from elasticsearch import NotFoundError
|
||||
|
||||
try:
|
||||
import sklearn # noqa: F401
|
||||
|
||||
HAS_SKLEARN = True
|
||||
except ImportError:
|
||||
HAS_SKLEARN = False
|
||||
|
||||
try:
|
||||
import torch # noqa: F401
|
||||
from torch import Tensor, nn # noqa: F401
|
||||
@ -67,9 +60,6 @@ pytestmark = [
|
||||
ES_VERSION < (8, 0, 0),
|
||||
reason="This test requires at least Elasticsearch version 8.0.0",
|
||||
),
|
||||
pytest.mark.skipif(
|
||||
not HAS_SKLEARN, reason="This test requires 'scikit-learn' package to run"
|
||||
),
|
||||
pytest.mark.skipif(
|
||||
not HAS_PYTORCH, reason="This test requires 'pytorch' package to run"
|
||||
),
|
||||
|
Loading…
x
Reference in New Issue
Block a user