mirror of
https://github.com/elastic/eland.git
synced 2025-07-11 00:02:14 +08:00
[ML] Add inference results tests for PyTorch transformer models
This commit is contained in:
parent
66e3e4eaad
commit
a3b0907c5b
@ -24,7 +24,7 @@
|
|||||||
- inject:
|
- inject:
|
||||||
properties-content: HOME=$JENKINS_HOME
|
properties-content: HOME=$JENKINS_HOME
|
||||||
concurrent: true
|
concurrent: true
|
||||||
node: ubuntu
|
node: flyweight
|
||||||
scm:
|
scm:
|
||||||
- git:
|
- git:
|
||||||
name: origin
|
name: origin
|
||||||
|
@ -121,6 +121,7 @@ if [[ "$ELASTICSEARCH_VERSION" != *oss* ]]; then
|
|||||||
--env xpack.security.enabled=false
|
--env xpack.security.enabled=false
|
||||||
--env xpack.security.http.ssl.enabled=false
|
--env xpack.security.http.ssl.enabled=false
|
||||||
--env xpack.security.transport.ssl.enabled=false
|
--env xpack.security.transport.ssl.enabled=false
|
||||||
|
--env xpack.ml.max_machine_memory_percent=90
|
||||||
END
|
END
|
||||||
))
|
))
|
||||||
fi
|
fi
|
||||||
|
@ -96,8 +96,8 @@ class PyTorchModel:
|
|||||||
) -> None:
|
) -> None:
|
||||||
# TODO: Implement some pre-flight checks on config, vocab, and model
|
# TODO: Implement some pre-flight checks on config, vocab, and model
|
||||||
self.put_config(config_path)
|
self.put_config(config_path)
|
||||||
self.put_vocab(vocab_path)
|
|
||||||
self.put_model(model_path, chunk_size)
|
self.put_model(model_path, chunk_size)
|
||||||
|
self.put_vocab(vocab_path)
|
||||||
|
|
||||||
def infer(
|
def infer(
|
||||||
self, body: Dict[str, Any], timeout: str = DEFAULT_TIMEOUT
|
self, body: Dict[str, Any], timeout: str = DEFAULT_TIMEOUT
|
||||||
@ -106,14 +106,14 @@ class PyTorchModel:
|
|||||||
"POST",
|
"POST",
|
||||||
f"/_ml/trained_models/{self.model_id}/deployment/_infer",
|
f"/_ml/trained_models/{self.model_id}/deployment/_infer",
|
||||||
body=body,
|
body=body,
|
||||||
params={"timeout": timeout},
|
params={"timeout": timeout, "request_timeout": 60},
|
||||||
)
|
)
|
||||||
|
|
||||||
def start(self, timeout: str = DEFAULT_TIMEOUT) -> None:
|
def start(self, timeout: str = DEFAULT_TIMEOUT) -> None:
|
||||||
self._client.transport.perform_request(
|
self._client.transport.perform_request(
|
||||||
"POST",
|
"POST",
|
||||||
f"/_ml/trained_models/{self.model_id}/deployment/_start",
|
f"/_ml/trained_models/{self.model_id}/deployment/_start",
|
||||||
params={"timeout": timeout, "wait_for": "started"},
|
params={"timeout": timeout, "request_timeout": 60, "wait_for": "started"},
|
||||||
)
|
)
|
||||||
|
|
||||||
def stop(self) -> None:
|
def stop(self) -> None:
|
||||||
|
94
tests/ml/pytorch/test_pytorch_model_pytest.py
Normal file
94
tests/ml/pytorch/test_pytorch_model_pytest.py
Normal file
@ -0,0 +1,94 @@
|
|||||||
|
# Licensed to Elasticsearch B.V. under one or more contributor
|
||||||
|
# license agreements. See the NOTICE file distributed with
|
||||||
|
# this work for additional information regarding copyright
|
||||||
|
# ownership. Elasticsearch B.V. licenses this file to you under
|
||||||
|
# the Apache License, Version 2.0 (the "License"); you may
|
||||||
|
# not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing,
|
||||||
|
# software distributed under the License is distributed on an
|
||||||
|
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||||
|
# KIND, either express or implied. See the License for the
|
||||||
|
# specific language governing permissions and limitations
|
||||||
|
# under the License.
|
||||||
|
import tempfile
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
try:
|
||||||
|
import sklearn # noqa: F401
|
||||||
|
|
||||||
|
HAS_SKLEARN = True
|
||||||
|
except ImportError:
|
||||||
|
HAS_SKLEARN = False
|
||||||
|
|
||||||
|
try:
|
||||||
|
import transformers # noqa: F401
|
||||||
|
|
||||||
|
from eland.ml.pytorch import PyTorchModel
|
||||||
|
from eland.ml.pytorch.transformers import TransformerModel
|
||||||
|
|
||||||
|
HAS_TRANSFORMERS = True
|
||||||
|
except ImportError:
|
||||||
|
HAS_TRANSFORMERS = False
|
||||||
|
|
||||||
|
from tests import ES_TEST_CLIENT, ES_VERSION
|
||||||
|
|
||||||
|
pytestmark = [
|
||||||
|
pytest.mark.skipif(
|
||||||
|
ES_VERSION < (8, 0, 0),
|
||||||
|
reason="This test requires at least Elasticsearch version 8.0.0",
|
||||||
|
),
|
||||||
|
pytest.mark.skipif(
|
||||||
|
not HAS_SKLEARN, reason="This test requires 'scikit-learn' package to run"
|
||||||
|
),
|
||||||
|
pytest.mark.skipif(
|
||||||
|
not HAS_TRANSFORMERS, reason="This test requires 'transformers' package to run"
|
||||||
|
),
|
||||||
|
]
|
||||||
|
|
||||||
|
TEXT_PREDICTION_MODELS = [
|
||||||
|
(
|
||||||
|
"distilbert-base-uncased",
|
||||||
|
"fill_mask",
|
||||||
|
"[MASK] is the capital of France.",
|
||||||
|
"paris",
|
||||||
|
)
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="function", autouse=True)
|
||||||
|
def setup_and_tear_down():
|
||||||
|
ES_TEST_CLIENT.cluster.put_settings(
|
||||||
|
body={"transient": {"logger.org.elasticsearch.xpack.ml": "DEBUG"}}
|
||||||
|
)
|
||||||
|
yield
|
||||||
|
for model_id, _, _, _ in TEXT_PREDICTION_MODELS:
|
||||||
|
model = PyTorchModel(ES_TEST_CLIENT, model_id.replace("/", "__").lower()[:64])
|
||||||
|
model.stop()
|
||||||
|
model.delete()
|
||||||
|
|
||||||
|
|
||||||
|
def download_model_and_start_deployment(tmp_dir, quantize, model_id, task):
|
||||||
|
print("Loading HuggingFace transformer tokenizer and model")
|
||||||
|
tm = TransformerModel(model_id, task, quantize)
|
||||||
|
model_path, config_path, vocab_path = tm.save(tmp_dir)
|
||||||
|
ptm = PyTorchModel(ES_TEST_CLIENT, tm.elasticsearch_model_id())
|
||||||
|
ptm.stop()
|
||||||
|
ptm.delete()
|
||||||
|
print(f"Importing model: {ptm.model_id}")
|
||||||
|
ptm.import_model(model_path, config_path, vocab_path)
|
||||||
|
ptm.start()
|
||||||
|
return ptm
|
||||||
|
|
||||||
|
|
||||||
|
class TestPytorchModel:
|
||||||
|
@pytest.mark.parametrize("model_id,task,text_input,value", TEXT_PREDICTION_MODELS)
|
||||||
|
def test_text_classification(self, model_id, task, text_input, value):
|
||||||
|
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||||
|
ptm = download_model_and_start_deployment(tmp_dir, True, model_id, task)
|
||||||
|
result = ptm.infer({"docs": [{"text_field": text_input}]})
|
||||||
|
assert result["predicted_value"] == value
|
Loading…
x
Reference in New Issue
Block a user