mirror of
https://github.com/elastic/eland.git
synced 2025-07-11 00:02:14 +08:00
Allow importing private HuggingFace models (#608)
This commit is contained in:
parent
5ec760635b
commit
566bb9e990
@ -160,7 +160,7 @@ underscores `__`.
|
||||
|
||||
The following authentication options are available when using the import script:
|
||||
|
||||
* username and password authentication (specified with the `-u` and `-p` options):
|
||||
* Elasticsearch username and password authentication (specified with the `-u` and `-p` options):
|
||||
+
|
||||
--
|
||||
[source,bash]
|
||||
@ -170,7 +170,7 @@ eland_import_hub_model -u <username> -p <password> --cloud-id <cloud-id> ...
|
||||
These `-u` and `-p` options also work when you use `--url`.
|
||||
--
|
||||
|
||||
* username and password authentication (embedded in the URL):
|
||||
* Elasticsearch username and password authentication (embedded in the URL):
|
||||
+
|
||||
--
|
||||
[source,bash]
|
||||
@ -179,7 +179,7 @@ eland_import_hub_model --url https://<user>:<password>@<hostname>:<port> ...
|
||||
--------------------------------------------------
|
||||
--
|
||||
|
||||
* API key authentication:
|
||||
* Elasticsearch API key authentication:
|
||||
+
|
||||
--
|
||||
[source,bash]
|
||||
@ -187,3 +187,12 @@ eland_import_hub_model --url https://<user>:<password>@<hostname>:<port> ...
|
||||
eland_import_hub_model --es-api-key <api-key> --url https://<hostname>:<port> ...
|
||||
--------------------------------------------------
|
||||
--
|
||||
|
||||
* HuggingFace Hub access token (for private models):
|
||||
+
|
||||
--
|
||||
[source,bash]
|
||||
--------------------------------------------------
|
||||
eland_import_hub_model --hub-access-token <access-token> ...
|
||||
--------------------------------------------------
|
||||
--
|
@ -58,6 +58,12 @@ def get_arg_parser():
|
||||
help="The model ID in the Hugging Face model hub, "
|
||||
"e.g. dbmdz/bert-large-cased-finetuned-conll03-english",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--hub-access-token",
|
||||
required=False,
|
||||
default=os.environ.get("HUB_ACCESS_TOKEN"),
|
||||
help="The Hugging Face access token, needed to access private models",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--es-model-id",
|
||||
required=False,
|
||||
@ -234,6 +240,7 @@ def main():
|
||||
try:
|
||||
tm = TransformerModel(
|
||||
model_id=args.hub_model_id,
|
||||
access_token=args.hub_access_token,
|
||||
task_type=args.task_type,
|
||||
es_version=cluster_version,
|
||||
quantize=args.quantize,
|
||||
|
@ -179,9 +179,9 @@ class _QuestionAnsweringWrapperModule(nn.Module): # type: ignore
|
||||
self.config = model.config
|
||||
|
||||
@staticmethod
|
||||
def from_pretrained(model_id: str) -> Optional[Any]:
|
||||
def from_pretrained(model_id: str, *, token: Optional[str] = None) -> Optional[Any]:
|
||||
model = AutoModelForQuestionAnswering.from_pretrained(
|
||||
model_id, torchscript=True
|
||||
model_id, token=token, torchscript=True
|
||||
)
|
||||
if isinstance(
|
||||
model.config,
|
||||
@ -292,9 +292,12 @@ class _SentenceTransformerWrapperModule(nn.Module): # type: ignore
|
||||
|
||||
@staticmethod
|
||||
def from_pretrained(
|
||||
model_id: str, output_key: str = DEFAULT_OUTPUT_KEY
|
||||
model_id: str,
|
||||
*,
|
||||
token: Optional[str] = None,
|
||||
output_key: str = DEFAULT_OUTPUT_KEY,
|
||||
) -> Optional[Any]:
|
||||
model = AutoModel.from_pretrained(model_id, torchscript=True)
|
||||
model = AutoModel.from_pretrained(model_id, token=token, torchscript=True)
|
||||
if isinstance(
|
||||
model.config,
|
||||
(
|
||||
@ -393,8 +396,8 @@ class _DPREncoderWrapper(nn.Module): # type: ignore
|
||||
self.config = model.config
|
||||
|
||||
@staticmethod
|
||||
def from_pretrained(model_id: str) -> Optional[Any]:
|
||||
config = AutoConfig.from_pretrained(model_id)
|
||||
def from_pretrained(model_id: str, *, token: Optional[str] = None) -> Optional[Any]:
|
||||
config = AutoConfig.from_pretrained(model_id, token=token)
|
||||
|
||||
def is_compatible() -> bool:
|
||||
is_dpr_model = config.model_type == "dpr"
|
||||
@ -579,9 +582,10 @@ class _TraceableTextSimilarityModel(_TransformerTraceableModel):
|
||||
class TransformerModel:
|
||||
def __init__(
|
||||
self,
|
||||
model_id: str,
|
||||
task_type: str,
|
||||
*,
|
||||
model_id: str,
|
||||
access_token: Optional[str],
|
||||
task_type: str,
|
||||
es_version: Optional[Tuple[int, int, int]] = None,
|
||||
quantize: bool = False,
|
||||
):
|
||||
@ -609,14 +613,14 @@ class TransformerModel:
|
||||
"""
|
||||
|
||||
self._model_id = model_id
|
||||
self._access_token = access_token
|
||||
self._task_type = task_type.replace("-", "_")
|
||||
|
||||
# load Hugging Face model and tokenizer
|
||||
# use padding in the tokenizer to ensure max length sequences are used for tracing (at call time)
|
||||
# - see: https://huggingface.co/transformers/serialization.html#dummy-inputs-and-standard-lengths
|
||||
self._tokenizer = transformers.AutoTokenizer.from_pretrained(
|
||||
self._model_id,
|
||||
use_fast=False,
|
||||
self._model_id, token=self._access_token, use_fast=False
|
||||
)
|
||||
|
||||
# check for a supported tokenizer
|
||||
@ -755,7 +759,7 @@ class TransformerModel:
|
||||
def _create_traceable_model(self) -> TraceableModel:
|
||||
if self._task_type == "auto":
|
||||
model = transformers.AutoModel.from_pretrained(
|
||||
self._model_id, torchscript=True
|
||||
self._model_id, token=self._access_token, torchscript=True
|
||||
)
|
||||
maybe_task_type = task_type_from_model_config(model.config)
|
||||
if maybe_task_type is None:
|
||||
@ -767,54 +771,58 @@ class TransformerModel:
|
||||
|
||||
if self._task_type == "fill_mask":
|
||||
model = transformers.AutoModelForMaskedLM.from_pretrained(
|
||||
self._model_id, torchscript=True
|
||||
self._model_id, token=self._access_token, torchscript=True
|
||||
)
|
||||
model = _DistilBertWrapper.try_wrapping(model)
|
||||
return _TraceableFillMaskModel(self._tokenizer, model)
|
||||
|
||||
elif self._task_type == "ner":
|
||||
model = transformers.AutoModelForTokenClassification.from_pretrained(
|
||||
self._model_id, torchscript=True
|
||||
self._model_id, token=self._access_token, torchscript=True
|
||||
)
|
||||
model = _DistilBertWrapper.try_wrapping(model)
|
||||
return _TraceableNerModel(self._tokenizer, model)
|
||||
|
||||
elif self._task_type == "text_classification":
|
||||
model = transformers.AutoModelForSequenceClassification.from_pretrained(
|
||||
self._model_id, torchscript=True
|
||||
self._model_id, token=self._access_token, torchscript=True
|
||||
)
|
||||
model = _DistilBertWrapper.try_wrapping(model)
|
||||
return _TraceableTextClassificationModel(self._tokenizer, model)
|
||||
|
||||
elif self._task_type == "text_embedding":
|
||||
model = _DPREncoderWrapper.from_pretrained(self._model_id)
|
||||
model = _DPREncoderWrapper.from_pretrained(
|
||||
self._model_id, token=self._access_token
|
||||
)
|
||||
if not model:
|
||||
model = _SentenceTransformerWrapperModule.from_pretrained(
|
||||
self._model_id
|
||||
self._model_id, token=self._access_token
|
||||
)
|
||||
return _TraceableTextEmbeddingModel(self._tokenizer, model)
|
||||
|
||||
elif self._task_type == "zero_shot_classification":
|
||||
model = transformers.AutoModelForSequenceClassification.from_pretrained(
|
||||
self._model_id, torchscript=True
|
||||
self._model_id, token=self._access_token, torchscript=True
|
||||
)
|
||||
model = _DistilBertWrapper.try_wrapping(model)
|
||||
return _TraceableZeroShotClassificationModel(self._tokenizer, model)
|
||||
|
||||
elif self._task_type == "question_answering":
|
||||
model = _QuestionAnsweringWrapperModule.from_pretrained(self._model_id)
|
||||
model = _QuestionAnsweringWrapperModule.from_pretrained(
|
||||
self._model_id, token=self._access_token
|
||||
)
|
||||
return _TraceableQuestionAnsweringModel(self._tokenizer, model)
|
||||
|
||||
elif self._task_type == "text_similarity":
|
||||
model = transformers.AutoModelForSequenceClassification.from_pretrained(
|
||||
self._model_id, torchscript=True
|
||||
self._model_id, token=self._access_token, torchscript=True
|
||||
)
|
||||
model = _DistilBertWrapper.try_wrapping(model)
|
||||
return _TraceableTextSimilarityModel(self._tokenizer, model)
|
||||
|
||||
elif self._task_type == "pass_through":
|
||||
model = transformers.AutoModel.from_pretrained(
|
||||
self._model_id, torchscript=True
|
||||
self._model_id, token=self._access_token, torchscript=True
|
||||
)
|
||||
return _TraceablePassThroughModel(self._tokenizer, model)
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user