mirror of
https://github.com/elastic/eland.git
synced 2025-07-11 00:02:14 +08:00
Fix direct usage of TransformerModel (#619)
This commit is contained in:
parent
48e290a927
commit
c6ce4b2c46
@ -584,10 +584,10 @@ class TransformerModel:
|
|||||||
self,
|
self,
|
||||||
*,
|
*,
|
||||||
model_id: str,
|
model_id: str,
|
||||||
access_token: Optional[str],
|
|
||||||
task_type: str,
|
task_type: str,
|
||||||
es_version: Optional[Tuple[int, int, int]] = None,
|
es_version: Optional[Tuple[int, int, int]] = None,
|
||||||
quantize: bool = False,
|
quantize: bool = False,
|
||||||
|
access_token: Optional[str] = None,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Loads a model from the Hugging Face repository or local file and creates
|
Loads a model from the Hugging Face repository or local file and creates
|
||||||
|
@ -112,6 +112,8 @@ def test(session, pandas_version: str):
|
|||||||
"python",
|
"python",
|
||||||
"-m",
|
"-m",
|
||||||
"pytest",
|
"pytest",
|
||||||
|
"-ra",
|
||||||
|
"--tb=native",
|
||||||
"--cov-report=term-missing",
|
"--cov-report=term-missing",
|
||||||
"--cov=eland/",
|
"--cov=eland/",
|
||||||
"--cov-config=setup.cfg",
|
"--cov-config=setup.cfg",
|
||||||
|
@ -22,7 +22,7 @@ lightgbm>=2,<4
|
|||||||
torch>=1.13.1,<2.0; python_version<'3.11'
|
torch>=1.13.1,<2.0; python_version<'3.11'
|
||||||
# Versions known to be compatible with PyTorch 1.13.1
|
# Versions known to be compatible with PyTorch 1.13.1
|
||||||
sentence-transformers>=2.1.0,<=2.2.2; python_version<'3.11'
|
sentence-transformers>=2.1.0,<=2.2.2; python_version<'3.11'
|
||||||
transformers[torch]>=4.12.0,<=4.27.4; python_version<'3.11'
|
transformers[torch]>=4.31.0,<=4.33.2; python_version<'3.11'
|
||||||
|
|
||||||
#
|
#
|
||||||
# Testing
|
# Testing
|
||||||
|
@ -14,6 +14,7 @@
|
|||||||
# KIND, either express or implied. See the License for the
|
# KIND, either express or implied. See the License for the
|
||||||
# specific language governing permissions and limitations
|
# specific language governing permissions and limitations
|
||||||
# under the License.
|
# under the License.
|
||||||
|
import platform
|
||||||
import tempfile
|
import tempfile
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
@ -82,6 +83,14 @@ def setup_and_tear_down():
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="session")
|
||||||
|
def quantize():
|
||||||
|
# quantization does not work on ARM processors
|
||||||
|
# TODO: It seems that PyTorch 2.0 supports OneDNN for aarch64. We should
|
||||||
|
# revisit this when we upgrade to PyTorch 2.0.
|
||||||
|
return platform.machine() not in ["arm64", "aarch64"]
|
||||||
|
|
||||||
|
|
||||||
def download_model_and_start_deployment(tmp_dir, quantize, model_id, task):
|
def download_model_and_start_deployment(tmp_dir, quantize, model_id, task):
|
||||||
print("Loading HuggingFace transformer tokenizer and model")
|
print("Loading HuggingFace transformer tokenizer and model")
|
||||||
tm = TransformerModel(
|
tm = TransformerModel(
|
||||||
@ -103,31 +112,17 @@ def download_model_and_start_deployment(tmp_dir, quantize, model_id, task):
|
|||||||
|
|
||||||
|
|
||||||
class TestPytorchModel:
|
class TestPytorchModel:
|
||||||
def __init__(self):
|
|
||||||
# quantization does not work on ARM processors
|
|
||||||
# TODO: It seems that PyTorch 2.0 supports OneDNN for aarch64. We should
|
|
||||||
# revisit this when we upgrade to PyTorch 2.0.
|
|
||||||
import platform
|
|
||||||
|
|
||||||
self.quantize = (
|
|
||||||
True if platform.machine() not in ["arm64", "aarch64"] else False
|
|
||||||
)
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("model_id,task,text_input,value", TEXT_PREDICTION_MODELS)
|
@pytest.mark.parametrize("model_id,task,text_input,value", TEXT_PREDICTION_MODELS)
|
||||||
def test_text_prediction(self, model_id, task, text_input, value):
|
def test_text_prediction(self, model_id, task, text_input, value, quantize):
|
||||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||||
ptm = download_model_and_start_deployment(
|
ptm = download_model_and_start_deployment(tmp_dir, quantize, model_id, task)
|
||||||
tmp_dir, self.quantize, model_id, task
|
results = ptm.infer(docs=[{"text_field": text_input}])
|
||||||
)
|
assert results.body["inference_results"][0]["predicted_value"] == value
|
||||||
result = ptm.infer(docs=[{"text_field": text_input}])
|
|
||||||
assert result["predicted_value"] == value
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("model_id,task,text_input", TEXT_EMBEDDING_MODELS)
|
@pytest.mark.parametrize("model_id,task,text_input", TEXT_EMBEDDING_MODELS)
|
||||||
def test_text_embedding(self, model_id, task, text_input):
|
def test_text_embedding(self, model_id, task, text_input, quantize):
|
||||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||||
ptm = download_model_and_start_deployment(
|
ptm = download_model_and_start_deployment(tmp_dir, quantize, model_id, task)
|
||||||
tmp_dir, self.quantize, model_id, task
|
|
||||||
)
|
|
||||||
ptm.infer(docs=[{"text_field": text_input}])
|
ptm.infer(docs=[{"text_field": text_input}])
|
||||||
|
|
||||||
if ES_VERSION >= (8, 8, 0):
|
if ES_VERSION >= (8, 8, 0):
|
||||||
|
Loading…
x
Reference in New Issue
Block a user