[ML] Make eland_import_hub_model an installable script

This commit is contained in:
Benjamin Trent 2021-10-19 12:29:58 -04:00 committed by GitHub
parent 704c8982bc
commit d39c1cd784
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 17 additions and 6 deletions

View File

@ -1,3 +1,5 @@
#!/usr/bin/env python
# Licensed to Elasticsearch B.V. under one or more contributor
# license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright
@ -35,7 +37,7 @@ MODEL_HUB_URL = "https://huggingface.co"
def main():
parser = argparse.ArgumentParser(prog="upload_hub_model.py")
parser = argparse.ArgumentParser(prog="upload_hub_model")
parser.add_argument(
"--url",
required=True,
@ -59,7 +61,7 @@ def main():
"--task-type",
required=True,
choices=SUPPORTED_TASK_TYPES,
help="The task type that the model will be used for.",
help="The task type for the model usage.",
)
parser.add_argument(
"--quantize",
@ -73,13 +75,19 @@ def main():
default=False,
help="Start the model deployment after uploading. Default: False",
)
parser.add_argument(
"--clear-previous",
action="store_true",
default=False,
help="Should the model previously stored with `elasticsearch-model-id` be deleted"
)
args = parser.parse_args()
es = elasticsearch.Elasticsearch(args.url, timeout=300) # 5 minute timeout
# trace and save model, then upload it from temp file
with tempfile.TemporaryDirectory() as tmp_dir:
print("Loading HuggingFace transformer tokenizer and model")
print(f"Loading HuggingFace transformer tokenizer and model {args.hub_model_id}")
tm = TransformerModel(args.hub_model_id, args.task_type, args.quantize)
model_path, config_path, vocab_path = tm.save(tmp_dir)
@ -90,14 +98,16 @@ def main():
)
ptm = PyTorchModel(es, es_model_id)
ptm.stop()
ptm.delete()
if args.clear_previous:
print(f"Stopping previous deployment and deleting model: {ptm.model_id}")
ptm.stop()
ptm.delete()
print(f"Importing model: {ptm.model_id}")
ptm.import_model(model_path, config_path, vocab_path)
# start the deployed model
if args.start:
print("Starting model deployment")
print(f"Starting model deployment: {ptm.model_id}")
ptm.start()

View File

@ -76,6 +76,7 @@ setup(
"matplotlib",
"numpy",
],
scripts=["bin/eland_import_hub_model"],
python_requires=">=3.7",
package_data={"eland": ["py.typed"]},
include_package_data=True,