Minimize if main section (#554)

For migration from scripts to console_scripts in setup.py,
the current long if __name__ == "__main__": section is a 
blocker because the console_scripts requires to specify a
function as an entrypoint.
Move the logic into a main() function.
This commit is contained in:
Youhei Sakurai 2023-07-05 18:49:16 +09:00 committed by GitHub
parent bf3b092ed4
commit 55967a7324
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 13 additions and 7 deletions

View File

@ -41,6 +41,8 @@ MODEL_HUB_URL = "https://huggingface.co"
def get_arg_parser():
from eland.ml.pytorch.transformers import SUPPORTED_TASK_TYPES
parser = argparse.ArgumentParser()
location_args = parser.add_mutually_exclusive_group(required=True)
location_args.add_argument(
@ -127,7 +129,7 @@ def get_arg_parser():
return parser
def get_es_client(cli_args):
def get_es_client(cli_args, logger):
try:
es_args = {
'request_timeout': 300,
@ -159,7 +161,7 @@ def get_es_client(cli_args):
exit(1)
def check_cluster_version(es_client):
def check_cluster_version(es_client, logger):
es_info = es_client.info()
logger.info(f"Connected to cluster named '{es_info['cluster_name']}' (version: {es_info['version']['number']})")
@ -180,7 +182,8 @@ def check_cluster_version(es_client):
return sem_ver
if __name__ == "__main__":
def main():
# Configure logging
logging.basicConfig(format='%(asctime)s %(levelname)s : %(message)s')
logger = logging.getLogger(__name__)
@ -202,14 +205,15 @@ if __name__ == "__main__":
\033[1m{sys.executable} -m pip install 'eland[pytorch]'
\033[0m"""))
exit(1)
assert SUPPORTED_TASK_TYPES
# Parse arguments
args = get_arg_parser().parse_args()
# Connect to ES
logger.info("Establishing connection to Elasticsearch")
es = get_es_client(args)
cluster_version = check_cluster_version(es)
es = get_es_client(args, logger)
cluster_version = check_cluster_version(es, logger)
# Trace and save model, then upload it from temp file
with tempfile.TemporaryDirectory() as tmp_dir:
@ -254,3 +258,5 @@ if __name__ == "__main__":
logger.info(f"Model successfully imported with id '{ptm.model_id}'")
if __name__ == "__main__":
main()

View File

@ -494,7 +494,7 @@ class _TransformerTraceableModel(TraceableModel):
class _TraceableClassificationModel(_TransformerTraceableModel, ABC):
def classification_labels(self) -> Optional[List[str]]:
id_label_items = self._model.config.id2label.items()
labels = [v for _, v in sorted(id_label_items, key=lambda kv: kv[0])] # type: ignore
labels = [v for _, v in sorted(id_label_items, key=lambda kv: kv[0])]
# Make classes like I-PER into I_PER which fits Java enumerations
return [label.replace("-", "_") for label in labels]
@ -636,7 +636,7 @@ class TransformerModel:
def _load_vocab(self) -> Dict[str, List[str]]:
vocab_items = self._tokenizer.get_vocab().items()
vocabulary = [k for k, _ in sorted(vocab_items, key=lambda kv: kv[1])] # type: ignore
vocabulary = [k for k, _ in sorted(vocab_items, key=lambda kv: kv[1])]
vocab_obj = {
"vocabulary": vocabulary,
}