Hub model import script improvements (#461)

## Changes 
### Better logging
Switched from `print` statements to `logging` for a cleaner and more informative output - timestamps and log level are shown. The logging is now a bit more verbose, but it will help users to better understand what the script is doing.

### Add support for ES authentication using username/password or api key
Instead of being limited to passing credentials in the URL, there are now 2 additional methods:
- username/password using `--es-username` and `--es-password`
- API key using `--es-api-key`

Credentials can also be specified as environment variables with `ES_USERNAME`/`ES_PASSWORD` or `ES_API_KEY`

### Graceful handling of missing PyTorch requirements
In order to use the `eland_import_hub_model` script, PyTorch extras are required to be installed. If the user does not have the required packages installed, a helpful message is logged with a hint to install `eland[pytorch]` with `pip`.

### Graceful handling of already existing trained model
If a trained model with the same ID as the one we're trying to import already exists, and `--clear-previous` was not specified, we now log a clearer message about why the script can't proceed along with a hint to use the `--clear-previous` flag. 

Prior to this change, we were letting the API exception seep through and the user was faced with a stack trace.

### `tqdm` added to main dependencies
If the user doesn't have `eland[pytorch]` extras installed, the first module to be reported as missing is `tqdm`. Since this module is [used in eland codebase](8294224e34/eland/ml/pytorch/_pytorch_model.py (L24)) directly, it makes sense to me to have it as part of the main set of requirements.

### Nit: Set tqdm unit to `parts` in `_pytorch_model.put_model`
The default unit is `it`, but `parts` better describes what the progress bar is tracking - uploading trained model definition parts.
This commit is contained in:
David Olaru 2022-04-27 15:13:58 +01:00 committed by GitHub
parent b5ea1cf228
commit fe3422100c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 124 additions and 39 deletions

View File

@ -24,25 +24,26 @@ uploading to Elasticsearch. This will also check that the task type is supported
as well as the model and tokenizer types. All necessary configuration is
uploaded along with the model.
"""
import argparse
import logging
import os
import sys
import tempfile
import textwrap
import elasticsearch
from elastic_transport.client_utils import DEFAULT
from eland.ml.pytorch import PyTorchModel
from eland.ml.pytorch.transformers import SUPPORTED_TASK_TYPES, TransformerModel
from elasticsearch import AuthenticationException, Elasticsearch
MODEL_HUB_URL = "https://huggingface.co"
def main():
parser = argparse.ArgumentParser(prog="upload_hub_model")
def get_arg_parser():
parser = argparse.ArgumentParser()
parser.add_argument(
"--url",
required=True,
help="An Elasticsearch connection URL, e.g. http://user:secret@localhost:9200",
default=os.environ.get("ES_URL"),
help="An Elasticsearch connection URL, e.g. http://localhost:9200",
)
parser.add_argument(
"--hub-model-id",
@ -51,13 +52,31 @@ def main():
"e.g. dbmdz/bert-large-cased-finetuned-conll03-english",
)
parser.add_argument(
"--elasticsearch-model-id",
"--es-model-id",
required=False,
default=None,
help="The model ID to use in Elasticsearch, "
"e.g. bert-large-cased-finetuned-conll03-english."
"When left unspecified, this will be auto-created from the `hub-id`",
)
parser.add_argument(
"-u", "--es-username",
required=False,
default=os.environ.get("ES_USERNAME"),
help="Username for Elasticsearch"
)
parser.add_argument(
"-p", "--es-password",
required=False,
default=os.environ.get("ES_PASSWORD"),
help="Password for the Elasticsearch user specified with -u/--username"
)
parser.add_argument(
"--es-api-key",
required=False,
default=os.environ.get("ES_API_KEY"),
help="Password for the Elasticsearch user specified with -u/--username"
)
parser.add_argument(
"--task-type",
required=True,
@ -80,7 +99,7 @@ def main():
"--clear-previous",
action="store_true",
default=False,
help="Should the model previously stored with `elasticsearch-model-id` be deleted"
help="Should the model previously stored with `es-model-id` be deleted"
)
parser.add_argument(
"--insecure",
@ -94,35 +113,99 @@ def main():
default=DEFAULT,
help="Path to CA bundle"
)
args = parser.parse_args()
es = elasticsearch.Elasticsearch(args.url, request_timeout=300, verify_certs=args.insecure, ca_certs=args.ca_certs) # 5 minute timeout
return parser
# trace and save model, then upload it from temp file
with tempfile.TemporaryDirectory() as tmp_dir:
print(f"Loading HuggingFace transformer tokenizer and model {args.hub_model_id}")
tm = TransformerModel(args.hub_model_id, args.task_type, args.quantize)
model_path, config_path, vocab_path = tm.save(tmp_dir)
es_model_id = (
args.elasticsearch_model_id
if args.elasticsearch_model_id
else tm.elasticsearch_model_id()
)
def get_es_client(cli_args):
try:
es_args = {
'request_timeout': 300,
'verify_certs': cli_args.insecure,
'ca_certs': cli_args.ca_certs
}
ptm = PyTorchModel(es, es_model_id)
if args.clear_previous:
print(f"Stopping previous deployment and deleting model: {ptm.model_id}")
ptm.stop()
ptm.delete()
print(f"Importing model: {ptm.model_id}")
ptm.import_model(model_path, config_path, vocab_path)
if cli_args.es_api_key:
es_args['api_key'] = cli_args.es_api_key
elif cli_args.es_username:
if not cli_args.es_password:
logging.error(f"Password for user {cli_args.es_username} was not specified.")
exit(1)
# start the deployed model
if args.start:
print(f"Starting model deployment: {ptm.model_id}")
ptm.start()
es_args['basic_auth'] = (cli_args.es_username, cli_args.es_password)
es_client = Elasticsearch(args.url, **es_args)
es_info = es_client.info()
logger.info(f"Connected to cluster named '{es_info['cluster_name']}' (version: {es_info['version']['number']})")
return es_client
except AuthenticationException as e:
logger.error(e)
exit(1)
if __name__ == "__main__":
main()
# Configure logging
logging.basicConfig(format='%(asctime)s %(levelname)s : %(message)s')
logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)
try:
from eland.ml.pytorch import PyTorchModel
from eland.ml.pytorch.transformers import SUPPORTED_TASK_TYPES, TransformerModel
except ModuleNotFoundError as e:
logger.error(textwrap.dedent(f"""\
\033[31mFailed to run because module '{e.name}' is not available.\033[0m
This script requires PyTorch extras to run. You can install these by running:
\033[1m{sys.executable} -m pip install 'eland[pytorch]'
\033[0m"""))
exit(1)
# Parse arguments
args = get_arg_parser().parse_args()
# Connect to ES
logger.info("Establishing connection to Elasticsearch")
es = get_es_client(args)
# Trace and save model, then upload it from temp file
with tempfile.TemporaryDirectory() as tmp_dir:
logger.info(f"Loading HuggingFace transformer tokenizer and model '{args.hub_model_id}'")
tm = TransformerModel(args.hub_model_id, args.task_type, args.quantize)
model_path, config_path, vocab_path = tm.save(tmp_dir)
ptm = PyTorchModel(es, args.es_model_id if args.es_model_id else tm.elasticsearch_model_id())
model_exists = es.options(ignore_status=404).ml.get_trained_models(model_id=ptm.model_id).meta.status == 200
if model_exists:
if args.clear_previous:
logger.info(f"Stopping deployment for model with id '{ptm.model_id}'")
ptm.stop()
logger.info(f"Deleting model with id '{ptm.model_id}'")
ptm.delete()
else:
logger.error(f"Trained model with id '{ptm.model_id}' already exists")
logger.info("Run the script with the '--clear-previous' flag if you want to overwrite the existing model.")
exit(1)
logger.info(f"Creating model with id '{ptm.model_id}'")
ptm.put_config(config_path)
logger.info(f"Uploading model definition")
ptm.put_model(model_path)
logger.info(f"Uploading model vocabulary")
ptm.put_vocab(vocab_path)
# Start the deployed model
if args.start:
logger.info(f"Starting model deployment")
ptm.start()
logger.info(f"Model successfully imported with id '{ptm.model_id}'")

View File

@ -76,7 +76,9 @@ class PyTorchModel:
break
yield base64.b64encode(data).decode()
for i, data in tqdm(enumerate(model_file_chunk_generator()), total=total_parts):
for i, data in tqdm(
enumerate(model_file_chunk_generator()), unit=" parts", total=total_parts
):
self._client.ml.put_trained_model_definition_part(
model_id=self.model_id,
part=i,