[NLP] Tests for NLP model configurations (#623)

Add tests for generated Elasticsearch model configurations
This commit is contained in:
David Kyle 2023-10-19 12:39:57 +01:00 committed by GitHub
parent 0c0a8ab19f
commit ab6e44f430
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 263 additions and 29 deletions

View File

@ -311,7 +311,7 @@ def ensure_es_client(
if isinstance(es_client, tuple):
es_client = list(es_client)
if not isinstance(es_client, Elasticsearch):
es_client = Elasticsearch(es_client) # type: ignore[arg-type]
es_client = Elasticsearch(es_client)
return es_client

View File

@ -17,10 +17,17 @@
from eland.ml.pytorch._pytorch_model import PyTorchModel # noqa: F401
from eland.ml.pytorch.nlp_ml_model import (
FillMaskInferenceOptions,
NerInferenceOptions,
NlpBertTokenizationConfig,
NlpMPNetTokenizationConfig,
NlpRobertaTokenizationConfig,
NlpTrainedModelConfig,
QuestionAnsweringInferenceOptions,
TextClassificationInferenceOptions,
TextEmbeddingInferenceOptions,
TextSimilarityInferenceOptions,
ZeroShotClassificationInferenceOptions,
)
from eland.ml.pytorch.traceable_model import TraceableModel # noqa: F401
from eland.ml.pytorch.transformers import task_type_from_model_config
@ -28,10 +35,17 @@ from eland.ml.pytorch.transformers import task_type_from_model_config
__all__ = [
"PyTorchModel",
"TraceableModel",
"FillMaskInferenceOptions",
"NerInferenceOptions",
"NlpTrainedModelConfig",
"NlpBertTokenizationConfig",
"NlpRobertaTokenizationConfig",
"NlpXLMRobertaTokenizationConfig",
"NlpMPNetTokenizationConfig",
"QuestionAnsweringInferenceOptions",
"TextClassificationInferenceOptions",
"TextEmbeddingInferenceOptions",
"TextSimilarityInferenceOptions",
"ZeroShotClassificationInferenceOptions",
"task_type_from_model_config",
]

View File

@ -317,11 +317,9 @@ class NlpTrainedModelConfig:
input: TrainedModelInput = TrainedModelInput(field_names=["text_field"]),
metadata: t.Optional[dict] = None,
model_type: t.Union["t.Literal['pytorch']", str] = "pytorch",
default_field_map: t.Optional[t.Mapping[str, str]] = None,
tags: t.Optional[t.Union[t.List[str], t.Tuple[str, ...]]] = None,
):
self.tags = tags
self.default_field_map = default_field_map
self.description = description
self.inference_config = inference_config
self.input = input

View File

@ -664,27 +664,23 @@ class TransformerModel:
return vocab_obj
def _create_tokenization_config(self) -> NlpTokenizationConfig:
_max_sequence_length = self._find_max_sequence_length()
if isinstance(self._tokenizer, transformers.MPNetTokenizer):
return NlpMPNetTokenizationConfig(
do_lower_case=getattr(self._tokenizer, "do_lower_case", None),
max_sequence_length=getattr(
self._tokenizer, "max_model_input_sizes", dict()
).get(self._model_id),
max_sequence_length=_max_sequence_length,
)
elif isinstance(
self._tokenizer, (transformers.RobertaTokenizer, transformers.BartTokenizer)
):
return NlpRobertaTokenizationConfig(
add_prefix_space=getattr(self._tokenizer, "add_prefix_space", None),
max_sequence_length=getattr(
self._tokenizer, "max_model_input_sizes", dict()
).get(self._model_id),
max_sequence_length=_max_sequence_length,
)
elif isinstance(self._tokenizer, transformers.XLMRobertaTokenizer):
return NlpXLMRobertaTokenizationConfig(
max_sequence_length=getattr(
self._tokenizer, "max_model_input_sizes", dict()
).get(self._model_id),
max_sequence_length=_max_sequence_length
)
else:
japanese_morphological_tokenizers = ["mecab"]
@ -695,18 +691,38 @@ class TransformerModel:
):
return NlpBertJapaneseTokenizationConfig(
do_lower_case=getattr(self._tokenizer, "do_lower_case", None),
max_sequence_length=getattr(
self._tokenizer, "max_model_input_sizes", dict()
).get(self._model_id),
max_sequence_length=_max_sequence_length,
)
else:
return NlpBertTokenizationConfig(
do_lower_case=getattr(self._tokenizer, "do_lower_case", None),
max_sequence_length=getattr(
self._tokenizer, "max_model_input_sizes", dict()
).get(self._model_id),
max_sequence_length=_max_sequence_length,
)
def _find_max_sequence_length(self) -> int:
# Sometimes the max_... values are present but contain
# a random or very large value.
REASONABLE_MAX_LENGTH = 8192
max_len = getattr(self._tokenizer, "max_model_input_sizes", dict()).get(
self._model_id
)
if max_len is not None and max_len < REASONABLE_MAX_LENGTH:
return int(max_len)
max_len = getattr(self._tokenizer, "model_max_length", None)
if max_len is not None and max_len < REASONABLE_MAX_LENGTH:
return int(max_len)
model_config = getattr(self._traceable_model._model, "config", None)
if model_config is None:
raise ValueError("Cannot determine model max input length")
max_len = getattr(model_config, "max_position_embeddings", None)
if max_len is not None and max_len < REASONABLE_MAX_LENGTH:
return int(max_len)
raise ValueError("Cannot determine model max input length")
def _create_config(
self, es_version: Optional[Tuple[int, int, int]]
) -> NlpTrainedModelConfig:
@ -756,7 +772,7 @@ class TransformerModel:
),
)
def _create_traceable_model(self) -> TraceableModel:
def _create_traceable_model(self) -> _TransformerTraceableModel:
if self._task_type == "auto":
model = transformers.AutoModel.from_pretrained(
self._model_id, token=self._access_token, torchscript=True

View File

@ -0,0 +1,216 @@
# Licensed to Elasticsearch B.V. under one or more contributor
# license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright
# ownership. Elasticsearch B.V. licenses this file to you under
# the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import tempfile
import pytest
try:
import sklearn # noqa: F401
HAS_SKLEARN = True
except ImportError:
HAS_SKLEARN = False
try:
from eland.ml.pytorch.transformers import TransformerModel
HAS_TRANSFORMERS = True
except ImportError:
HAS_TRANSFORMERS = False
try:
import torch # noqa: F401
from eland.ml.pytorch import (
FillMaskInferenceOptions,
NerInferenceOptions,
NlpBertTokenizationConfig,
NlpMPNetTokenizationConfig,
NlpRobertaTokenizationConfig,
QuestionAnsweringInferenceOptions,
TextClassificationInferenceOptions,
TextEmbeddingInferenceOptions,
TextSimilarityInferenceOptions,
ZeroShotClassificationInferenceOptions,
)
HAS_PYTORCH = True
except ImportError:
HAS_PYTORCH = False
from tests import ES_VERSION
pytestmark = [
pytest.mark.skipif(
ES_VERSION < (8, 7, 0),
reason="Eland uses Pytorch 1.13.1, versions of Elasticsearch prior to 8.7.0 are incompatible with PyTorch 1.13.1",
),
pytest.mark.skipif(
not HAS_SKLEARN, reason="This test requires 'scikit-learn' package to run"
),
pytest.mark.skipif(
not HAS_TRANSFORMERS, reason="This test requires 'transformers' package to run"
),
pytest.mark.skipif(
not HAS_PYTORCH, reason="This test requires 'torch' package to run"
),
]
# If the required imports are missing the test will be skipped.
# Only define th test configurations if the referenced classes
# have been imported
if HAS_PYTORCH and HAS_SKLEARN and HAS_TRANSFORMERS:
MODEL_CONFIGURATIONS = [
(
"intfloat/e5-small-v2",
"text_embedding",
TextEmbeddingInferenceOptions,
NlpBertTokenizationConfig,
512,
384,
),
(
"sentence-transformers/all-mpnet-base-v2",
"text_embedding",
TextEmbeddingInferenceOptions,
NlpMPNetTokenizationConfig,
512,
768,
),
(
"sentence-transformers/all-MiniLM-L12-v2",
"text_embedding",
TextEmbeddingInferenceOptions,
NlpBertTokenizationConfig,
512,
384,
),
(
"facebook/dpr-ctx_encoder-multiset-base",
"text_embedding",
TextEmbeddingInferenceOptions,
NlpBertTokenizationConfig,
512,
768,
),
(
"distilbert-base-uncased",
"fill_mask",
FillMaskInferenceOptions,
NlpBertTokenizationConfig,
512,
None,
),
(
"bert-base-uncased",
"fill_mask",
FillMaskInferenceOptions,
NlpBertTokenizationConfig,
512,
None,
),
(
"elastic/distilbert-base-uncased-finetuned-conll03-english",
"ner",
NerInferenceOptions,
NlpBertTokenizationConfig,
512,
None,
),
(
"SamLowe/roberta-base-go_emotions",
"text_classification",
TextClassificationInferenceOptions,
NlpRobertaTokenizationConfig,
512,
None,
),
(
"distilbert-base-cased-distilled-squad",
"question_answering",
QuestionAnsweringInferenceOptions,
NlpBertTokenizationConfig,
386,
None,
),
(
"cross-encoder/ms-marco-TinyBERT-L-2-v2",
"text_similarity",
TextSimilarityInferenceOptions,
NlpBertTokenizationConfig,
512,
None,
),
(
"valhalla/distilbart-mnli-12-6",
"zero_shot_classification",
ZeroShotClassificationInferenceOptions,
NlpRobertaTokenizationConfig,
1024,
None,
),
]
else:
MODEL_CONFIGURATIONS = []
class TestModelConfguration:
@pytest.mark.parametrize(
"model_id,task_type,config_type,tokenizer_type,max_sequence_len,embedding_size",
MODEL_CONFIGURATIONS,
)
def test_text_prediction(
self,
model_id,
task_type,
config_type,
tokenizer_type,
max_sequence_len,
embedding_size,
):
with tempfile.TemporaryDirectory() as tmp_dir:
print("loading model " + model_id)
tm = TransformerModel(
model_id=model_id,
task_type=task_type,
es_version=ES_VERSION,
quantize=False,
)
_, config, _ = tm.save(tmp_dir)
assert "pytorch" == config.model_type
assert ["text_field"] == config.input.field_names
assert isinstance(config.inference_config, config_type)
tokenization = config.inference_config.tokenization
assert isinstance(tokenization, tokenizer_type)
assert max_sequence_len == tokenization.max_sequence_length
if task_type == "text_classification":
assert isinstance(config.inference_config.classification_labels, list)
assert len(config.inference_config.classification_labels) > 0
if task_type == "text_embedding":
assert embedding_size == config.inference_config.embedding_size
if task_type == "question_answering":
assert tokenization.truncate == "none"
assert tokenization.span > 0
if task_type == "zero_shot_classification":
assert isinstance(config.inference_config.classification_labels, list)
assert len(config.inference_config.classification_labels) > 0

View File

@ -24,13 +24,6 @@ import numpy as np
import pytest
from elasticsearch import NotFoundError
try:
import sklearn # noqa: F401
HAS_SKLEARN = True
except ImportError:
HAS_SKLEARN = False
try:
import torch # noqa: F401
from torch import Tensor, nn # noqa: F401
@ -67,9 +60,6 @@ pytestmark = [
ES_VERSION < (8, 0, 0),
reason="This test requires at least Elasticsearch version 8.0.0",
),
pytest.mark.skipif(
not HAS_SKLEARN, reason="This test requires 'scikit-learn' package to run"
),
pytest.mark.skipif(
not HAS_PYTORCH, reason="This test requires 'pytorch' package to run"
),