mirror of
https://github.com/elastic/eland.git
synced 2025-07-11 00:02:14 +08:00
Fix direct usage of TransformerModel (#619)
This commit is contained in:
parent
48e290a927
commit
c6ce4b2c46
@ -584,10 +584,10 @@ class TransformerModel:
|
||||
self,
|
||||
*,
|
||||
model_id: str,
|
||||
access_token: Optional[str],
|
||||
task_type: str,
|
||||
es_version: Optional[Tuple[int, int, int]] = None,
|
||||
quantize: bool = False,
|
||||
access_token: Optional[str] = None,
|
||||
):
|
||||
"""
|
||||
Loads a model from the Hugging Face repository or local file and creates
|
||||
|
@ -112,6 +112,8 @@ def test(session, pandas_version: str):
|
||||
"python",
|
||||
"-m",
|
||||
"pytest",
|
||||
"-ra",
|
||||
"--tb=native",
|
||||
"--cov-report=term-missing",
|
||||
"--cov=eland/",
|
||||
"--cov-config=setup.cfg",
|
||||
|
@ -22,7 +22,7 @@ lightgbm>=2,<4
|
||||
torch>=1.13.1,<2.0; python_version<'3.11'
|
||||
# Versions known to be compatible with PyTorch 1.13.1
|
||||
sentence-transformers>=2.1.0,<=2.2.2; python_version<'3.11'
|
||||
transformers[torch]>=4.12.0,<=4.27.4; python_version<'3.11'
|
||||
transformers[torch]>=4.31.0,<=4.33.2; python_version<'3.11'
|
||||
|
||||
#
|
||||
# Testing
|
||||
|
@ -14,6 +14,7 @@
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
import platform
|
||||
import tempfile
|
||||
|
||||
import pytest
|
||||
@ -82,6 +83,14 @@ def setup_and_tear_down():
|
||||
pass
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def quantize():
|
||||
# quantization does not work on ARM processors
|
||||
# TODO: It seems that PyTorch 2.0 supports OneDNN for aarch64. We should
|
||||
# revisit this when we upgrade to PyTorch 2.0.
|
||||
return platform.machine() not in ["arm64", "aarch64"]
|
||||
|
||||
|
||||
def download_model_and_start_deployment(tmp_dir, quantize, model_id, task):
|
||||
print("Loading HuggingFace transformer tokenizer and model")
|
||||
tm = TransformerModel(
|
||||
@ -103,31 +112,17 @@ def download_model_and_start_deployment(tmp_dir, quantize, model_id, task):
|
||||
|
||||
|
||||
class TestPytorchModel:
|
||||
def __init__(self):
|
||||
# quantization does not work on ARM processors
|
||||
# TODO: It seems that PyTorch 2.0 supports OneDNN for aarch64. We should
|
||||
# revisit this when we upgrade to PyTorch 2.0.
|
||||
import platform
|
||||
|
||||
self.quantize = (
|
||||
True if platform.machine() not in ["arm64", "aarch64"] else False
|
||||
)
|
||||
|
||||
@pytest.mark.parametrize("model_id,task,text_input,value", TEXT_PREDICTION_MODELS)
|
||||
def test_text_prediction(self, model_id, task, text_input, value):
|
||||
def test_text_prediction(self, model_id, task, text_input, value, quantize):
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
ptm = download_model_and_start_deployment(
|
||||
tmp_dir, self.quantize, model_id, task
|
||||
)
|
||||
result = ptm.infer(docs=[{"text_field": text_input}])
|
||||
assert result["predicted_value"] == value
|
||||
ptm = download_model_and_start_deployment(tmp_dir, quantize, model_id, task)
|
||||
results = ptm.infer(docs=[{"text_field": text_input}])
|
||||
assert results.body["inference_results"][0]["predicted_value"] == value
|
||||
|
||||
@pytest.mark.parametrize("model_id,task,text_input", TEXT_EMBEDDING_MODELS)
|
||||
def test_text_embedding(self, model_id, task, text_input):
|
||||
def test_text_embedding(self, model_id, task, text_input, quantize):
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
ptm = download_model_and_start_deployment(
|
||||
tmp_dir, self.quantize, model_id, task
|
||||
)
|
||||
ptm = download_model_and_start_deployment(tmp_dir, quantize, model_id, task)
|
||||
ptm.infer(docs=[{"text_field": text_input}])
|
||||
|
||||
if ES_VERSION >= (8, 8, 0):
|
||||
|
Loading…
x
Reference in New Issue
Block a user