Allow importing private HuggingFace models (#608)

This commit is contained in:
Quentin Pradet 2023-09-25 17:10:58 +04:00 committed by GitHub
parent 5ec760635b
commit 566bb9e990
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 47 additions and 23 deletions

View File

@ -160,7 +160,7 @@ underscores `__`.
The following authentication options are available when using the import script:
* username and password authentication (specified with the `-u` and `-p` options):
* Elasticsearch username and password authentication (specified with the `-u` and `-p` options):
+
--
[source,bash]
@ -170,7 +170,7 @@ eland_import_hub_model -u <username> -p <password> --cloud-id <cloud-id> ...
These `-u` and `-p` options also work when you use `--url`.
--
* username and password authentication (embedded in the URL):
* Elasticsearch username and password authentication (embedded in the URL):
+
--
[source,bash]
@ -179,7 +179,7 @@ eland_import_hub_model --url https://<user>:<password>@<hostname>:<port> ...
--------------------------------------------------
--
* API key authentication:
* Elasticsearch API key authentication:
+
--
[source,bash]
@ -187,3 +187,12 @@ eland_import_hub_model --url https://<user>:<password>@<hostname>:<port> ...
eland_import_hub_model --es-api-key <api-key> --url https://<hostname>:<port> ...
--------------------------------------------------
--
* HuggingFace Hub access token (for private models):
+
--
[source,bash]
--------------------------------------------------
eland_import_hub_model --hub-access-token <access-token> ...
--------------------------------------------------
--

View File

@ -58,6 +58,12 @@ def get_arg_parser():
help="The model ID in the Hugging Face model hub, "
"e.g. dbmdz/bert-large-cased-finetuned-conll03-english",
)
parser.add_argument(
"--hub-access-token",
required=False,
default=os.environ.get("HUB_ACCESS_TOKEN"),
help="The Hugging Face access token, needed to access private models",
)
parser.add_argument(
"--es-model-id",
required=False,
@ -234,6 +240,7 @@ def main():
try:
tm = TransformerModel(
model_id=args.hub_model_id,
access_token=args.hub_access_token,
task_type=args.task_type,
es_version=cluster_version,
quantize=args.quantize,

View File

@ -179,9 +179,9 @@ class _QuestionAnsweringWrapperModule(nn.Module): # type: ignore
self.config = model.config
@staticmethod
def from_pretrained(model_id: str) -> Optional[Any]:
def from_pretrained(model_id: str, *, token: Optional[str] = None) -> Optional[Any]:
model = AutoModelForQuestionAnswering.from_pretrained(
model_id, torchscript=True
model_id, token=token, torchscript=True
)
if isinstance(
model.config,
@ -292,9 +292,12 @@ class _SentenceTransformerWrapperModule(nn.Module): # type: ignore
@staticmethod
def from_pretrained(
model_id: str, output_key: str = DEFAULT_OUTPUT_KEY
model_id: str,
*,
token: Optional[str] = None,
output_key: str = DEFAULT_OUTPUT_KEY,
) -> Optional[Any]:
model = AutoModel.from_pretrained(model_id, torchscript=True)
model = AutoModel.from_pretrained(model_id, token=token, torchscript=True)
if isinstance(
model.config,
(
@ -393,8 +396,8 @@ class _DPREncoderWrapper(nn.Module): # type: ignore
self.config = model.config
@staticmethod
def from_pretrained(model_id: str) -> Optional[Any]:
config = AutoConfig.from_pretrained(model_id)
def from_pretrained(model_id: str, *, token: Optional[str] = None) -> Optional[Any]:
config = AutoConfig.from_pretrained(model_id, token=token)
def is_compatible() -> bool:
is_dpr_model = config.model_type == "dpr"
@ -579,9 +582,10 @@ class _TraceableTextSimilarityModel(_TransformerTraceableModel):
class TransformerModel:
def __init__(
self,
model_id: str,
task_type: str,
*,
model_id: str,
access_token: Optional[str],
task_type: str,
es_version: Optional[Tuple[int, int, int]] = None,
quantize: bool = False,
):
@ -609,14 +613,14 @@ class TransformerModel:
"""
self._model_id = model_id
self._access_token = access_token
self._task_type = task_type.replace("-", "_")
# load Hugging Face model and tokenizer
# use padding in the tokenizer to ensure max length sequences are used for tracing (at call time)
# - see: https://huggingface.co/transformers/serialization.html#dummy-inputs-and-standard-lengths
self._tokenizer = transformers.AutoTokenizer.from_pretrained(
self._model_id,
use_fast=False,
self._model_id, token=self._access_token, use_fast=False
)
# check for a supported tokenizer
@ -755,7 +759,7 @@ class TransformerModel:
def _create_traceable_model(self) -> TraceableModel:
if self._task_type == "auto":
model = transformers.AutoModel.from_pretrained(
self._model_id, torchscript=True
self._model_id, token=self._access_token, torchscript=True
)
maybe_task_type = task_type_from_model_config(model.config)
if maybe_task_type is None:
@ -767,54 +771,58 @@ class TransformerModel:
if self._task_type == "fill_mask":
model = transformers.AutoModelForMaskedLM.from_pretrained(
self._model_id, torchscript=True
self._model_id, token=self._access_token, torchscript=True
)
model = _DistilBertWrapper.try_wrapping(model)
return _TraceableFillMaskModel(self._tokenizer, model)
elif self._task_type == "ner":
model = transformers.AutoModelForTokenClassification.from_pretrained(
self._model_id, torchscript=True
self._model_id, token=self._access_token, torchscript=True
)
model = _DistilBertWrapper.try_wrapping(model)
return _TraceableNerModel(self._tokenizer, model)
elif self._task_type == "text_classification":
model = transformers.AutoModelForSequenceClassification.from_pretrained(
self._model_id, torchscript=True
self._model_id, token=self._access_token, torchscript=True
)
model = _DistilBertWrapper.try_wrapping(model)
return _TraceableTextClassificationModel(self._tokenizer, model)
elif self._task_type == "text_embedding":
model = _DPREncoderWrapper.from_pretrained(self._model_id)
model = _DPREncoderWrapper.from_pretrained(
self._model_id, token=self._access_token
)
if not model:
model = _SentenceTransformerWrapperModule.from_pretrained(
self._model_id
self._model_id, token=self._access_token
)
return _TraceableTextEmbeddingModel(self._tokenizer, model)
elif self._task_type == "zero_shot_classification":
model = transformers.AutoModelForSequenceClassification.from_pretrained(
self._model_id, torchscript=True
self._model_id, token=self._access_token, torchscript=True
)
model = _DistilBertWrapper.try_wrapping(model)
return _TraceableZeroShotClassificationModel(self._tokenizer, model)
elif self._task_type == "question_answering":
model = _QuestionAnsweringWrapperModule.from_pretrained(self._model_id)
model = _QuestionAnsweringWrapperModule.from_pretrained(
self._model_id, token=self._access_token
)
return _TraceableQuestionAnsweringModel(self._tokenizer, model)
elif self._task_type == "text_similarity":
model = transformers.AutoModelForSequenceClassification.from_pretrained(
self._model_id, torchscript=True
self._model_id, token=self._access_token, torchscript=True
)
model = _DistilBertWrapper.try_wrapping(model)
return _TraceableTextSimilarityModel(self._tokenizer, model)
elif self._task_type == "pass_through":
model = transformers.AutoModel.from_pretrained(
self._model_id, torchscript=True
self._model_id, token=self._access_token, torchscript=True
)
return _TraceablePassThroughModel(self._tokenizer, model)