mirror of
https://github.com/elastic/eland.git
synced 2025-07-11 00:02:14 +08:00
Minimize if main section (#554)
For migration from scripts to console_scripts in setup.py, the current long if __name__ == "__main__": section is a blocker because the console_scripts requires to specify a function as an entrypoint. Move the logic into a main() function.
This commit is contained in:
parent
bf3b092ed4
commit
55967a7324
@ -41,6 +41,8 @@ MODEL_HUB_URL = "https://huggingface.co"
|
||||
|
||||
|
||||
def get_arg_parser():
|
||||
from eland.ml.pytorch.transformers import SUPPORTED_TASK_TYPES
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
location_args = parser.add_mutually_exclusive_group(required=True)
|
||||
location_args.add_argument(
|
||||
@ -127,7 +129,7 @@ def get_arg_parser():
|
||||
return parser
|
||||
|
||||
|
||||
def get_es_client(cli_args):
|
||||
def get_es_client(cli_args, logger):
|
||||
try:
|
||||
es_args = {
|
||||
'request_timeout': 300,
|
||||
@ -159,7 +161,7 @@ def get_es_client(cli_args):
|
||||
exit(1)
|
||||
|
||||
|
||||
def check_cluster_version(es_client):
|
||||
def check_cluster_version(es_client, logger):
|
||||
es_info = es_client.info()
|
||||
logger.info(f"Connected to cluster named '{es_info['cluster_name']}' (version: {es_info['version']['number']})")
|
||||
|
||||
@ -180,7 +182,8 @@ def check_cluster_version(es_client):
|
||||
|
||||
return sem_ver
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
def main():
|
||||
# Configure logging
|
||||
logging.basicConfig(format='%(asctime)s %(levelname)s : %(message)s')
|
||||
logger = logging.getLogger(__name__)
|
||||
@ -202,14 +205,15 @@ if __name__ == "__main__":
|
||||
\033[1m{sys.executable} -m pip install 'eland[pytorch]'
|
||||
\033[0m"""))
|
||||
exit(1)
|
||||
assert SUPPORTED_TASK_TYPES
|
||||
|
||||
# Parse arguments
|
||||
args = get_arg_parser().parse_args()
|
||||
|
||||
# Connect to ES
|
||||
logger.info("Establishing connection to Elasticsearch")
|
||||
es = get_es_client(args)
|
||||
cluster_version = check_cluster_version(es)
|
||||
es = get_es_client(args, logger)
|
||||
cluster_version = check_cluster_version(es, logger)
|
||||
|
||||
# Trace and save model, then upload it from temp file
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
@ -254,3 +258,5 @@ if __name__ == "__main__":
|
||||
logger.info(f"Model successfully imported with id '{ptm.model_id}'")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
@ -494,7 +494,7 @@ class _TransformerTraceableModel(TraceableModel):
|
||||
class _TraceableClassificationModel(_TransformerTraceableModel, ABC):
|
||||
def classification_labels(self) -> Optional[List[str]]:
|
||||
id_label_items = self._model.config.id2label.items()
|
||||
labels = [v for _, v in sorted(id_label_items, key=lambda kv: kv[0])] # type: ignore
|
||||
labels = [v for _, v in sorted(id_label_items, key=lambda kv: kv[0])]
|
||||
|
||||
# Make classes like I-PER into I_PER which fits Java enumerations
|
||||
return [label.replace("-", "_") for label in labels]
|
||||
@ -636,7 +636,7 @@ class TransformerModel:
|
||||
|
||||
def _load_vocab(self) -> Dict[str, List[str]]:
|
||||
vocab_items = self._tokenizer.get_vocab().items()
|
||||
vocabulary = [k for k, _ in sorted(vocab_items, key=lambda kv: kv[1])] # type: ignore
|
||||
vocabulary = [k for k, _ in sorted(vocab_items, key=lambda kv: kv[1])]
|
||||
vocab_obj = {
|
||||
"vocabulary": vocabulary,
|
||||
}
|
||||
|
Loading…
x
Reference in New Issue
Block a user