Fix direct usage of TransformerModel (#619)

This commit is contained in:
Quentin Pradet 2023-10-11 13:56:14 +04:00 committed by GitHub
parent 48e290a927
commit c6ce4b2c46
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 19 additions and 22 deletions

View File

@ -584,10 +584,10 @@ class TransformerModel:
self,
*,
model_id: str,
access_token: Optional[str],
task_type: str,
es_version: Optional[Tuple[int, int, int]] = None,
quantize: bool = False,
access_token: Optional[str] = None,
):
"""
Loads a model from the Hugging Face repository or local file and creates

View File

@ -112,6 +112,8 @@ def test(session, pandas_version: str):
"python",
"-m",
"pytest",
"-ra",
"--tb=native",
"--cov-report=term-missing",
"--cov=eland/",
"--cov-config=setup.cfg",

View File

@ -22,7 +22,7 @@ lightgbm>=2,<4
torch>=1.13.1,<2.0; python_version<'3.11'
# Versions known to be compatible with PyTorch 1.13.1
sentence-transformers>=2.1.0,<=2.2.2; python_version<'3.11'
transformers[torch]>=4.12.0,<=4.27.4; python_version<'3.11'
transformers[torch]>=4.31.0,<=4.33.2; python_version<'3.11'
#
# Testing

View File

@ -14,6 +14,7 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import platform
import tempfile
import pytest
@ -82,6 +83,14 @@ def setup_and_tear_down():
pass
@pytest.fixture(scope="session")
def quantize():
# quantization does not work on ARM processors
# TODO: It seems that PyTorch 2.0 supports OneDNN for aarch64. We should
# revisit this when we upgrade to PyTorch 2.0.
return platform.machine() not in ["arm64", "aarch64"]
def download_model_and_start_deployment(tmp_dir, quantize, model_id, task):
print("Loading HuggingFace transformer tokenizer and model")
tm = TransformerModel(
@ -103,31 +112,17 @@ def download_model_and_start_deployment(tmp_dir, quantize, model_id, task):
class TestPytorchModel:
def __init__(self):
# quantization does not work on ARM processors
# TODO: It seems that PyTorch 2.0 supports OneDNN for aarch64. We should
# revisit this when we upgrade to PyTorch 2.0.
import platform
self.quantize = (
True if platform.machine() not in ["arm64", "aarch64"] else False
)
@pytest.mark.parametrize("model_id,task,text_input,value", TEXT_PREDICTION_MODELS)
def test_text_prediction(self, model_id, task, text_input, value):
def test_text_prediction(self, model_id, task, text_input, value, quantize):
with tempfile.TemporaryDirectory() as tmp_dir:
ptm = download_model_and_start_deployment(
tmp_dir, self.quantize, model_id, task
)
result = ptm.infer(docs=[{"text_field": text_input}])
assert result["predicted_value"] == value
ptm = download_model_and_start_deployment(tmp_dir, quantize, model_id, task)
results = ptm.infer(docs=[{"text_field": text_input}])
assert results.body["inference_results"][0]["predicted_value"] == value
@pytest.mark.parametrize("model_id,task,text_input", TEXT_EMBEDDING_MODELS)
def test_text_embedding(self, model_id, task, text_input):
def test_text_embedding(self, model_id, task, text_input, quantize):
with tempfile.TemporaryDirectory() as tmp_dir:
ptm = download_model_and_start_deployment(
tmp_dir, self.quantize, model_id, task
)
ptm = download_model_and_start_deployment(tmp_dir, quantize, model_id, task)
ptm.infer(docs=[{"text_field": text_input}])
if ES_VERSION >= (8, 8, 0):