diff --git a/docs/guide/machine-learning.asciidoc b/docs/guide/machine-learning.asciidoc index 34e14ff..d8b1307 100644 --- a/docs/guide/machine-learning.asciidoc +++ b/docs/guide/machine-learning.asciidoc @@ -160,7 +160,7 @@ underscores `__`. The following authentication options are available when using the import script: -* username and password authentication (specified with the `-u` and `-p` options): +* Elasticsearch username and password authentication (specified with the `-u` and `-p` options): + -- [source,bash] @@ -170,7 +170,7 @@ eland_import_hub_model -u -p --cloud-id ... These `-u` and `-p` options also work when you use `--url`. -- -* username and password authentication (embedded in the URL): +* Elasticsearch username and password authentication (embedded in the URL): + -- [source,bash] @@ -179,7 +179,7 @@ eland_import_hub_model --url https://:@: ... -------------------------------------------------- -- -* API key authentication: +* Elasticsearch API key authentication: + -- [source,bash] @@ -187,3 +187,12 @@ eland_import_hub_model --url https://:@: ... eland_import_hub_model --es-api-key --url https://: ... -------------------------------------------------- -- + +* HuggingFace Hub access token (for private models): ++ +-- +[source,bash] +-------------------------------------------------- +eland_import_hub_model --hub-access-token ... +-------------------------------------------------- +-- \ No newline at end of file diff --git a/eland/cli/eland_import_hub_model.py b/eland/cli/eland_import_hub_model.py index b763d35..b8d40b4 100755 --- a/eland/cli/eland_import_hub_model.py +++ b/eland/cli/eland_import_hub_model.py @@ -58,6 +58,12 @@ def get_arg_parser(): help="The model ID in the Hugging Face model hub, " "e.g. dbmdz/bert-large-cased-finetuned-conll03-english", ) + parser.add_argument( + "--hub-access-token", + required=False, + default=os.environ.get("HUB_ACCESS_TOKEN"), + help="The Hugging Face access token, needed to access private models", + ) parser.add_argument( "--es-model-id", required=False, @@ -234,6 +240,7 @@ def main(): try: tm = TransformerModel( model_id=args.hub_model_id, + access_token=args.hub_access_token, task_type=args.task_type, es_version=cluster_version, quantize=args.quantize, diff --git a/eland/ml/pytorch/transformers.py b/eland/ml/pytorch/transformers.py index 2b16b1f..67d4253 100644 --- a/eland/ml/pytorch/transformers.py +++ b/eland/ml/pytorch/transformers.py @@ -179,9 +179,9 @@ class _QuestionAnsweringWrapperModule(nn.Module): # type: ignore self.config = model.config @staticmethod - def from_pretrained(model_id: str) -> Optional[Any]: + def from_pretrained(model_id: str, *, token: Optional[str] = None) -> Optional[Any]: model = AutoModelForQuestionAnswering.from_pretrained( - model_id, torchscript=True + model_id, token=token, torchscript=True ) if isinstance( model.config, @@ -292,9 +292,12 @@ class _SentenceTransformerWrapperModule(nn.Module): # type: ignore @staticmethod def from_pretrained( - model_id: str, output_key: str = DEFAULT_OUTPUT_KEY + model_id: str, + *, + token: Optional[str] = None, + output_key: str = DEFAULT_OUTPUT_KEY, ) -> Optional[Any]: - model = AutoModel.from_pretrained(model_id, torchscript=True) + model = AutoModel.from_pretrained(model_id, token=token, torchscript=True) if isinstance( model.config, ( @@ -393,8 +396,8 @@ class _DPREncoderWrapper(nn.Module): # type: ignore self.config = model.config @staticmethod - def from_pretrained(model_id: str) -> Optional[Any]: - config = AutoConfig.from_pretrained(model_id) + def from_pretrained(model_id: str, *, token: Optional[str] = None) -> Optional[Any]: + config = AutoConfig.from_pretrained(model_id, token=token) def is_compatible() -> bool: is_dpr_model = config.model_type == "dpr" @@ -579,9 +582,10 @@ class _TraceableTextSimilarityModel(_TransformerTraceableModel): class TransformerModel: def __init__( self, - model_id: str, - task_type: str, *, + model_id: str, + access_token: Optional[str], + task_type: str, es_version: Optional[Tuple[int, int, int]] = None, quantize: bool = False, ): @@ -609,14 +613,14 @@ class TransformerModel: """ self._model_id = model_id + self._access_token = access_token self._task_type = task_type.replace("-", "_") # load Hugging Face model and tokenizer # use padding in the tokenizer to ensure max length sequences are used for tracing (at call time) # - see: https://huggingface.co/transformers/serialization.html#dummy-inputs-and-standard-lengths self._tokenizer = transformers.AutoTokenizer.from_pretrained( - self._model_id, - use_fast=False, + self._model_id, token=self._access_token, use_fast=False ) # check for a supported tokenizer @@ -755,7 +759,7 @@ class TransformerModel: def _create_traceable_model(self) -> TraceableModel: if self._task_type == "auto": model = transformers.AutoModel.from_pretrained( - self._model_id, torchscript=True + self._model_id, token=self._access_token, torchscript=True ) maybe_task_type = task_type_from_model_config(model.config) if maybe_task_type is None: @@ -767,54 +771,58 @@ class TransformerModel: if self._task_type == "fill_mask": model = transformers.AutoModelForMaskedLM.from_pretrained( - self._model_id, torchscript=True + self._model_id, token=self._access_token, torchscript=True ) model = _DistilBertWrapper.try_wrapping(model) return _TraceableFillMaskModel(self._tokenizer, model) elif self._task_type == "ner": model = transformers.AutoModelForTokenClassification.from_pretrained( - self._model_id, torchscript=True + self._model_id, token=self._access_token, torchscript=True ) model = _DistilBertWrapper.try_wrapping(model) return _TraceableNerModel(self._tokenizer, model) elif self._task_type == "text_classification": model = transformers.AutoModelForSequenceClassification.from_pretrained( - self._model_id, torchscript=True + self._model_id, token=self._access_token, torchscript=True ) model = _DistilBertWrapper.try_wrapping(model) return _TraceableTextClassificationModel(self._tokenizer, model) elif self._task_type == "text_embedding": - model = _DPREncoderWrapper.from_pretrained(self._model_id) + model = _DPREncoderWrapper.from_pretrained( + self._model_id, token=self._access_token + ) if not model: model = _SentenceTransformerWrapperModule.from_pretrained( - self._model_id + self._model_id, token=self._access_token ) return _TraceableTextEmbeddingModel(self._tokenizer, model) elif self._task_type == "zero_shot_classification": model = transformers.AutoModelForSequenceClassification.from_pretrained( - self._model_id, torchscript=True + self._model_id, token=self._access_token, torchscript=True ) model = _DistilBertWrapper.try_wrapping(model) return _TraceableZeroShotClassificationModel(self._tokenizer, model) elif self._task_type == "question_answering": - model = _QuestionAnsweringWrapperModule.from_pretrained(self._model_id) + model = _QuestionAnsweringWrapperModule.from_pretrained( + self._model_id, token=self._access_token + ) return _TraceableQuestionAnsweringModel(self._tokenizer, model) elif self._task_type == "text_similarity": model = transformers.AutoModelForSequenceClassification.from_pretrained( - self._model_id, torchscript=True + self._model_id, token=self._access_token, torchscript=True ) model = _DistilBertWrapper.try_wrapping(model) return _TraceableTextSimilarityModel(self._tokenizer, model) elif self._task_type == "pass_through": model = transformers.AutoModel.from_pretrained( - self._model_id, torchscript=True + self._model_id, token=self._access_token, torchscript=True ) return _TraceablePassThroughModel(self._tokenizer, model)