diff --git a/eland/ml/ml_model.py b/eland/ml/ml_model.py index e4f8d9e..42aa424 100644 --- a/eland/ml/ml_model.py +++ b/eland/ml/ml_model.py @@ -15,7 +15,7 @@ # specific language governing permissions and limitations # under the License. -from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union, cast +from typing import TYPE_CHECKING, Any, Dict, List, Mapping, Optional, Tuple, Union, cast import elasticsearch import numpy as np @@ -129,7 +129,7 @@ class MLModel: >>> # Delete model from Elasticsearch >>> es_model.delete_model() """ - docs = [] + docs: List[Mapping[str, Any]] = [] if isinstance(X, np.ndarray): def to_list_or_float(x: Any) -> Union[List[Any], float]: diff --git a/eland/ml/pytorch/_pytorch_model.py b/eland/ml/pytorch/_pytorch_model.py index 77d6eb7..4cf2a42 100644 --- a/eland/ml/pytorch/_pytorch_model.py +++ b/eland/ml/pytorch/_pytorch_model.py @@ -19,7 +19,7 @@ import base64 import json import math import os -from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Set, Tuple, Union +from typing import TYPE_CHECKING, Any, Iterable, List, Mapping, Set, Tuple, Union from tqdm.auto import tqdm # type: ignore @@ -96,7 +96,7 @@ class PyTorchModel: def infer( self, - docs: List[Dict[str, str]], + docs: List[Mapping[str, str]], timeout: str = DEFAULT_TIMEOUT, ) -> Any: return self._client.options(