mirror of
https://github.com/elastic/eland.git
synced 2025-07-24 00:00:39 +08:00
Apply black to comply with the code style (#557)
Relates https://github.com/elastic/eland/pull/552 **Issue**: ```console C:\Users\YouheiSakurai\git\myeland>python -m black --version python -m black, 23.3.0 (compiled: yes) Python (CPython) 3.11.0 C:\Users\YouheiSakurai\git\myeland>python -m black --check --target-version=py38 bin\eland_import_hub_model would reformat bin\eland_import_hub_model Oh no! 💥 💔 💥 1 file would be reformatted. ``` **Solution**: ``` C:\Users\YouheiSakurai\git\myeland>python -m black --target-version=py38 bin\eland_import_hub_model reformatted bin\eland_import_hub_model All done! ✨ 🍰 ✨ 1 file reformatted. ```
This commit is contained in:
parent
77781b90ff
commit
b5bcba713d
@ -58,41 +58,43 @@ def get_arg_parser():
|
|||||||
"--hub-model-id",
|
"--hub-model-id",
|
||||||
required=True,
|
required=True,
|
||||||
help="The model ID in the Hugging Face model hub, "
|
help="The model ID in the Hugging Face model hub, "
|
||||||
"e.g. dbmdz/bert-large-cased-finetuned-conll03-english",
|
"e.g. dbmdz/bert-large-cased-finetuned-conll03-english",
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--es-model-id",
|
"--es-model-id",
|
||||||
required=False,
|
required=False,
|
||||||
default=None,
|
default=None,
|
||||||
help="The model ID to use in Elasticsearch, "
|
help="The model ID to use in Elasticsearch, "
|
||||||
"e.g. bert-large-cased-finetuned-conll03-english."
|
"e.g. bert-large-cased-finetuned-conll03-english."
|
||||||
"When left unspecified, this will be auto-created from the `hub-id`",
|
"When left unspecified, this will be auto-created from the `hub-id`",
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"-u", "--es-username",
|
"-u",
|
||||||
|
"--es-username",
|
||||||
required=False,
|
required=False,
|
||||||
default=os.environ.get("ES_USERNAME"),
|
default=os.environ.get("ES_USERNAME"),
|
||||||
help="Username for Elasticsearch"
|
help="Username for Elasticsearch",
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"-p", "--es-password",
|
"-p",
|
||||||
|
"--es-password",
|
||||||
required=False,
|
required=False,
|
||||||
default=os.environ.get("ES_PASSWORD"),
|
default=os.environ.get("ES_PASSWORD"),
|
||||||
help="Password for the Elasticsearch user specified with -u/--username"
|
help="Password for the Elasticsearch user specified with -u/--username",
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--es-api-key",
|
"--es-api-key",
|
||||||
required=False,
|
required=False,
|
||||||
default=os.environ.get("ES_API_KEY"),
|
default=os.environ.get("ES_API_KEY"),
|
||||||
help="API key for Elasticsearch"
|
help="API key for Elasticsearch",
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--task-type",
|
"--task-type",
|
||||||
required=False,
|
required=False,
|
||||||
choices=SUPPORTED_TASK_TYPES,
|
choices=SUPPORTED_TASK_TYPES,
|
||||||
help="The task type for the model usage. Will attempt to auto-detect task type for the model if not provided. "
|
help="The task type for the model usage. Will attempt to auto-detect task type for the model if not provided. "
|
||||||
"Default: auto",
|
"Default: auto",
|
||||||
default="auto"
|
default="auto",
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--quantize",
|
"--quantize",
|
||||||
@ -110,19 +112,16 @@ def get_arg_parser():
|
|||||||
"--clear-previous",
|
"--clear-previous",
|
||||||
action="store_true",
|
action="store_true",
|
||||||
default=False,
|
default=False,
|
||||||
help="Should the model previously stored with `es-model-id` be deleted"
|
help="Should the model previously stored with `es-model-id` be deleted",
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--insecure",
|
"--insecure",
|
||||||
action="store_false",
|
action="store_false",
|
||||||
default=True,
|
default=True,
|
||||||
help="Do not verify SSL certificates"
|
help="Do not verify SSL certificates",
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--ca-certs",
|
"--ca-certs", required=False, default=DEFAULT, help="Path to CA bundle"
|
||||||
required=False,
|
|
||||||
default=DEFAULT,
|
|
||||||
help="Path to CA bundle"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
return parser
|
return parser
|
||||||
@ -131,27 +130,29 @@ def get_arg_parser():
|
|||||||
def get_es_client(cli_args, logger):
|
def get_es_client(cli_args, logger):
|
||||||
try:
|
try:
|
||||||
es_args = {
|
es_args = {
|
||||||
'request_timeout': 300,
|
"request_timeout": 300,
|
||||||
'verify_certs': cli_args.insecure,
|
"verify_certs": cli_args.insecure,
|
||||||
'ca_certs': cli_args.ca_certs
|
"ca_certs": cli_args.ca_certs,
|
||||||
}
|
}
|
||||||
|
|
||||||
# Deployment location
|
# Deployment location
|
||||||
if cli_args.url:
|
if cli_args.url:
|
||||||
es_args['hosts'] = cli_args.url
|
es_args["hosts"] = cli_args.url
|
||||||
|
|
||||||
if cli_args.cloud_id:
|
if cli_args.cloud_id:
|
||||||
es_args['cloud_id'] = cli_args.cloud_id
|
es_args["cloud_id"] = cli_args.cloud_id
|
||||||
|
|
||||||
# Authentication
|
# Authentication
|
||||||
if cli_args.es_api_key:
|
if cli_args.es_api_key:
|
||||||
es_args['api_key'] = cli_args.es_api_key
|
es_args["api_key"] = cli_args.es_api_key
|
||||||
elif cli_args.es_username:
|
elif cli_args.es_username:
|
||||||
if not cli_args.es_password:
|
if not cli_args.es_password:
|
||||||
logging.error(f"Password for user {cli_args.es_username} was not specified.")
|
logging.error(
|
||||||
|
f"Password for user {cli_args.es_username} was not specified."
|
||||||
|
)
|
||||||
exit(1)
|
exit(1)
|
||||||
|
|
||||||
es_args['basic_auth'] = (cli_args.es_username, cli_args.es_password)
|
es_args["basic_auth"] = (cli_args.es_username, cli_args.es_password)
|
||||||
|
|
||||||
es_client = Elasticsearch(**es_args)
|
es_client = Elasticsearch(**es_args)
|
||||||
return es_client
|
return es_client
|
||||||
@ -162,15 +163,19 @@ def get_es_client(cli_args, logger):
|
|||||||
|
|
||||||
def check_cluster_version(es_client, logger):
|
def check_cluster_version(es_client, logger):
|
||||||
es_info = es_client.info()
|
es_info = es_client.info()
|
||||||
logger.info(f"Connected to cluster named '{es_info['cluster_name']}' (version: {es_info['version']['number']})")
|
logger.info(
|
||||||
|
f"Connected to cluster named '{es_info['cluster_name']}' (version: {es_info['version']['number']})"
|
||||||
|
)
|
||||||
|
|
||||||
sem_ver = parse_es_version(es_info['version']['number'])
|
sem_ver = parse_es_version(es_info["version"]["number"])
|
||||||
major_version = sem_ver[0]
|
major_version = sem_ver[0]
|
||||||
minor_version = sem_ver[1]
|
minor_version = sem_ver[1]
|
||||||
|
|
||||||
# NLP models added in 8
|
# NLP models added in 8
|
||||||
if major_version < 8:
|
if major_version < 8:
|
||||||
logger.error(f"Elasticsearch version {major_version} does not support NLP models. Please upgrade Elasticsearch to the latest version")
|
logger.error(
|
||||||
|
f"Elasticsearch version {major_version} does not support NLP models. Please upgrade Elasticsearch to the latest version"
|
||||||
|
)
|
||||||
exit(1)
|
exit(1)
|
||||||
|
|
||||||
# PyTorch was upgraded to version 1.13.1 in 8.7.
|
# PyTorch was upgraded to version 1.13.1 in 8.7.
|
||||||
@ -178,7 +183,9 @@ def check_cluster_version(es_client, logger):
|
|||||||
if major_version == 8 and minor_version < 7:
|
if major_version == 8 and minor_version < 7:
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
logger.error(f"Eland uses PyTorch version {torch.__version__} which is incompatible with Elasticsearch versions prior to 8.7. Please upgrade Elasticsearch to at least version 8.7")
|
logger.error(
|
||||||
|
f"Eland uses PyTorch version {torch.__version__} which is incompatible with Elasticsearch versions prior to 8.7. Please upgrade Elasticsearch to at least version 8.7"
|
||||||
|
)
|
||||||
exit(1)
|
exit(1)
|
||||||
|
|
||||||
return sem_ver
|
return sem_ver
|
||||||
@ -186,7 +193,7 @@ def check_cluster_version(es_client, logger):
|
|||||||
|
|
||||||
def main():
|
def main():
|
||||||
# Configure logging
|
# Configure logging
|
||||||
logging.basicConfig(format='%(asctime)s %(levelname)s : %(message)s')
|
logging.basicConfig(format="%(asctime)s %(levelname)s : %(message)s")
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
logger.setLevel(logging.INFO)
|
logger.setLevel(logging.INFO)
|
||||||
|
|
||||||
@ -198,13 +205,17 @@ def main():
|
|||||||
TransformerModel,
|
TransformerModel,
|
||||||
)
|
)
|
||||||
except ModuleNotFoundError as e:
|
except ModuleNotFoundError as e:
|
||||||
logger.error(textwrap.dedent(f"""\
|
logger.error(
|
||||||
|
textwrap.dedent(
|
||||||
|
f"""\
|
||||||
\033[31mFailed to run because module '{e.name}' is not available.\033[0m
|
\033[31mFailed to run because module '{e.name}' is not available.\033[0m
|
||||||
|
|
||||||
This script requires PyTorch extras to run. You can install these by running:
|
This script requires PyTorch extras to run. You can install these by running:
|
||||||
|
|
||||||
\033[1m{sys.executable} -m pip install 'eland[pytorch]'
|
\033[1m{sys.executable} -m pip install 'eland[pytorch]'
|
||||||
\033[0m"""))
|
\033[0m"""
|
||||||
|
)
|
||||||
|
)
|
||||||
exit(1)
|
exit(1)
|
||||||
assert SUPPORTED_TASK_TYPES
|
assert SUPPORTED_TASK_TYPES
|
||||||
|
|
||||||
@ -218,17 +229,33 @@ def main():
|
|||||||
|
|
||||||
# Trace and save model, then upload it from temp file
|
# Trace and save model, then upload it from temp file
|
||||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||||
logger.info(f"Loading HuggingFace transformer tokenizer and model '{args.hub_model_id}'")
|
logger.info(
|
||||||
|
f"Loading HuggingFace transformer tokenizer and model '{args.hub_model_id}'"
|
||||||
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
tm = TransformerModel(model_id=args.hub_model_id, task_type=args.task_type, es_version=cluster_version, quantize=args.quantize)
|
tm = TransformerModel(
|
||||||
|
model_id=args.hub_model_id,
|
||||||
|
task_type=args.task_type,
|
||||||
|
es_version=cluster_version,
|
||||||
|
quantize=args.quantize,
|
||||||
|
)
|
||||||
model_path, config, vocab_path = tm.save(tmp_dir)
|
model_path, config, vocab_path = tm.save(tmp_dir)
|
||||||
except TaskTypeError as err:
|
except TaskTypeError as err:
|
||||||
logger.error(f"Failed to get model for task type, please provide valid task type via '--task-type' parameter. Caused by {err}")
|
logger.error(
|
||||||
|
f"Failed to get model for task type, please provide valid task type via '--task-type' parameter. Caused by {err}"
|
||||||
|
)
|
||||||
exit(1)
|
exit(1)
|
||||||
|
|
||||||
ptm = PyTorchModel(es, args.es_model_id if args.es_model_id else tm.elasticsearch_model_id())
|
ptm = PyTorchModel(
|
||||||
model_exists = es.options(ignore_status=404).ml.get_trained_models(model_id=ptm.model_id).meta.status == 200
|
es, args.es_model_id if args.es_model_id else tm.elasticsearch_model_id()
|
||||||
|
)
|
||||||
|
model_exists = (
|
||||||
|
es.options(ignore_status=404)
|
||||||
|
.ml.get_trained_models(model_id=ptm.model_id)
|
||||||
|
.meta.status
|
||||||
|
== 200
|
||||||
|
)
|
||||||
|
|
||||||
if model_exists:
|
if model_exists:
|
||||||
if args.clear_previous:
|
if args.clear_previous:
|
||||||
@ -239,7 +266,9 @@ def main():
|
|||||||
ptm.delete()
|
ptm.delete()
|
||||||
else:
|
else:
|
||||||
logger.error(f"Trained model with id '{ptm.model_id}' already exists")
|
logger.error(f"Trained model with id '{ptm.model_id}' already exists")
|
||||||
logger.info("Run the script with the '--clear-previous' flag if you want to overwrite the existing model.")
|
logger.info(
|
||||||
|
"Run the script with the '--clear-previous' flag if you want to overwrite the existing model."
|
||||||
|
)
|
||||||
exit(1)
|
exit(1)
|
||||||
|
|
||||||
logger.info(f"Creating model with id '{ptm.model_id}'")
|
logger.info(f"Creating model with id '{ptm.model_id}'")
|
||||||
|
Loading…
x
Reference in New Issue
Block a user