mirror of
https://github.com/elastic/eland.git
synced 2025-07-11 00:02:14 +08:00
Allow importing private HuggingFace models (#608)
This commit is contained in:
parent
5ec760635b
commit
566bb9e990
@ -160,7 +160,7 @@ underscores `__`.
|
|||||||
|
|
||||||
The following authentication options are available when using the import script:
|
The following authentication options are available when using the import script:
|
||||||
|
|
||||||
* username and password authentication (specified with the `-u` and `-p` options):
|
* Elasticsearch username and password authentication (specified with the `-u` and `-p` options):
|
||||||
+
|
+
|
||||||
--
|
--
|
||||||
[source,bash]
|
[source,bash]
|
||||||
@ -170,7 +170,7 @@ eland_import_hub_model -u <username> -p <password> --cloud-id <cloud-id> ...
|
|||||||
These `-u` and `-p` options also work when you use `--url`.
|
These `-u` and `-p` options also work when you use `--url`.
|
||||||
--
|
--
|
||||||
|
|
||||||
* username and password authentication (embedded in the URL):
|
* Elasticsearch username and password authentication (embedded in the URL):
|
||||||
+
|
+
|
||||||
--
|
--
|
||||||
[source,bash]
|
[source,bash]
|
||||||
@ -179,7 +179,7 @@ eland_import_hub_model --url https://<user>:<password>@<hostname>:<port> ...
|
|||||||
--------------------------------------------------
|
--------------------------------------------------
|
||||||
--
|
--
|
||||||
|
|
||||||
* API key authentication:
|
* Elasticsearch API key authentication:
|
||||||
+
|
+
|
||||||
--
|
--
|
||||||
[source,bash]
|
[source,bash]
|
||||||
@ -187,3 +187,12 @@ eland_import_hub_model --url https://<user>:<password>@<hostname>:<port> ...
|
|||||||
eland_import_hub_model --es-api-key <api-key> --url https://<hostname>:<port> ...
|
eland_import_hub_model --es-api-key <api-key> --url https://<hostname>:<port> ...
|
||||||
--------------------------------------------------
|
--------------------------------------------------
|
||||||
--
|
--
|
||||||
|
|
||||||
|
* HuggingFace Hub access token (for private models):
|
||||||
|
+
|
||||||
|
--
|
||||||
|
[source,bash]
|
||||||
|
--------------------------------------------------
|
||||||
|
eland_import_hub_model --hub-access-token <access-token> ...
|
||||||
|
--------------------------------------------------
|
||||||
|
--
|
@ -58,6 +58,12 @@ def get_arg_parser():
|
|||||||
help="The model ID in the Hugging Face model hub, "
|
help="The model ID in the Hugging Face model hub, "
|
||||||
"e.g. dbmdz/bert-large-cased-finetuned-conll03-english",
|
"e.g. dbmdz/bert-large-cased-finetuned-conll03-english",
|
||||||
)
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--hub-access-token",
|
||||||
|
required=False,
|
||||||
|
default=os.environ.get("HUB_ACCESS_TOKEN"),
|
||||||
|
help="The Hugging Face access token, needed to access private models",
|
||||||
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--es-model-id",
|
"--es-model-id",
|
||||||
required=False,
|
required=False,
|
||||||
@ -234,6 +240,7 @@ def main():
|
|||||||
try:
|
try:
|
||||||
tm = TransformerModel(
|
tm = TransformerModel(
|
||||||
model_id=args.hub_model_id,
|
model_id=args.hub_model_id,
|
||||||
|
access_token=args.hub_access_token,
|
||||||
task_type=args.task_type,
|
task_type=args.task_type,
|
||||||
es_version=cluster_version,
|
es_version=cluster_version,
|
||||||
quantize=args.quantize,
|
quantize=args.quantize,
|
||||||
|
@ -179,9 +179,9 @@ class _QuestionAnsweringWrapperModule(nn.Module): # type: ignore
|
|||||||
self.config = model.config
|
self.config = model.config
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def from_pretrained(model_id: str) -> Optional[Any]:
|
def from_pretrained(model_id: str, *, token: Optional[str] = None) -> Optional[Any]:
|
||||||
model = AutoModelForQuestionAnswering.from_pretrained(
|
model = AutoModelForQuestionAnswering.from_pretrained(
|
||||||
model_id, torchscript=True
|
model_id, token=token, torchscript=True
|
||||||
)
|
)
|
||||||
if isinstance(
|
if isinstance(
|
||||||
model.config,
|
model.config,
|
||||||
@ -292,9 +292,12 @@ class _SentenceTransformerWrapperModule(nn.Module): # type: ignore
|
|||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def from_pretrained(
|
def from_pretrained(
|
||||||
model_id: str, output_key: str = DEFAULT_OUTPUT_KEY
|
model_id: str,
|
||||||
|
*,
|
||||||
|
token: Optional[str] = None,
|
||||||
|
output_key: str = DEFAULT_OUTPUT_KEY,
|
||||||
) -> Optional[Any]:
|
) -> Optional[Any]:
|
||||||
model = AutoModel.from_pretrained(model_id, torchscript=True)
|
model = AutoModel.from_pretrained(model_id, token=token, torchscript=True)
|
||||||
if isinstance(
|
if isinstance(
|
||||||
model.config,
|
model.config,
|
||||||
(
|
(
|
||||||
@ -393,8 +396,8 @@ class _DPREncoderWrapper(nn.Module): # type: ignore
|
|||||||
self.config = model.config
|
self.config = model.config
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def from_pretrained(model_id: str) -> Optional[Any]:
|
def from_pretrained(model_id: str, *, token: Optional[str] = None) -> Optional[Any]:
|
||||||
config = AutoConfig.from_pretrained(model_id)
|
config = AutoConfig.from_pretrained(model_id, token=token)
|
||||||
|
|
||||||
def is_compatible() -> bool:
|
def is_compatible() -> bool:
|
||||||
is_dpr_model = config.model_type == "dpr"
|
is_dpr_model = config.model_type == "dpr"
|
||||||
@ -579,9 +582,10 @@ class _TraceableTextSimilarityModel(_TransformerTraceableModel):
|
|||||||
class TransformerModel:
|
class TransformerModel:
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
model_id: str,
|
|
||||||
task_type: str,
|
|
||||||
*,
|
*,
|
||||||
|
model_id: str,
|
||||||
|
access_token: Optional[str],
|
||||||
|
task_type: str,
|
||||||
es_version: Optional[Tuple[int, int, int]] = None,
|
es_version: Optional[Tuple[int, int, int]] = None,
|
||||||
quantize: bool = False,
|
quantize: bool = False,
|
||||||
):
|
):
|
||||||
@ -609,14 +613,14 @@ class TransformerModel:
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
self._model_id = model_id
|
self._model_id = model_id
|
||||||
|
self._access_token = access_token
|
||||||
self._task_type = task_type.replace("-", "_")
|
self._task_type = task_type.replace("-", "_")
|
||||||
|
|
||||||
# load Hugging Face model and tokenizer
|
# load Hugging Face model and tokenizer
|
||||||
# use padding in the tokenizer to ensure max length sequences are used for tracing (at call time)
|
# use padding in the tokenizer to ensure max length sequences are used for tracing (at call time)
|
||||||
# - see: https://huggingface.co/transformers/serialization.html#dummy-inputs-and-standard-lengths
|
# - see: https://huggingface.co/transformers/serialization.html#dummy-inputs-and-standard-lengths
|
||||||
self._tokenizer = transformers.AutoTokenizer.from_pretrained(
|
self._tokenizer = transformers.AutoTokenizer.from_pretrained(
|
||||||
self._model_id,
|
self._model_id, token=self._access_token, use_fast=False
|
||||||
use_fast=False,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# check for a supported tokenizer
|
# check for a supported tokenizer
|
||||||
@ -755,7 +759,7 @@ class TransformerModel:
|
|||||||
def _create_traceable_model(self) -> TraceableModel:
|
def _create_traceable_model(self) -> TraceableModel:
|
||||||
if self._task_type == "auto":
|
if self._task_type == "auto":
|
||||||
model = transformers.AutoModel.from_pretrained(
|
model = transformers.AutoModel.from_pretrained(
|
||||||
self._model_id, torchscript=True
|
self._model_id, token=self._access_token, torchscript=True
|
||||||
)
|
)
|
||||||
maybe_task_type = task_type_from_model_config(model.config)
|
maybe_task_type = task_type_from_model_config(model.config)
|
||||||
if maybe_task_type is None:
|
if maybe_task_type is None:
|
||||||
@ -767,54 +771,58 @@ class TransformerModel:
|
|||||||
|
|
||||||
if self._task_type == "fill_mask":
|
if self._task_type == "fill_mask":
|
||||||
model = transformers.AutoModelForMaskedLM.from_pretrained(
|
model = transformers.AutoModelForMaskedLM.from_pretrained(
|
||||||
self._model_id, torchscript=True
|
self._model_id, token=self._access_token, torchscript=True
|
||||||
)
|
)
|
||||||
model = _DistilBertWrapper.try_wrapping(model)
|
model = _DistilBertWrapper.try_wrapping(model)
|
||||||
return _TraceableFillMaskModel(self._tokenizer, model)
|
return _TraceableFillMaskModel(self._tokenizer, model)
|
||||||
|
|
||||||
elif self._task_type == "ner":
|
elif self._task_type == "ner":
|
||||||
model = transformers.AutoModelForTokenClassification.from_pretrained(
|
model = transformers.AutoModelForTokenClassification.from_pretrained(
|
||||||
self._model_id, torchscript=True
|
self._model_id, token=self._access_token, torchscript=True
|
||||||
)
|
)
|
||||||
model = _DistilBertWrapper.try_wrapping(model)
|
model = _DistilBertWrapper.try_wrapping(model)
|
||||||
return _TraceableNerModel(self._tokenizer, model)
|
return _TraceableNerModel(self._tokenizer, model)
|
||||||
|
|
||||||
elif self._task_type == "text_classification":
|
elif self._task_type == "text_classification":
|
||||||
model = transformers.AutoModelForSequenceClassification.from_pretrained(
|
model = transformers.AutoModelForSequenceClassification.from_pretrained(
|
||||||
self._model_id, torchscript=True
|
self._model_id, token=self._access_token, torchscript=True
|
||||||
)
|
)
|
||||||
model = _DistilBertWrapper.try_wrapping(model)
|
model = _DistilBertWrapper.try_wrapping(model)
|
||||||
return _TraceableTextClassificationModel(self._tokenizer, model)
|
return _TraceableTextClassificationModel(self._tokenizer, model)
|
||||||
|
|
||||||
elif self._task_type == "text_embedding":
|
elif self._task_type == "text_embedding":
|
||||||
model = _DPREncoderWrapper.from_pretrained(self._model_id)
|
model = _DPREncoderWrapper.from_pretrained(
|
||||||
|
self._model_id, token=self._access_token
|
||||||
|
)
|
||||||
if not model:
|
if not model:
|
||||||
model = _SentenceTransformerWrapperModule.from_pretrained(
|
model = _SentenceTransformerWrapperModule.from_pretrained(
|
||||||
self._model_id
|
self._model_id, token=self._access_token
|
||||||
)
|
)
|
||||||
return _TraceableTextEmbeddingModel(self._tokenizer, model)
|
return _TraceableTextEmbeddingModel(self._tokenizer, model)
|
||||||
|
|
||||||
elif self._task_type == "zero_shot_classification":
|
elif self._task_type == "zero_shot_classification":
|
||||||
model = transformers.AutoModelForSequenceClassification.from_pretrained(
|
model = transformers.AutoModelForSequenceClassification.from_pretrained(
|
||||||
self._model_id, torchscript=True
|
self._model_id, token=self._access_token, torchscript=True
|
||||||
)
|
)
|
||||||
model = _DistilBertWrapper.try_wrapping(model)
|
model = _DistilBertWrapper.try_wrapping(model)
|
||||||
return _TraceableZeroShotClassificationModel(self._tokenizer, model)
|
return _TraceableZeroShotClassificationModel(self._tokenizer, model)
|
||||||
|
|
||||||
elif self._task_type == "question_answering":
|
elif self._task_type == "question_answering":
|
||||||
model = _QuestionAnsweringWrapperModule.from_pretrained(self._model_id)
|
model = _QuestionAnsweringWrapperModule.from_pretrained(
|
||||||
|
self._model_id, token=self._access_token
|
||||||
|
)
|
||||||
return _TraceableQuestionAnsweringModel(self._tokenizer, model)
|
return _TraceableQuestionAnsweringModel(self._tokenizer, model)
|
||||||
|
|
||||||
elif self._task_type == "text_similarity":
|
elif self._task_type == "text_similarity":
|
||||||
model = transformers.AutoModelForSequenceClassification.from_pretrained(
|
model = transformers.AutoModelForSequenceClassification.from_pretrained(
|
||||||
self._model_id, torchscript=True
|
self._model_id, token=self._access_token, torchscript=True
|
||||||
)
|
)
|
||||||
model = _DistilBertWrapper.try_wrapping(model)
|
model = _DistilBertWrapper.try_wrapping(model)
|
||||||
return _TraceableTextSimilarityModel(self._tokenizer, model)
|
return _TraceableTextSimilarityModel(self._tokenizer, model)
|
||||||
|
|
||||||
elif self._task_type == "pass_through":
|
elif self._task_type == "pass_through":
|
||||||
model = transformers.AutoModel.from_pretrained(
|
model = transformers.AutoModel.from_pretrained(
|
||||||
self._model_id, torchscript=True
|
self._model_id, token=self._access_token, torchscript=True
|
||||||
)
|
)
|
||||||
return _TraceablePassThroughModel(self._tokenizer, model)
|
return _TraceablePassThroughModel(self._tokenizer, model)
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user