mirror of
https://github.com/elastic/eland.git
synced 2025-07-11 00:02:14 +08:00
[ML] Add inference results tests for PyTorch transformer models
This commit is contained in:
parent
66e3e4eaad
commit
a3b0907c5b
@ -24,7 +24,7 @@
|
||||
- inject:
|
||||
properties-content: HOME=$JENKINS_HOME
|
||||
concurrent: true
|
||||
node: ubuntu
|
||||
node: flyweight
|
||||
scm:
|
||||
- git:
|
||||
name: origin
|
||||
|
@ -121,6 +121,7 @@ if [[ "$ELASTICSEARCH_VERSION" != *oss* ]]; then
|
||||
--env xpack.security.enabled=false
|
||||
--env xpack.security.http.ssl.enabled=false
|
||||
--env xpack.security.transport.ssl.enabled=false
|
||||
--env xpack.ml.max_machine_memory_percent=90
|
||||
END
|
||||
))
|
||||
fi
|
||||
|
@ -96,8 +96,8 @@ class PyTorchModel:
|
||||
) -> None:
|
||||
# TODO: Implement some pre-flight checks on config, vocab, and model
|
||||
self.put_config(config_path)
|
||||
self.put_vocab(vocab_path)
|
||||
self.put_model(model_path, chunk_size)
|
||||
self.put_vocab(vocab_path)
|
||||
|
||||
def infer(
|
||||
self, body: Dict[str, Any], timeout: str = DEFAULT_TIMEOUT
|
||||
@ -106,14 +106,14 @@ class PyTorchModel:
|
||||
"POST",
|
||||
f"/_ml/trained_models/{self.model_id}/deployment/_infer",
|
||||
body=body,
|
||||
params={"timeout": timeout},
|
||||
params={"timeout": timeout, "request_timeout": 60},
|
||||
)
|
||||
|
||||
def start(self, timeout: str = DEFAULT_TIMEOUT) -> None:
|
||||
self._client.transport.perform_request(
|
||||
"POST",
|
||||
f"/_ml/trained_models/{self.model_id}/deployment/_start",
|
||||
params={"timeout": timeout, "wait_for": "started"},
|
||||
params={"timeout": timeout, "request_timeout": 60, "wait_for": "started"},
|
||||
)
|
||||
|
||||
def stop(self) -> None:
|
||||
|
94
tests/ml/pytorch/test_pytorch_model_pytest.py
Normal file
94
tests/ml/pytorch/test_pytorch_model_pytest.py
Normal file
@ -0,0 +1,94 @@
|
||||
# Licensed to Elasticsearch B.V. under one or more contributor
|
||||
# license agreements. See the NOTICE file distributed with
|
||||
# this work for additional information regarding copyright
|
||||
# ownership. Elasticsearch B.V. licenses this file to you under
|
||||
# the Apache License, Version 2.0 (the "License"); you may
|
||||
# not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
import tempfile
|
||||
|
||||
import pytest
|
||||
|
||||
try:
|
||||
import sklearn # noqa: F401
|
||||
|
||||
HAS_SKLEARN = True
|
||||
except ImportError:
|
||||
HAS_SKLEARN = False
|
||||
|
||||
try:
|
||||
import transformers # noqa: F401
|
||||
|
||||
from eland.ml.pytorch import PyTorchModel
|
||||
from eland.ml.pytorch.transformers import TransformerModel
|
||||
|
||||
HAS_TRANSFORMERS = True
|
||||
except ImportError:
|
||||
HAS_TRANSFORMERS = False
|
||||
|
||||
from tests import ES_TEST_CLIENT, ES_VERSION
|
||||
|
||||
pytestmark = [
|
||||
pytest.mark.skipif(
|
||||
ES_VERSION < (8, 0, 0),
|
||||
reason="This test requires at least Elasticsearch version 8.0.0",
|
||||
),
|
||||
pytest.mark.skipif(
|
||||
not HAS_SKLEARN, reason="This test requires 'scikit-learn' package to run"
|
||||
),
|
||||
pytest.mark.skipif(
|
||||
not HAS_TRANSFORMERS, reason="This test requires 'transformers' package to run"
|
||||
),
|
||||
]
|
||||
|
||||
TEXT_PREDICTION_MODELS = [
|
||||
(
|
||||
"distilbert-base-uncased",
|
||||
"fill_mask",
|
||||
"[MASK] is the capital of France.",
|
||||
"paris",
|
||||
)
|
||||
]
|
||||
|
||||
|
||||
@pytest.fixture(scope="function", autouse=True)
|
||||
def setup_and_tear_down():
|
||||
ES_TEST_CLIENT.cluster.put_settings(
|
||||
body={"transient": {"logger.org.elasticsearch.xpack.ml": "DEBUG"}}
|
||||
)
|
||||
yield
|
||||
for model_id, _, _, _ in TEXT_PREDICTION_MODELS:
|
||||
model = PyTorchModel(ES_TEST_CLIENT, model_id.replace("/", "__").lower()[:64])
|
||||
model.stop()
|
||||
model.delete()
|
||||
|
||||
|
||||
def download_model_and_start_deployment(tmp_dir, quantize, model_id, task):
|
||||
print("Loading HuggingFace transformer tokenizer and model")
|
||||
tm = TransformerModel(model_id, task, quantize)
|
||||
model_path, config_path, vocab_path = tm.save(tmp_dir)
|
||||
ptm = PyTorchModel(ES_TEST_CLIENT, tm.elasticsearch_model_id())
|
||||
ptm.stop()
|
||||
ptm.delete()
|
||||
print(f"Importing model: {ptm.model_id}")
|
||||
ptm.import_model(model_path, config_path, vocab_path)
|
||||
ptm.start()
|
||||
return ptm
|
||||
|
||||
|
||||
class TestPytorchModel:
|
||||
@pytest.mark.parametrize("model_id,task,text_input,value", TEXT_PREDICTION_MODELS)
|
||||
def test_text_classification(self, model_id, task, text_input, value):
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
ptm = download_model_and_start_deployment(tmp_dir, True, model_id, task)
|
||||
result = ptm.infer({"docs": [{"text_field": text_input}]})
|
||||
assert result["predicted_value"] == value
|
Loading…
x
Reference in New Issue
Block a user