Fix direct usage of TransformerModel (#619)

This commit is contained in:
Quentin Pradet 2023-10-11 13:56:14 +04:00 committed by GitHub
parent 48e290a927
commit c6ce4b2c46
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 19 additions and 22 deletions

View File

@ -584,10 +584,10 @@ class TransformerModel:
self, self,
*, *,
model_id: str, model_id: str,
access_token: Optional[str],
task_type: str, task_type: str,
es_version: Optional[Tuple[int, int, int]] = None, es_version: Optional[Tuple[int, int, int]] = None,
quantize: bool = False, quantize: bool = False,
access_token: Optional[str] = None,
): ):
""" """
Loads a model from the Hugging Face repository or local file and creates Loads a model from the Hugging Face repository or local file and creates

View File

@ -112,6 +112,8 @@ def test(session, pandas_version: str):
"python", "python",
"-m", "-m",
"pytest", "pytest",
"-ra",
"--tb=native",
"--cov-report=term-missing", "--cov-report=term-missing",
"--cov=eland/", "--cov=eland/",
"--cov-config=setup.cfg", "--cov-config=setup.cfg",

View File

@ -22,7 +22,7 @@ lightgbm>=2,<4
torch>=1.13.1,<2.0; python_version<'3.11' torch>=1.13.1,<2.0; python_version<'3.11'
# Versions known to be compatible with PyTorch 1.13.1 # Versions known to be compatible with PyTorch 1.13.1
sentence-transformers>=2.1.0,<=2.2.2; python_version<'3.11' sentence-transformers>=2.1.0,<=2.2.2; python_version<'3.11'
transformers[torch]>=4.12.0,<=4.27.4; python_version<'3.11' transformers[torch]>=4.31.0,<=4.33.2; python_version<'3.11'
# #
# Testing # Testing

View File

@ -14,6 +14,7 @@
# KIND, either express or implied. See the License for the # KIND, either express or implied. See the License for the
# specific language governing permissions and limitations # specific language governing permissions and limitations
# under the License. # under the License.
import platform
import tempfile import tempfile
import pytest import pytest
@ -82,6 +83,14 @@ def setup_and_tear_down():
pass pass
@pytest.fixture(scope="session")
def quantize():
# quantization does not work on ARM processors
# TODO: It seems that PyTorch 2.0 supports OneDNN for aarch64. We should
# revisit this when we upgrade to PyTorch 2.0.
return platform.machine() not in ["arm64", "aarch64"]
def download_model_and_start_deployment(tmp_dir, quantize, model_id, task): def download_model_and_start_deployment(tmp_dir, quantize, model_id, task):
print("Loading HuggingFace transformer tokenizer and model") print("Loading HuggingFace transformer tokenizer and model")
tm = TransformerModel( tm = TransformerModel(
@ -103,31 +112,17 @@ def download_model_and_start_deployment(tmp_dir, quantize, model_id, task):
class TestPytorchModel: class TestPytorchModel:
def __init__(self):
# quantization does not work on ARM processors
# TODO: It seems that PyTorch 2.0 supports OneDNN for aarch64. We should
# revisit this when we upgrade to PyTorch 2.0.
import platform
self.quantize = (
True if platform.machine() not in ["arm64", "aarch64"] else False
)
@pytest.mark.parametrize("model_id,task,text_input,value", TEXT_PREDICTION_MODELS) @pytest.mark.parametrize("model_id,task,text_input,value", TEXT_PREDICTION_MODELS)
def test_text_prediction(self, model_id, task, text_input, value): def test_text_prediction(self, model_id, task, text_input, value, quantize):
with tempfile.TemporaryDirectory() as tmp_dir: with tempfile.TemporaryDirectory() as tmp_dir:
ptm = download_model_and_start_deployment( ptm = download_model_and_start_deployment(tmp_dir, quantize, model_id, task)
tmp_dir, self.quantize, model_id, task results = ptm.infer(docs=[{"text_field": text_input}])
) assert results.body["inference_results"][0]["predicted_value"] == value
result = ptm.infer(docs=[{"text_field": text_input}])
assert result["predicted_value"] == value
@pytest.mark.parametrize("model_id,task,text_input", TEXT_EMBEDDING_MODELS) @pytest.mark.parametrize("model_id,task,text_input", TEXT_EMBEDDING_MODELS)
def test_text_embedding(self, model_id, task, text_input): def test_text_embedding(self, model_id, task, text_input, quantize):
with tempfile.TemporaryDirectory() as tmp_dir: with tempfile.TemporaryDirectory() as tmp_dir:
ptm = download_model_and_start_deployment( ptm = download_model_and_start_deployment(tmp_dir, quantize, model_id, task)
tmp_dir, self.quantize, model_id, task
)
ptm.infer(docs=[{"text_field": text_input}]) ptm.infer(docs=[{"text_field": text_input}])
if ES_VERSION >= (8, 8, 0): if ES_VERSION >= (8, 8, 0):