Apply black to comply with the code style (#557)

Relates https://github.com/elastic/eland/pull/552

**Issue**:

```console
C:\Users\YouheiSakurai\git\myeland>python -m black --version
python -m black, 23.3.0 (compiled: yes)
Python (CPython) 3.11.0

C:\Users\YouheiSakurai\git\myeland>python -m black --check --target-version=py38 bin\eland_import_hub_model
would reformat bin\eland_import_hub_model

Oh no! 💥 💔 💥
1 file would be reformatted.
```

**Solution**:
```
C:\Users\YouheiSakurai\git\myeland>python -m black --target-version=py38 bin\eland_import_hub_model
reformatted bin\eland_import_hub_model

All done!  🍰 
1 file reformatted.
```
This commit is contained in:
Youhei Sakurai 2023-07-13 16:55:00 +09:00 committed by GitHub
parent 77781b90ff
commit b5bcba713d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -69,22 +69,24 @@ def get_arg_parser():
"When left unspecified, this will be auto-created from the `hub-id`", "When left unspecified, this will be auto-created from the `hub-id`",
) )
parser.add_argument( parser.add_argument(
"-u", "--es-username", "-u",
"--es-username",
required=False, required=False,
default=os.environ.get("ES_USERNAME"), default=os.environ.get("ES_USERNAME"),
help="Username for Elasticsearch" help="Username for Elasticsearch",
) )
parser.add_argument( parser.add_argument(
"-p", "--es-password", "-p",
"--es-password",
required=False, required=False,
default=os.environ.get("ES_PASSWORD"), default=os.environ.get("ES_PASSWORD"),
help="Password for the Elasticsearch user specified with -u/--username" help="Password for the Elasticsearch user specified with -u/--username",
) )
parser.add_argument( parser.add_argument(
"--es-api-key", "--es-api-key",
required=False, required=False,
default=os.environ.get("ES_API_KEY"), default=os.environ.get("ES_API_KEY"),
help="API key for Elasticsearch" help="API key for Elasticsearch",
) )
parser.add_argument( parser.add_argument(
"--task-type", "--task-type",
@ -92,7 +94,7 @@ def get_arg_parser():
choices=SUPPORTED_TASK_TYPES, choices=SUPPORTED_TASK_TYPES,
help="The task type for the model usage. Will attempt to auto-detect task type for the model if not provided. " help="The task type for the model usage. Will attempt to auto-detect task type for the model if not provided. "
"Default: auto", "Default: auto",
default="auto" default="auto",
) )
parser.add_argument( parser.add_argument(
"--quantize", "--quantize",
@ -110,19 +112,16 @@ def get_arg_parser():
"--clear-previous", "--clear-previous",
action="store_true", action="store_true",
default=False, default=False,
help="Should the model previously stored with `es-model-id` be deleted" help="Should the model previously stored with `es-model-id` be deleted",
) )
parser.add_argument( parser.add_argument(
"--insecure", "--insecure",
action="store_false", action="store_false",
default=True, default=True,
help="Do not verify SSL certificates" help="Do not verify SSL certificates",
) )
parser.add_argument( parser.add_argument(
"--ca-certs", "--ca-certs", required=False, default=DEFAULT, help="Path to CA bundle"
required=False,
default=DEFAULT,
help="Path to CA bundle"
) )
return parser return parser
@ -131,27 +130,29 @@ def get_arg_parser():
def get_es_client(cli_args, logger): def get_es_client(cli_args, logger):
try: try:
es_args = { es_args = {
'request_timeout': 300, "request_timeout": 300,
'verify_certs': cli_args.insecure, "verify_certs": cli_args.insecure,
'ca_certs': cli_args.ca_certs "ca_certs": cli_args.ca_certs,
} }
# Deployment location # Deployment location
if cli_args.url: if cli_args.url:
es_args['hosts'] = cli_args.url es_args["hosts"] = cli_args.url
if cli_args.cloud_id: if cli_args.cloud_id:
es_args['cloud_id'] = cli_args.cloud_id es_args["cloud_id"] = cli_args.cloud_id
# Authentication # Authentication
if cli_args.es_api_key: if cli_args.es_api_key:
es_args['api_key'] = cli_args.es_api_key es_args["api_key"] = cli_args.es_api_key
elif cli_args.es_username: elif cli_args.es_username:
if not cli_args.es_password: if not cli_args.es_password:
logging.error(f"Password for user {cli_args.es_username} was not specified.") logging.error(
f"Password for user {cli_args.es_username} was not specified."
)
exit(1) exit(1)
es_args['basic_auth'] = (cli_args.es_username, cli_args.es_password) es_args["basic_auth"] = (cli_args.es_username, cli_args.es_password)
es_client = Elasticsearch(**es_args) es_client = Elasticsearch(**es_args)
return es_client return es_client
@ -162,15 +163,19 @@ def get_es_client(cli_args, logger):
def check_cluster_version(es_client, logger): def check_cluster_version(es_client, logger):
es_info = es_client.info() es_info = es_client.info()
logger.info(f"Connected to cluster named '{es_info['cluster_name']}' (version: {es_info['version']['number']})") logger.info(
f"Connected to cluster named '{es_info['cluster_name']}' (version: {es_info['version']['number']})"
)
sem_ver = parse_es_version(es_info['version']['number']) sem_ver = parse_es_version(es_info["version"]["number"])
major_version = sem_ver[0] major_version = sem_ver[0]
minor_version = sem_ver[1] minor_version = sem_ver[1]
# NLP models added in 8 # NLP models added in 8
if major_version < 8: if major_version < 8:
logger.error(f"Elasticsearch version {major_version} does not support NLP models. Please upgrade Elasticsearch to the latest version") logger.error(
f"Elasticsearch version {major_version} does not support NLP models. Please upgrade Elasticsearch to the latest version"
)
exit(1) exit(1)
# PyTorch was upgraded to version 1.13.1 in 8.7. # PyTorch was upgraded to version 1.13.1 in 8.7.
@ -178,7 +183,9 @@ def check_cluster_version(es_client, logger):
if major_version == 8 and minor_version < 7: if major_version == 8 and minor_version < 7:
import torch import torch
logger.error(f"Eland uses PyTorch version {torch.__version__} which is incompatible with Elasticsearch versions prior to 8.7. Please upgrade Elasticsearch to at least version 8.7") logger.error(
f"Eland uses PyTorch version {torch.__version__} which is incompatible with Elasticsearch versions prior to 8.7. Please upgrade Elasticsearch to at least version 8.7"
)
exit(1) exit(1)
return sem_ver return sem_ver
@ -186,7 +193,7 @@ def check_cluster_version(es_client, logger):
def main(): def main():
# Configure logging # Configure logging
logging.basicConfig(format='%(asctime)s %(levelname)s : %(message)s') logging.basicConfig(format="%(asctime)s %(levelname)s : %(message)s")
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO) logger.setLevel(logging.INFO)
@ -198,13 +205,17 @@ def main():
TransformerModel, TransformerModel,
) )
except ModuleNotFoundError as e: except ModuleNotFoundError as e:
logger.error(textwrap.dedent(f"""\ logger.error(
textwrap.dedent(
f"""\
\033[31mFailed to run because module '{e.name}' is not available.\033[0m \033[31mFailed to run because module '{e.name}' is not available.\033[0m
This script requires PyTorch extras to run. You can install these by running: This script requires PyTorch extras to run. You can install these by running:
\033[1m{sys.executable} -m pip install 'eland[pytorch]' \033[1m{sys.executable} -m pip install 'eland[pytorch]'
\033[0m""")) \033[0m"""
)
)
exit(1) exit(1)
assert SUPPORTED_TASK_TYPES assert SUPPORTED_TASK_TYPES
@ -218,17 +229,33 @@ def main():
# Trace and save model, then upload it from temp file # Trace and save model, then upload it from temp file
with tempfile.TemporaryDirectory() as tmp_dir: with tempfile.TemporaryDirectory() as tmp_dir:
logger.info(f"Loading HuggingFace transformer tokenizer and model '{args.hub_model_id}'") logger.info(
f"Loading HuggingFace transformer tokenizer and model '{args.hub_model_id}'"
)
try: try:
tm = TransformerModel(model_id=args.hub_model_id, task_type=args.task_type, es_version=cluster_version, quantize=args.quantize) tm = TransformerModel(
model_id=args.hub_model_id,
task_type=args.task_type,
es_version=cluster_version,
quantize=args.quantize,
)
model_path, config, vocab_path = tm.save(tmp_dir) model_path, config, vocab_path = tm.save(tmp_dir)
except TaskTypeError as err: except TaskTypeError as err:
logger.error(f"Failed to get model for task type, please provide valid task type via '--task-type' parameter. Caused by {err}") logger.error(
f"Failed to get model for task type, please provide valid task type via '--task-type' parameter. Caused by {err}"
)
exit(1) exit(1)
ptm = PyTorchModel(es, args.es_model_id if args.es_model_id else tm.elasticsearch_model_id()) ptm = PyTorchModel(
model_exists = es.options(ignore_status=404).ml.get_trained_models(model_id=ptm.model_id).meta.status == 200 es, args.es_model_id if args.es_model_id else tm.elasticsearch_model_id()
)
model_exists = (
es.options(ignore_status=404)
.ml.get_trained_models(model_id=ptm.model_id)
.meta.status
== 200
)
if model_exists: if model_exists:
if args.clear_previous: if args.clear_previous:
@ -239,7 +266,9 @@ def main():
ptm.delete() ptm.delete()
else: else:
logger.error(f"Trained model with id '{ptm.model_id}' already exists") logger.error(f"Trained model with id '{ptm.model_id}' already exists")
logger.info("Run the script with the '--clear-previous' flag if you want to overwrite the existing model.") logger.info(
"Run the script with the '--clear-previous' flag if you want to overwrite the existing model."
)
exit(1) exit(1)
logger.info(f"Creating model with id '{ptm.model_id}'") logger.info(f"Creating model with id '{ptm.model_id}'")