Choose text_embedding from auto when task type is unknown but its a sentence-transfomers model (#516)

closes https://github.com/elastic/eland/issues/514
This commit is contained in:
Benjamin Trent 2023-02-09 12:50:30 -05:00 committed by GitHub
parent 0576114a1d
commit d5578637cb
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 3 additions and 0 deletions

View File

@ -130,6 +130,8 @@ def task_type_from_model_config(model_config: PretrainedConfig) -> Optional[str]
for t in task_type: for t in task_type:
potential_task_types.add(t) potential_task_types.add(t)
if len(potential_task_types) == 0: if len(potential_task_types) == 0:
if model_config.name_or_path.startswith("sentence-transformers/"):
return "text_embedding"
return None return None
if ( if (
"text_classification" in potential_task_types "text_classification" in potential_task_types

View File

@ -230,6 +230,7 @@ AUTO_TASK_RESULTS = [
("sentence-transformers/any_bert", "BERTMaskedLM", None, "text_embedding"), ("sentence-transformers/any_bert", "BERTMaskedLM", None, "text_embedding"),
("sentence-transformers/any_roberta", "RoBERTaMaskedLM", None, "text_embedding"), ("sentence-transformers/any_roberta", "RoBERTaMaskedLM", None, "text_embedding"),
("sentence-transformers/mpnet", "MPNetMaskedLM", None, "text_embedding"), ("sentence-transformers/mpnet", "MPNetMaskedLM", None, "text_embedding"),
("sentence-transformers/any_bert", "BertModel", None, "text_embedding"),
("anynermodel", "BERTForTokenClassification", None, "ner"), ("anynermodel", "BERTForTokenClassification", None, "ner"),
("anynermodel", "MPNetForTokenClassification", None, "ner"), ("anynermodel", "MPNetForTokenClassification", None, "ner"),
("anynermodel", "RoBERTaForTokenClassification", None, "ner"), ("anynermodel", "RoBERTaForTokenClassification", None, "ner"),