mirror of
https://github.com/elastic/eland.git
synced 2025-07-11 00:02:14 +08:00
Apply black to comply with the code style (#557)
Relates https://github.com/elastic/eland/pull/552 **Issue**: ```console C:\Users\YouheiSakurai\git\myeland>python -m black --version python -m black, 23.3.0 (compiled: yes) Python (CPython) 3.11.0 C:\Users\YouheiSakurai\git\myeland>python -m black --check --target-version=py38 bin\eland_import_hub_model would reformat bin\eland_import_hub_model Oh no! 💥 💔 💥 1 file would be reformatted. ``` **Solution**: ``` C:\Users\YouheiSakurai\git\myeland>python -m black --target-version=py38 bin\eland_import_hub_model reformatted bin\eland_import_hub_model All done! ✨ 🍰 ✨ 1 file reformatted. ```
This commit is contained in:
parent
77781b90ff
commit
b5bcba713d
@ -58,41 +58,43 @@ def get_arg_parser():
|
||||
"--hub-model-id",
|
||||
required=True,
|
||||
help="The model ID in the Hugging Face model hub, "
|
||||
"e.g. dbmdz/bert-large-cased-finetuned-conll03-english",
|
||||
"e.g. dbmdz/bert-large-cased-finetuned-conll03-english",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--es-model-id",
|
||||
required=False,
|
||||
default=None,
|
||||
help="The model ID to use in Elasticsearch, "
|
||||
"e.g. bert-large-cased-finetuned-conll03-english."
|
||||
"When left unspecified, this will be auto-created from the `hub-id`",
|
||||
"e.g. bert-large-cased-finetuned-conll03-english."
|
||||
"When left unspecified, this will be auto-created from the `hub-id`",
|
||||
)
|
||||
parser.add_argument(
|
||||
"-u", "--es-username",
|
||||
"-u",
|
||||
"--es-username",
|
||||
required=False,
|
||||
default=os.environ.get("ES_USERNAME"),
|
||||
help="Username for Elasticsearch"
|
||||
help="Username for Elasticsearch",
|
||||
)
|
||||
parser.add_argument(
|
||||
"-p", "--es-password",
|
||||
"-p",
|
||||
"--es-password",
|
||||
required=False,
|
||||
default=os.environ.get("ES_PASSWORD"),
|
||||
help="Password for the Elasticsearch user specified with -u/--username"
|
||||
help="Password for the Elasticsearch user specified with -u/--username",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--es-api-key",
|
||||
required=False,
|
||||
default=os.environ.get("ES_API_KEY"),
|
||||
help="API key for Elasticsearch"
|
||||
help="API key for Elasticsearch",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--task-type",
|
||||
required=False,
|
||||
choices=SUPPORTED_TASK_TYPES,
|
||||
help="The task type for the model usage. Will attempt to auto-detect task type for the model if not provided. "
|
||||
"Default: auto",
|
||||
default="auto"
|
||||
"Default: auto",
|
||||
default="auto",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--quantize",
|
||||
@ -110,19 +112,16 @@ def get_arg_parser():
|
||||
"--clear-previous",
|
||||
action="store_true",
|
||||
default=False,
|
||||
help="Should the model previously stored with `es-model-id` be deleted"
|
||||
help="Should the model previously stored with `es-model-id` be deleted",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--insecure",
|
||||
action="store_false",
|
||||
default=True,
|
||||
help="Do not verify SSL certificates"
|
||||
help="Do not verify SSL certificates",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--ca-certs",
|
||||
required=False,
|
||||
default=DEFAULT,
|
||||
help="Path to CA bundle"
|
||||
"--ca-certs", required=False, default=DEFAULT, help="Path to CA bundle"
|
||||
)
|
||||
|
||||
return parser
|
||||
@ -131,27 +130,29 @@ def get_arg_parser():
|
||||
def get_es_client(cli_args, logger):
|
||||
try:
|
||||
es_args = {
|
||||
'request_timeout': 300,
|
||||
'verify_certs': cli_args.insecure,
|
||||
'ca_certs': cli_args.ca_certs
|
||||
"request_timeout": 300,
|
||||
"verify_certs": cli_args.insecure,
|
||||
"ca_certs": cli_args.ca_certs,
|
||||
}
|
||||
|
||||
# Deployment location
|
||||
if cli_args.url:
|
||||
es_args['hosts'] = cli_args.url
|
||||
es_args["hosts"] = cli_args.url
|
||||
|
||||
if cli_args.cloud_id:
|
||||
es_args['cloud_id'] = cli_args.cloud_id
|
||||
es_args["cloud_id"] = cli_args.cloud_id
|
||||
|
||||
# Authentication
|
||||
if cli_args.es_api_key:
|
||||
es_args['api_key'] = cli_args.es_api_key
|
||||
es_args["api_key"] = cli_args.es_api_key
|
||||
elif cli_args.es_username:
|
||||
if not cli_args.es_password:
|
||||
logging.error(f"Password for user {cli_args.es_username} was not specified.")
|
||||
logging.error(
|
||||
f"Password for user {cli_args.es_username} was not specified."
|
||||
)
|
||||
exit(1)
|
||||
|
||||
es_args['basic_auth'] = (cli_args.es_username, cli_args.es_password)
|
||||
es_args["basic_auth"] = (cli_args.es_username, cli_args.es_password)
|
||||
|
||||
es_client = Elasticsearch(**es_args)
|
||||
return es_client
|
||||
@ -162,15 +163,19 @@ def get_es_client(cli_args, logger):
|
||||
|
||||
def check_cluster_version(es_client, logger):
|
||||
es_info = es_client.info()
|
||||
logger.info(f"Connected to cluster named '{es_info['cluster_name']}' (version: {es_info['version']['number']})")
|
||||
logger.info(
|
||||
f"Connected to cluster named '{es_info['cluster_name']}' (version: {es_info['version']['number']})"
|
||||
)
|
||||
|
||||
sem_ver = parse_es_version(es_info['version']['number'])
|
||||
sem_ver = parse_es_version(es_info["version"]["number"])
|
||||
major_version = sem_ver[0]
|
||||
minor_version = sem_ver[1]
|
||||
|
||||
# NLP models added in 8
|
||||
if major_version < 8:
|
||||
logger.error(f"Elasticsearch version {major_version} does not support NLP models. Please upgrade Elasticsearch to the latest version")
|
||||
logger.error(
|
||||
f"Elasticsearch version {major_version} does not support NLP models. Please upgrade Elasticsearch to the latest version"
|
||||
)
|
||||
exit(1)
|
||||
|
||||
# PyTorch was upgraded to version 1.13.1 in 8.7.
|
||||
@ -178,7 +183,9 @@ def check_cluster_version(es_client, logger):
|
||||
if major_version == 8 and minor_version < 7:
|
||||
import torch
|
||||
|
||||
logger.error(f"Eland uses PyTorch version {torch.__version__} which is incompatible with Elasticsearch versions prior to 8.7. Please upgrade Elasticsearch to at least version 8.7")
|
||||
logger.error(
|
||||
f"Eland uses PyTorch version {torch.__version__} which is incompatible with Elasticsearch versions prior to 8.7. Please upgrade Elasticsearch to at least version 8.7"
|
||||
)
|
||||
exit(1)
|
||||
|
||||
return sem_ver
|
||||
@ -186,7 +193,7 @@ def check_cluster_version(es_client, logger):
|
||||
|
||||
def main():
|
||||
# Configure logging
|
||||
logging.basicConfig(format='%(asctime)s %(levelname)s : %(message)s')
|
||||
logging.basicConfig(format="%(asctime)s %(levelname)s : %(message)s")
|
||||
logger = logging.getLogger(__name__)
|
||||
logger.setLevel(logging.INFO)
|
||||
|
||||
@ -198,13 +205,17 @@ def main():
|
||||
TransformerModel,
|
||||
)
|
||||
except ModuleNotFoundError as e:
|
||||
logger.error(textwrap.dedent(f"""\
|
||||
logger.error(
|
||||
textwrap.dedent(
|
||||
f"""\
|
||||
\033[31mFailed to run because module '{e.name}' is not available.\033[0m
|
||||
|
||||
This script requires PyTorch extras to run. You can install these by running:
|
||||
|
||||
\033[1m{sys.executable} -m pip install 'eland[pytorch]'
|
||||
\033[0m"""))
|
||||
\033[0m"""
|
||||
)
|
||||
)
|
||||
exit(1)
|
||||
assert SUPPORTED_TASK_TYPES
|
||||
|
||||
@ -218,17 +229,33 @@ def main():
|
||||
|
||||
# Trace and save model, then upload it from temp file
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
logger.info(f"Loading HuggingFace transformer tokenizer and model '{args.hub_model_id}'")
|
||||
logger.info(
|
||||
f"Loading HuggingFace transformer tokenizer and model '{args.hub_model_id}'"
|
||||
)
|
||||
|
||||
try:
|
||||
tm = TransformerModel(model_id=args.hub_model_id, task_type=args.task_type, es_version=cluster_version, quantize=args.quantize)
|
||||
tm = TransformerModel(
|
||||
model_id=args.hub_model_id,
|
||||
task_type=args.task_type,
|
||||
es_version=cluster_version,
|
||||
quantize=args.quantize,
|
||||
)
|
||||
model_path, config, vocab_path = tm.save(tmp_dir)
|
||||
except TaskTypeError as err:
|
||||
logger.error(f"Failed to get model for task type, please provide valid task type via '--task-type' parameter. Caused by {err}")
|
||||
logger.error(
|
||||
f"Failed to get model for task type, please provide valid task type via '--task-type' parameter. Caused by {err}"
|
||||
)
|
||||
exit(1)
|
||||
|
||||
ptm = PyTorchModel(es, args.es_model_id if args.es_model_id else tm.elasticsearch_model_id())
|
||||
model_exists = es.options(ignore_status=404).ml.get_trained_models(model_id=ptm.model_id).meta.status == 200
|
||||
ptm = PyTorchModel(
|
||||
es, args.es_model_id if args.es_model_id else tm.elasticsearch_model_id()
|
||||
)
|
||||
model_exists = (
|
||||
es.options(ignore_status=404)
|
||||
.ml.get_trained_models(model_id=ptm.model_id)
|
||||
.meta.status
|
||||
== 200
|
||||
)
|
||||
|
||||
if model_exists:
|
||||
if args.clear_previous:
|
||||
@ -239,7 +266,9 @@ def main():
|
||||
ptm.delete()
|
||||
else:
|
||||
logger.error(f"Trained model with id '{ptm.model_id}' already exists")
|
||||
logger.info("Run the script with the '--clear-previous' flag if you want to overwrite the existing model.")
|
||||
logger.info(
|
||||
"Run the script with the '--clear-previous' flag if you want to overwrite the existing model."
|
||||
)
|
||||
exit(1)
|
||||
|
||||
logger.info(f"Creating model with id '{ptm.model_id}'")
|
||||
|
Loading…
x
Reference in New Issue
Block a user