mirror of
https://github.com/elastic/eland.git
synced 2025-07-24 00:00:39 +08:00
Upgrade Sentence Transformers to v5 (#801)
Sentence Transformers v5 adds support for sparse embedding models and is now necessary for importing sparse models such as https://huggingface.co/naver/splade-v3-distilbert.
This commit is contained in:
parent
117f61b010
commit
bebb9d52e5
4
setup.py
4
setup.py
@ -62,10 +62,10 @@ extras = {
|
|||||||
"requests<3",
|
"requests<3",
|
||||||
"torch==2.5.1",
|
"torch==2.5.1",
|
||||||
"tqdm",
|
"tqdm",
|
||||||
"sentence-transformers>=2.1.0,<=2.7.0",
|
"sentence-transformers>=5.0.0,<6.0.0",
|
||||||
# sentencepiece is a required dependency for the slow tokenizers
|
# sentencepiece is a required dependency for the slow tokenizers
|
||||||
# https://huggingface.co/transformers/v4.4.2/migration.html#sentencepiece-is-removed-from-the-required-dependencies
|
# https://huggingface.co/transformers/v4.4.2/migration.html#sentencepiece-is-removed-from-the-required-dependencies
|
||||||
"transformers[sentencepiece]>=4.47.0",
|
"transformers[sentencepiece]>=4.47.0,<4.50.3",
|
||||||
],
|
],
|
||||||
}
|
}
|
||||||
extras["all"] = list({dep for deps in extras.values() for dep in deps})
|
extras["all"] = list({dep for deps in extras.values() for dep in deps})
|
||||||
|
@ -65,6 +65,8 @@ TEXT_EMBEDDING_MODELS = [
|
|||||||
|
|
||||||
TEXT_SIMILARITY_MODELS = ["mixedbread-ai/mxbai-rerank-xsmall-v1"]
|
TEXT_SIMILARITY_MODELS = ["mixedbread-ai/mxbai-rerank-xsmall-v1"]
|
||||||
|
|
||||||
|
TEXT_EXPANSION_MODELS = ["naver/splade-v3-distilbert"]
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="function", autouse=True)
|
@pytest.fixture(scope="function", autouse=True)
|
||||||
def setup_and_tear_down():
|
def setup_and_tear_down():
|
||||||
@ -155,3 +157,22 @@ class TestPytorchModel:
|
|||||||
|
|
||||||
assert result.body["inference_results"][0]["predicted_value"] < 0
|
assert result.body["inference_results"][0]["predicted_value"] < 0
|
||||||
assert result.body["inference_results"][1]["predicted_value"] > 0
|
assert result.body["inference_results"][1]["predicted_value"] > 0
|
||||||
|
|
||||||
|
@pytest.mark.skipif(ES_VERSION < (9, 0, 0), reason="requires current major version")
|
||||||
|
@pytest.mark.parametrize("model_id", TEXT_EXPANSION_MODELS)
|
||||||
|
def test_text_expansion(self, model_id):
|
||||||
|
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||||
|
ptm = download_model_and_start_deployment(
|
||||||
|
tmp_dir, False, model_id, "text_expansion"
|
||||||
|
)
|
||||||
|
result = ptm.infer(
|
||||||
|
docs=[
|
||||||
|
{
|
||||||
|
"text_field": "The Amazon rainforest covers most of the Amazon basin in South America"
|
||||||
|
},
|
||||||
|
{"text_field": "Paris is the capital of France"},
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
assert len(result.body["inference_results"][0]["predicted_value"]) > 0
|
||||||
|
assert len(result.body["inference_results"][1]["predicted_value"]) > 0
|
||||||
|
Loading…
x
Reference in New Issue
Block a user