Support typed 'elasticsearch-py' and add 'py.typed'

This commit is contained in:
Seth Michael Larson 2020-10-20 16:26:58 -05:00 committed by GitHub
parent 05a24cbe0b
commit bd7956ea72
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
9 changed files with 55 additions and 34 deletions

View File

@ -1,2 +1,3 @@
include LICENSE.txt include LICENSE.txt
include README.md include README.md
recursive-include eland py.typed

View File

@ -22,7 +22,7 @@ from typing import Any, Callable, Dict, List, Optional, Tuple, Union, cast
import numpy as np # type: ignore import numpy as np # type: ignore
import pandas as pd # type: ignore import pandas as pd # type: ignore
from elasticsearch import Elasticsearch # type: ignore from elasticsearch import Elasticsearch
# Default number of rows displayed (different to pandas where ALL could be displayed) # Default number of rows displayed (different to pandas where ALL could be displayed)
DEFAULT_NUM_ROWS_DISPLAYED = 60 DEFAULT_NUM_ROWS_DISPLAYED = 60
@ -86,7 +86,7 @@ class SortOrder(Enum):
def elasticsearch_date_to_pandas_date( def elasticsearch_date_to_pandas_date(
value: Union[int, str], date_format: str value: Union[int, str], date_format: Optional[str]
) -> pd.Timestamp: ) -> pd.Timestamp:
""" """
Given a specific Elasticsearch format for a date datatype, returns the Given a specific Elasticsearch format for a date datatype, returns the
@ -298,6 +298,7 @@ def es_version(es_client: Elasticsearch) -> Tuple[int, int, int]:
"""Tags the current ES client with a cached '_eland_es_version' """Tags the current ES client with a cached '_eland_es_version'
property if one doesn't exist yet for the current Elasticsearch version. property if one doesn't exist yet for the current Elasticsearch version.
""" """
eland_es_version: Tuple[int, int, int]
if not hasattr(es_client, "_eland_es_version"): if not hasattr(es_client, "_eland_es_version"):
version_info = es_client.info()["version"]["number"] version_info = es_client.info()["version"]["number"]
match = re.match(r"^(\d+)\.(\d+)\.(\d+)", version_info) match = re.match(r"^(\d+)\.(\d+)\.(\d+)", version_info)
@ -306,6 +307,10 @@ def es_version(es_client: Elasticsearch) -> Tuple[int, int, int]:
f"Unable to determine Elasticsearch version. " f"Unable to determine Elasticsearch version. "
f"Received: {version_info}" f"Received: {version_info}"
) )
major, minor, patch = [int(x) for x in match.groups()] eland_es_version = cast(
es_client._eland_es_version = (major, minor, patch) Tuple[int, int, int], tuple([int(x) for x in match.groups()])
return cast(Tuple[int, int, int], es_client._eland_es_version) )
es_client._eland_es_version = eland_es_version # type: ignore
else:
eland_es_version = es_client._eland_es_version # type: ignore
return eland_es_version

View File

@ -20,8 +20,8 @@ from collections import deque
from typing import Any, Dict, Generator, List, Mapping, Optional, Tuple, Union from typing import Any, Dict, Generator, List, Mapping, Optional, Tuple, Union
import pandas as pd # type: ignore import pandas as pd # type: ignore
from elasticsearch import Elasticsearch # type: ignore from elasticsearch import Elasticsearch
from elasticsearch.helpers import parallel_bulk # type: ignore from elasticsearch.helpers import parallel_bulk
from pandas.io.parsers import _c_parser_defaults # type: ignore from pandas.io.parsers import _c_parser_defaults # type: ignore
from eland import DataFrame from eland import DataFrame
@ -240,7 +240,7 @@ def pandas_to_eland(
pd_df, es_dropna, use_pandas_index_for_es_ids, es_dest_index pd_df, es_dropna, use_pandas_index_for_es_ids, es_dest_index
), ),
thread_count=thread_count, thread_count=thread_count,
chunk_size=chunksize / thread_count, chunk_size=int(chunksize / thread_count),
), ),
maxlen=0, maxlen=0,
) )

View File

@ -18,7 +18,7 @@
import warnings import warnings
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union, cast from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union, cast
import elasticsearch # type: ignore import elasticsearch
import numpy as np # type: ignore import numpy as np # type: ignore
from eland.common import ensure_es_client, es_version from eland.common import ensure_es_client, es_version
@ -447,11 +447,13 @@ class MLModel:
# In Elasticsearch 7.7 and earlier you can't get # In Elasticsearch 7.7 and earlier you can't get
# target type without pulling the model definition # target type without pulling the model definition
# so we check the version first. # so we check the version first.
kwargs = {}
if es_version(self._client) < (7, 8): if es_version(self._client) < (7, 8):
kwargs["include_model_definition"] = True resp = self._client.ml.get_trained_models(
model_id=self._model_id, include_model_definition=True
)
else:
resp = self._client.ml.get_trained_models(model_id=self._model_id)
resp = self._client.ml.get_trained_models(model_id=self._model_id, **kwargs)
if resp["count"] > 1: if resp["count"] > 1:
raise ValueError(f"Model ID {self._model_id!r} wasn't unambiguous") raise ValueError(f"Model ID {self._model_id!r} wasn't unambiguous")
elif resp["count"] == 0: elif resp["count"] == 0:

View File

@ -17,13 +17,15 @@
import sys import sys
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from typing import TYPE_CHECKING, Optional, Tuple from typing import TYPE_CHECKING, List, Optional, Tuple, Union
import pandas as pd import pandas as pd # type: ignore
from eland.query_compiler import QueryCompiler from eland.query_compiler import QueryCompiler
if TYPE_CHECKING: if TYPE_CHECKING:
from elasticsearch import Elasticsearch
from eland.index import Index from eland.index import Index
""" """
@ -55,12 +57,14 @@ only Elasticsearch aggregatable fields can be aggregated or grouped.
class NDFrame(ABC): class NDFrame(ABC):
def __init__( def __init__(
self, self,
es_client=None, es_client: Optional[
es_index_pattern=None, Union[str, List[str], Tuple[str, ...], "Elasticsearch"]
columns=None, ] = None,
es_index_field=None, es_index_pattern: Optional[str] = None,
_query_compiler=None, columns: Optional[List[str]] = None,
): es_index_field: Optional[str] = None,
_query_compiler: Optional[QueryCompiler] = None,
) -> None:
""" """
pandas.DataFrame/Series like API that proxies into Elasticsearch index(es). pandas.DataFrame/Series like API that proxies into Elasticsearch index(es).
@ -134,7 +138,7 @@ class NDFrame(ABC):
return self._query_compiler.dtypes return self._query_compiler.dtypes
@property @property
def es_dtypes(self): def es_dtypes(self) -> pd.Series:
""" """
Return the Elasticsearch dtypes in the index Return the Elasticsearch dtypes in the index
@ -155,7 +159,7 @@ class NDFrame(ABC):
""" """
return self._query_compiler.es_dtypes return self._query_compiler.es_dtypes
def _build_repr(self, num_rows) -> pd.DataFrame: def _build_repr(self, num_rows: int) -> pd.DataFrame:
# self could be Series or DataFrame # self could be Series or DataFrame
if len(self.index) <= num_rows: if len(self.index) <= num_rows:
return self.to_pandas() return self.to_pandas()
@ -639,20 +643,25 @@ class NDFrame(ABC):
return self._query_compiler.describe() return self._query_compiler.describe()
@abstractmethod @abstractmethod
def to_pandas(self, show_progress=False): def to_pandas(self, show_progress: bool = False) -> pd.DataFrame:
pass raise NotImplementedError
@abstractmethod @abstractmethod
def head(self, n=5): def head(self, n: int = 5) -> "NDFrame":
pass raise NotImplementedError
@abstractmethod @abstractmethod
def tail(self, n=5): def tail(self, n: int = 5) -> "NDFrame":
pass raise NotImplementedError
@abstractmethod @abstractmethod
def sample(self, n=None, frac=None, random_state=None): def sample(
pass self,
n: Optional[int] = None,
frac: Optional[float] = None,
random_state: Optional[int] = None,
) -> "NDFrame":
raise NotImplementedError
@property @property
def shape(self) -> Tuple[int, ...]: def shape(self) -> Tuple[int, ...]:

0
eland/py.typed Normal file
View File

View File

@ -94,11 +94,11 @@ class QueryCompiler:
self._operations = Operations() self._operations = Operations()
@property @property
def index(self): def index(self) -> Index:
return self._index return self._index
@property @property
def columns(self): def columns(self) -> pd.Index:
columns = self._mappings.display_names columns = self._mappings.display_names
return pd.Index(columns) return pd.Index(columns)
@ -120,11 +120,11 @@ class QueryCompiler:
return result return result
@property @property
def dtypes(self): def dtypes(self) -> pd.Series:
return self._mappings.dtypes() return self._mappings.dtypes()
@property @property
def es_dtypes(self): def es_dtypes(self) -> pd.Series:
return self._mappings.es_dtypes() return self._mappings.es_dtypes()
# END Index, columns, and dtypes objects # END Index, columns, and dtypes objects

View File

@ -68,6 +68,7 @@ def format(session):
@nox.session(reuse_venv=True) @nox.session(reuse_venv=True)
def lint(session): def lint(session):
session.install("black", "flake8", "mypy", "isort") session.install("black", "flake8", "mypy", "isort")
session.install("--pre", "elasticsearch")
session.run("python", "utils/license-headers.py", "check", *SOURCE_FILES) session.run("python", "utils/license-headers.py", "check", *SOURCE_FILES)
session.run("black", "--check", "--target-version=py36", *SOURCE_FILES) session.run("black", "--check", "--target-version=py36", *SOURCE_FILES)
session.run("isort", "--check", *SOURCE_FILES) session.run("isort", "--check", *SOURCE_FILES)

View File

@ -72,6 +72,9 @@ setup(
packages=find_packages(include=["eland", "eland.*"]), packages=find_packages(include=["eland", "eland.*"]),
install_requires=["elasticsearch>=7.7", "pandas>=1", "matplotlib", "numpy"], install_requires=["elasticsearch>=7.7", "pandas>=1", "matplotlib", "numpy"],
python_requires=">=3.6", python_requires=">=3.6",
package_data={"eland": ["py.typed"]},
include_package_data=True,
zip_safe=False,
extras_require={ extras_require={
"xgboost": ["xgboost>=0.90,<2"], "xgboost": ["xgboost>=0.90,<2"],
"scikit-learn": ["scikit-learn>=0.22.1,<1"], "scikit-learn": ["scikit-learn>=0.22.1,<1"],