mirror of
https://github.com/elastic/eland.git
synced 2025-07-11 00:02:14 +08:00
Hub model import script improvements (#461)
## Changes
### Better logging
Switched from `print` statements to `logging` for a cleaner and more informative output - timestamps and log level are shown. The logging is now a bit more verbose, but it will help users to better understand what the script is doing.
### Add support for ES authentication using username/password or api key
Instead of being limited to passing credentials in the URL, there are now 2 additional methods:
- username/password using `--es-username` and `--es-password`
- API key using `--es-api-key`
Credentials can also be specified as environment variables with `ES_USERNAME`/`ES_PASSWORD` or `ES_API_KEY`
### Graceful handling of missing PyTorch requirements
In order to use the `eland_import_hub_model` script, PyTorch extras are required to be installed. If the user does not have the required packages installed, a helpful message is logged with a hint to install `eland[pytorch]` with `pip`.
### Graceful handling of already existing trained model
If a trained model with the same ID as the one we're trying to import already exists, and `--clear-previous` was not specified, we now log a clearer message about why the script can't proceed along with a hint to use the `--clear-previous` flag.
Prior to this change, we were letting the API exception seep through and the user was faced with a stack trace.
### `tqdm` added to main dependencies
If the user doesn't have `eland[pytorch]` extras installed, the first module to be reported as missing is `tqdm`. Since this module is [used in eland codebase](8294224e34/eland/ml/pytorch/_pytorch_model.py (L24)
) directly, it makes sense to me to have it as part of the main set of requirements.
### Nit: Set tqdm unit to `parts` in `_pytorch_model.put_model`
The default unit is `it`, but `parts` better describes what the progress bar is tracking - uploading trained model definition parts.
This commit is contained in:
parent
b5ea1cf228
commit
fe3422100c
@ -24,39 +24,58 @@ uploading to Elasticsearch. This will also check that the task type is supported
|
||||
as well as the model and tokenizer types. All necessary configuration is
|
||||
uploaded along with the model.
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
import tempfile
|
||||
import textwrap
|
||||
|
||||
import elasticsearch
|
||||
from elastic_transport.client_utils import DEFAULT
|
||||
|
||||
from eland.ml.pytorch import PyTorchModel
|
||||
from eland.ml.pytorch.transformers import SUPPORTED_TASK_TYPES, TransformerModel
|
||||
from elasticsearch import AuthenticationException, Elasticsearch
|
||||
|
||||
MODEL_HUB_URL = "https://huggingface.co"
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(prog="upload_hub_model")
|
||||
def get_arg_parser():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"--url",
|
||||
required=True,
|
||||
help="An Elasticsearch connection URL, e.g. http://user:secret@localhost:9200",
|
||||
default=os.environ.get("ES_URL"),
|
||||
help="An Elasticsearch connection URL, e.g. http://localhost:9200",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--hub-model-id",
|
||||
required=True,
|
||||
help="The model ID in the Hugging Face model hub, "
|
||||
"e.g. dbmdz/bert-large-cased-finetuned-conll03-english",
|
||||
"e.g. dbmdz/bert-large-cased-finetuned-conll03-english",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--elasticsearch-model-id",
|
||||
"--es-model-id",
|
||||
required=False,
|
||||
default=None,
|
||||
help="The model ID to use in Elasticsearch, "
|
||||
"e.g. bert-large-cased-finetuned-conll03-english."
|
||||
"When left unspecified, this will be auto-created from the `hub-id`",
|
||||
"e.g. bert-large-cased-finetuned-conll03-english."
|
||||
"When left unspecified, this will be auto-created from the `hub-id`",
|
||||
)
|
||||
parser.add_argument(
|
||||
"-u", "--es-username",
|
||||
required=False,
|
||||
default=os.environ.get("ES_USERNAME"),
|
||||
help="Username for Elasticsearch"
|
||||
)
|
||||
parser.add_argument(
|
||||
"-p", "--es-password",
|
||||
required=False,
|
||||
default=os.environ.get("ES_PASSWORD"),
|
||||
help="Password for the Elasticsearch user specified with -u/--username"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--es-api-key",
|
||||
required=False,
|
||||
default=os.environ.get("ES_API_KEY"),
|
||||
help="Password for the Elasticsearch user specified with -u/--username"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--task-type",
|
||||
@ -80,7 +99,7 @@ def main():
|
||||
"--clear-previous",
|
||||
action="store_true",
|
||||
default=False,
|
||||
help="Should the model previously stored with `elasticsearch-model-id` be deleted"
|
||||
help="Should the model previously stored with `es-model-id` be deleted"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--insecure",
|
||||
@ -93,36 +112,100 @@ def main():
|
||||
required=False,
|
||||
default=DEFAULT,
|
||||
help="Path to CA bundle"
|
||||
)
|
||||
args = parser.parse_args()
|
||||
)
|
||||
|
||||
es = elasticsearch.Elasticsearch(args.url, request_timeout=300, verify_certs=args.insecure, ca_certs=args.ca_certs) # 5 minute timeout
|
||||
return parser
|
||||
|
||||
# trace and save model, then upload it from temp file
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
print(f"Loading HuggingFace transformer tokenizer and model {args.hub_model_id}")
|
||||
tm = TransformerModel(args.hub_model_id, args.task_type, args.quantize)
|
||||
model_path, config_path, vocab_path = tm.save(tmp_dir)
|
||||
|
||||
es_model_id = (
|
||||
args.elasticsearch_model_id
|
||||
if args.elasticsearch_model_id
|
||||
else tm.elasticsearch_model_id()
|
||||
)
|
||||
def get_es_client(cli_args):
|
||||
try:
|
||||
es_args = {
|
||||
'request_timeout': 300,
|
||||
'verify_certs': cli_args.insecure,
|
||||
'ca_certs': cli_args.ca_certs
|
||||
}
|
||||
|
||||
ptm = PyTorchModel(es, es_model_id)
|
||||
if args.clear_previous:
|
||||
print(f"Stopping previous deployment and deleting model: {ptm.model_id}")
|
||||
ptm.stop()
|
||||
ptm.delete()
|
||||
print(f"Importing model: {ptm.model_id}")
|
||||
ptm.import_model(model_path, config_path, vocab_path)
|
||||
if cli_args.es_api_key:
|
||||
es_args['api_key'] = cli_args.es_api_key
|
||||
elif cli_args.es_username:
|
||||
if not cli_args.es_password:
|
||||
logging.error(f"Password for user {cli_args.es_username} was not specified.")
|
||||
exit(1)
|
||||
|
||||
# start the deployed model
|
||||
if args.start:
|
||||
print(f"Starting model deployment: {ptm.model_id}")
|
||||
ptm.start()
|
||||
es_args['basic_auth'] = (cli_args.es_username, cli_args.es_password)
|
||||
|
||||
es_client = Elasticsearch(args.url, **es_args)
|
||||
es_info = es_client.info()
|
||||
logger.info(f"Connected to cluster named '{es_info['cluster_name']}' (version: {es_info['version']['number']})")
|
||||
|
||||
return es_client
|
||||
except AuthenticationException as e:
|
||||
logger.error(e)
|
||||
exit(1)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
# Configure logging
|
||||
logging.basicConfig(format='%(asctime)s %(levelname)s : %(message)s')
|
||||
logger = logging.getLogger(__name__)
|
||||
logger.setLevel(logging.INFO)
|
||||
|
||||
try:
|
||||
from eland.ml.pytorch import PyTorchModel
|
||||
from eland.ml.pytorch.transformers import SUPPORTED_TASK_TYPES, TransformerModel
|
||||
except ModuleNotFoundError as e:
|
||||
logger.error(textwrap.dedent(f"""\
|
||||
\033[31mFailed to run because module '{e.name}' is not available.\033[0m
|
||||
|
||||
This script requires PyTorch extras to run. You can install these by running:
|
||||
|
||||
\033[1m{sys.executable} -m pip install 'eland[pytorch]'
|
||||
\033[0m"""))
|
||||
exit(1)
|
||||
|
||||
# Parse arguments
|
||||
args = get_arg_parser().parse_args()
|
||||
|
||||
# Connect to ES
|
||||
logger.info("Establishing connection to Elasticsearch")
|
||||
es = get_es_client(args)
|
||||
|
||||
# Trace and save model, then upload it from temp file
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
logger.info(f"Loading HuggingFace transformer tokenizer and model '{args.hub_model_id}'")
|
||||
|
||||
tm = TransformerModel(args.hub_model_id, args.task_type, args.quantize)
|
||||
model_path, config_path, vocab_path = tm.save(tmp_dir)
|
||||
|
||||
ptm = PyTorchModel(es, args.es_model_id if args.es_model_id else tm.elasticsearch_model_id())
|
||||
model_exists = es.options(ignore_status=404).ml.get_trained_models(model_id=ptm.model_id).meta.status == 200
|
||||
|
||||
if model_exists:
|
||||
if args.clear_previous:
|
||||
logger.info(f"Stopping deployment for model with id '{ptm.model_id}'")
|
||||
ptm.stop()
|
||||
|
||||
logger.info(f"Deleting model with id '{ptm.model_id}'")
|
||||
ptm.delete()
|
||||
else:
|
||||
logger.error(f"Trained model with id '{ptm.model_id}' already exists")
|
||||
logger.info("Run the script with the '--clear-previous' flag if you want to overwrite the existing model.")
|
||||
exit(1)
|
||||
|
||||
logger.info(f"Creating model with id '{ptm.model_id}'")
|
||||
ptm.put_config(config_path)
|
||||
|
||||
logger.info(f"Uploading model definition")
|
||||
ptm.put_model(model_path)
|
||||
|
||||
logger.info(f"Uploading model vocabulary")
|
||||
ptm.put_vocab(vocab_path)
|
||||
|
||||
# Start the deployed model
|
||||
if args.start:
|
||||
logger.info(f"Starting model deployment")
|
||||
ptm.start()
|
||||
|
||||
logger.info(f"Model successfully imported with id '{ptm.model_id}'")
|
||||
|
||||
|
||||
|
@ -76,7 +76,9 @@ class PyTorchModel:
|
||||
break
|
||||
yield base64.b64encode(data).decode()
|
||||
|
||||
for i, data in tqdm(enumerate(model_file_chunk_generator()), total=total_parts):
|
||||
for i, data in tqdm(
|
||||
enumerate(model_file_chunk_generator()), unit=" parts", total=total_parts
|
||||
):
|
||||
self._client.ml.put_trained_model_definition_part(
|
||||
model_id=self.model_id,
|
||||
part=i,
|
||||
|
Loading…
x
Reference in New Issue
Block a user