From b5bcba713dedf2a8665d73465291562bf05b1b34 Mon Sep 17 00:00:00 2001 From: Youhei Sakurai Date: Thu, 13 Jul 2023 16:55:00 +0900 Subject: [PATCH] Apply black to comply with the code style (#557) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Relates https://github.com/elastic/eland/pull/552 **Issue**: ```console C:\Users\YouheiSakurai\git\myeland>python -m black --version python -m black, 23.3.0 (compiled: yes) Python (CPython) 3.11.0 C:\Users\YouheiSakurai\git\myeland>python -m black --check --target-version=py38 bin\eland_import_hub_model would reformat bin\eland_import_hub_model Oh no! 💥 💔 💥 1 file would be reformatted. ``` **Solution**: ``` C:\Users\YouheiSakurai\git\myeland>python -m black --target-version=py38 bin\eland_import_hub_model reformatted bin\eland_import_hub_model All done! ✨ 🍰 ✨ 1 file reformatted. ``` --- bin/eland_import_hub_model | 103 ++++++++++++++++++++++++------------- 1 file changed, 66 insertions(+), 37 deletions(-) diff --git a/bin/eland_import_hub_model b/bin/eland_import_hub_model index d0da973..94ec4d0 100755 --- a/bin/eland_import_hub_model +++ b/bin/eland_import_hub_model @@ -58,41 +58,43 @@ def get_arg_parser(): "--hub-model-id", required=True, help="The model ID in the Hugging Face model hub, " - "e.g. dbmdz/bert-large-cased-finetuned-conll03-english", + "e.g. dbmdz/bert-large-cased-finetuned-conll03-english", ) parser.add_argument( "--es-model-id", required=False, default=None, help="The model ID to use in Elasticsearch, " - "e.g. bert-large-cased-finetuned-conll03-english." - "When left unspecified, this will be auto-created from the `hub-id`", + "e.g. bert-large-cased-finetuned-conll03-english." + "When left unspecified, this will be auto-created from the `hub-id`", ) parser.add_argument( - "-u", "--es-username", + "-u", + "--es-username", required=False, default=os.environ.get("ES_USERNAME"), - help="Username for Elasticsearch" + help="Username for Elasticsearch", ) parser.add_argument( - "-p", "--es-password", + "-p", + "--es-password", required=False, default=os.environ.get("ES_PASSWORD"), - help="Password for the Elasticsearch user specified with -u/--username" + help="Password for the Elasticsearch user specified with -u/--username", ) parser.add_argument( "--es-api-key", required=False, default=os.environ.get("ES_API_KEY"), - help="API key for Elasticsearch" + help="API key for Elasticsearch", ) parser.add_argument( "--task-type", required=False, choices=SUPPORTED_TASK_TYPES, help="The task type for the model usage. Will attempt to auto-detect task type for the model if not provided. " - "Default: auto", - default="auto" + "Default: auto", + default="auto", ) parser.add_argument( "--quantize", @@ -110,19 +112,16 @@ def get_arg_parser(): "--clear-previous", action="store_true", default=False, - help="Should the model previously stored with `es-model-id` be deleted" + help="Should the model previously stored with `es-model-id` be deleted", ) parser.add_argument( "--insecure", action="store_false", default=True, - help="Do not verify SSL certificates" + help="Do not verify SSL certificates", ) parser.add_argument( - "--ca-certs", - required=False, - default=DEFAULT, - help="Path to CA bundle" + "--ca-certs", required=False, default=DEFAULT, help="Path to CA bundle" ) return parser @@ -131,27 +130,29 @@ def get_arg_parser(): def get_es_client(cli_args, logger): try: es_args = { - 'request_timeout': 300, - 'verify_certs': cli_args.insecure, - 'ca_certs': cli_args.ca_certs + "request_timeout": 300, + "verify_certs": cli_args.insecure, + "ca_certs": cli_args.ca_certs, } # Deployment location if cli_args.url: - es_args['hosts'] = cli_args.url + es_args["hosts"] = cli_args.url if cli_args.cloud_id: - es_args['cloud_id'] = cli_args.cloud_id + es_args["cloud_id"] = cli_args.cloud_id # Authentication if cli_args.es_api_key: - es_args['api_key'] = cli_args.es_api_key + es_args["api_key"] = cli_args.es_api_key elif cli_args.es_username: if not cli_args.es_password: - logging.error(f"Password for user {cli_args.es_username} was not specified.") + logging.error( + f"Password for user {cli_args.es_username} was not specified." + ) exit(1) - es_args['basic_auth'] = (cli_args.es_username, cli_args.es_password) + es_args["basic_auth"] = (cli_args.es_username, cli_args.es_password) es_client = Elasticsearch(**es_args) return es_client @@ -162,15 +163,19 @@ def get_es_client(cli_args, logger): def check_cluster_version(es_client, logger): es_info = es_client.info() - logger.info(f"Connected to cluster named '{es_info['cluster_name']}' (version: {es_info['version']['number']})") + logger.info( + f"Connected to cluster named '{es_info['cluster_name']}' (version: {es_info['version']['number']})" + ) - sem_ver = parse_es_version(es_info['version']['number']) + sem_ver = parse_es_version(es_info["version"]["number"]) major_version = sem_ver[0] minor_version = sem_ver[1] # NLP models added in 8 if major_version < 8: - logger.error(f"Elasticsearch version {major_version} does not support NLP models. Please upgrade Elasticsearch to the latest version") + logger.error( + f"Elasticsearch version {major_version} does not support NLP models. Please upgrade Elasticsearch to the latest version" + ) exit(1) # PyTorch was upgraded to version 1.13.1 in 8.7. @@ -178,7 +183,9 @@ def check_cluster_version(es_client, logger): if major_version == 8 and minor_version < 7: import torch - logger.error(f"Eland uses PyTorch version {torch.__version__} which is incompatible with Elasticsearch versions prior to 8.7. Please upgrade Elasticsearch to at least version 8.7") + logger.error( + f"Eland uses PyTorch version {torch.__version__} which is incompatible with Elasticsearch versions prior to 8.7. Please upgrade Elasticsearch to at least version 8.7" + ) exit(1) return sem_ver @@ -186,7 +193,7 @@ def check_cluster_version(es_client, logger): def main(): # Configure logging - logging.basicConfig(format='%(asctime)s %(levelname)s : %(message)s') + logging.basicConfig(format="%(asctime)s %(levelname)s : %(message)s") logger = logging.getLogger(__name__) logger.setLevel(logging.INFO) @@ -198,13 +205,17 @@ def main(): TransformerModel, ) except ModuleNotFoundError as e: - logger.error(textwrap.dedent(f"""\ + logger.error( + textwrap.dedent( + f"""\ \033[31mFailed to run because module '{e.name}' is not available.\033[0m This script requires PyTorch extras to run. You can install these by running: \033[1m{sys.executable} -m pip install 'eland[pytorch]' - \033[0m""")) + \033[0m""" + ) + ) exit(1) assert SUPPORTED_TASK_TYPES @@ -218,17 +229,33 @@ def main(): # Trace and save model, then upload it from temp file with tempfile.TemporaryDirectory() as tmp_dir: - logger.info(f"Loading HuggingFace transformer tokenizer and model '{args.hub_model_id}'") + logger.info( + f"Loading HuggingFace transformer tokenizer and model '{args.hub_model_id}'" + ) try: - tm = TransformerModel(model_id=args.hub_model_id, task_type=args.task_type, es_version=cluster_version, quantize=args.quantize) + tm = TransformerModel( + model_id=args.hub_model_id, + task_type=args.task_type, + es_version=cluster_version, + quantize=args.quantize, + ) model_path, config, vocab_path = tm.save(tmp_dir) except TaskTypeError as err: - logger.error(f"Failed to get model for task type, please provide valid task type via '--task-type' parameter. Caused by {err}") + logger.error( + f"Failed to get model for task type, please provide valid task type via '--task-type' parameter. Caused by {err}" + ) exit(1) - ptm = PyTorchModel(es, args.es_model_id if args.es_model_id else tm.elasticsearch_model_id()) - model_exists = es.options(ignore_status=404).ml.get_trained_models(model_id=ptm.model_id).meta.status == 200 + ptm = PyTorchModel( + es, args.es_model_id if args.es_model_id else tm.elasticsearch_model_id() + ) + model_exists = ( + es.options(ignore_status=404) + .ml.get_trained_models(model_id=ptm.model_id) + .meta.status + == 200 + ) if model_exists: if args.clear_previous: @@ -239,7 +266,9 @@ def main(): ptm.delete() else: logger.error(f"Trained model with id '{ptm.model_id}' already exists") - logger.info("Run the script with the '--clear-previous' flag if you want to overwrite the existing model.") + logger.info( + "Run the script with the '--clear-previous' flag if you want to overwrite the existing model." + ) exit(1) logger.info(f"Creating model with id '{ptm.model_id}'")