mirror of
https://github.com/elastic/eland.git
synced 2025-07-11 00:02:14 +08:00
For many model types, we don't need to require the task requested. We can infer the task type based on the model configuration and architecture.
This commit makes the `task-type` parameter optional for the model up load script and adds logic for auto-detecting the task type based on the 🤗 model.
235 lines
7.8 KiB
Python
Executable File
235 lines
7.8 KiB
Python
Executable File
#!/usr/bin/env python
|
|
|
|
# Licensed to Elasticsearch B.V. under one or more contributor
|
|
# license agreements. See the NOTICE file distributed with
|
|
# this work for additional information regarding copyright
|
|
# ownership. Elasticsearch B.V. licenses this file to you under
|
|
# the Apache License, Version 2.0 (the "License"); you may
|
|
# not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing,
|
|
# software distributed under the License is distributed on an
|
|
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
|
# KIND, either express or implied. See the License for the
|
|
# specific language governing permissions and limitations
|
|
# under the License.
|
|
|
|
"""
|
|
Copies a model from the Hugging Face model hub into an Elasticsearch cluster.
|
|
This will create local cached copies that will be traced (necessary) before
|
|
uploading to Elasticsearch. This will also check that the task type is supported
|
|
as well as the model and tokenizer types. All necessary configuration is
|
|
uploaded along with the model.
|
|
"""
|
|
import argparse
|
|
import logging
|
|
import os
|
|
import sys
|
|
import tempfile
|
|
import textwrap
|
|
|
|
from elastic_transport.client_utils import DEFAULT
|
|
from elasticsearch import AuthenticationException, Elasticsearch
|
|
|
|
MODEL_HUB_URL = "https://huggingface.co"
|
|
|
|
|
|
def get_arg_parser():
|
|
parser = argparse.ArgumentParser()
|
|
location_args = parser.add_mutually_exclusive_group(required=True)
|
|
location_args.add_argument(
|
|
"--url",
|
|
default=os.environ.get("ES_URL"),
|
|
help="An Elasticsearch connection URL, e.g. http://localhost:9200",
|
|
)
|
|
location_args.add_argument(
|
|
"--cloud-id",
|
|
default=os.environ.get("CLOUD_ID"),
|
|
help="Cloud ID as found in the 'Manage Deployment' page of an Elastic Cloud deployment",
|
|
)
|
|
parser.add_argument(
|
|
"--hub-model-id",
|
|
required=True,
|
|
help="The model ID in the Hugging Face model hub, "
|
|
"e.g. dbmdz/bert-large-cased-finetuned-conll03-english",
|
|
)
|
|
parser.add_argument(
|
|
"--es-model-id",
|
|
required=False,
|
|
default=None,
|
|
help="The model ID to use in Elasticsearch, "
|
|
"e.g. bert-large-cased-finetuned-conll03-english."
|
|
"When left unspecified, this will be auto-created from the `hub-id`",
|
|
)
|
|
parser.add_argument(
|
|
"-u", "--es-username",
|
|
required=False,
|
|
default=os.environ.get("ES_USERNAME"),
|
|
help="Username for Elasticsearch"
|
|
)
|
|
parser.add_argument(
|
|
"-p", "--es-password",
|
|
required=False,
|
|
default=os.environ.get("ES_PASSWORD"),
|
|
help="Password for the Elasticsearch user specified with -u/--username"
|
|
)
|
|
parser.add_argument(
|
|
"--es-api-key",
|
|
required=False,
|
|
default=os.environ.get("ES_API_KEY"),
|
|
help="API key for Elasticsearch"
|
|
)
|
|
parser.add_argument(
|
|
"--task-type",
|
|
required=False,
|
|
choices=SUPPORTED_TASK_TYPES,
|
|
help="The task type for the model usage. Will attempt to auto-detect task type for the model if not provided. "
|
|
"Default: auto",
|
|
default="auto"
|
|
)
|
|
parser.add_argument(
|
|
"--quantize",
|
|
action="store_true",
|
|
default=False,
|
|
help="Quantize the model before uploading. Default: False",
|
|
)
|
|
parser.add_argument(
|
|
"--start",
|
|
action="store_true",
|
|
default=False,
|
|
help="Start the model deployment after uploading. Default: False",
|
|
)
|
|
parser.add_argument(
|
|
"--clear-previous",
|
|
action="store_true",
|
|
default=False,
|
|
help="Should the model previously stored with `es-model-id` be deleted"
|
|
)
|
|
parser.add_argument(
|
|
"--insecure",
|
|
action="store_false",
|
|
default=True,
|
|
help="Do not verify SSL certificates"
|
|
)
|
|
parser.add_argument(
|
|
"--ca-certs",
|
|
required=False,
|
|
default=DEFAULT,
|
|
help="Path to CA bundle"
|
|
)
|
|
|
|
return parser
|
|
|
|
|
|
def get_es_client(cli_args):
|
|
try:
|
|
es_args = {
|
|
'request_timeout': 300,
|
|
'verify_certs': cli_args.insecure,
|
|
'ca_certs': cli_args.ca_certs
|
|
}
|
|
|
|
# Deployment location
|
|
if cli_args.url:
|
|
es_args['hosts'] = cli_args.url
|
|
|
|
if cli_args.cloud_id:
|
|
es_args['cloud_id'] = cli_args.cloud_id
|
|
|
|
# Authentication
|
|
if cli_args.es_api_key:
|
|
es_args['api_key'] = cli_args.es_api_key
|
|
elif cli_args.es_username:
|
|
if not cli_args.es_password:
|
|
logging.error(f"Password for user {cli_args.es_username} was not specified.")
|
|
exit(1)
|
|
|
|
es_args['basic_auth'] = (cli_args.es_username, cli_args.es_password)
|
|
|
|
es_client = Elasticsearch(**es_args)
|
|
es_info = es_client.info()
|
|
logger.info(f"Connected to cluster named '{es_info['cluster_name']}' (version: {es_info['version']['number']})")
|
|
|
|
return es_client
|
|
except AuthenticationException as e:
|
|
logger.error(e)
|
|
exit(1)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
# Configure logging
|
|
logging.basicConfig(format='%(asctime)s %(levelname)s : %(message)s')
|
|
logger = logging.getLogger(__name__)
|
|
logger.setLevel(logging.INFO)
|
|
|
|
try:
|
|
from eland.ml.pytorch import PyTorchModel
|
|
from eland.ml.pytorch.transformers import (
|
|
SUPPORTED_TASK_TYPES,
|
|
TaskTypeError,
|
|
TransformerModel,
|
|
)
|
|
except ModuleNotFoundError as e:
|
|
logger.error(textwrap.dedent(f"""\
|
|
\033[31mFailed to run because module '{e.name}' is not available.\033[0m
|
|
|
|
This script requires PyTorch extras to run. You can install these by running:
|
|
|
|
\033[1m{sys.executable} -m pip install 'eland[pytorch]'
|
|
\033[0m"""))
|
|
exit(1)
|
|
|
|
# Parse arguments
|
|
args = get_arg_parser().parse_args()
|
|
|
|
# Connect to ES
|
|
logger.info("Establishing connection to Elasticsearch")
|
|
es = get_es_client(args)
|
|
|
|
# Trace and save model, then upload it from temp file
|
|
with tempfile.TemporaryDirectory() as tmp_dir:
|
|
logger.info(f"Loading HuggingFace transformer tokenizer and model '{args.hub_model_id}'")
|
|
|
|
try:
|
|
tm = TransformerModel(args.hub_model_id, args.task_type, args.quantize)
|
|
model_path, config, vocab_path = tm.save(tmp_dir)
|
|
except TaskTypeError as err:
|
|
logger.error(f"Failed to get model for task type, please provide valid task type via '--task-type' parameter. Caused by {err}")
|
|
exit(1)
|
|
|
|
ptm = PyTorchModel(es, args.es_model_id if args.es_model_id else tm.elasticsearch_model_id())
|
|
model_exists = es.options(ignore_status=404).ml.get_trained_models(model_id=ptm.model_id).meta.status == 200
|
|
|
|
if model_exists:
|
|
if args.clear_previous:
|
|
logger.info(f"Stopping deployment for model with id '{ptm.model_id}'")
|
|
ptm.stop()
|
|
|
|
logger.info(f"Deleting model with id '{ptm.model_id}'")
|
|
ptm.delete()
|
|
else:
|
|
logger.error(f"Trained model with id '{ptm.model_id}' already exists")
|
|
logger.info("Run the script with the '--clear-previous' flag if you want to overwrite the existing model.")
|
|
exit(1)
|
|
|
|
logger.info(f"Creating model with id '{ptm.model_id}'")
|
|
ptm.put_config(config=config)
|
|
|
|
logger.info(f"Uploading model definition")
|
|
ptm.put_model(model_path)
|
|
|
|
logger.info(f"Uploading model vocabulary")
|
|
ptm.put_vocab(vocab_path)
|
|
|
|
# Start the deployed model
|
|
if args.start:
|
|
logger.info(f"Starting model deployment")
|
|
ptm.start()
|
|
|
|
logger.info(f"Model successfully imported with id '{ptm.model_id}'")
|
|
|
|
|