Support typed 'elasticsearch-py' and add 'py.typed'

This commit is contained in:
Seth Michael Larson 2020-10-20 16:26:58 -05:00 committed by GitHub
parent 05a24cbe0b
commit bd7956ea72
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
9 changed files with 55 additions and 34 deletions

View File

@ -1,2 +1,3 @@
include LICENSE.txt
include README.md
recursive-include eland py.typed

View File

@ -22,7 +22,7 @@ from typing import Any, Callable, Dict, List, Optional, Tuple, Union, cast
import numpy as np # type: ignore
import pandas as pd # type: ignore
from elasticsearch import Elasticsearch # type: ignore
from elasticsearch import Elasticsearch
# Default number of rows displayed (different to pandas where ALL could be displayed)
DEFAULT_NUM_ROWS_DISPLAYED = 60
@ -86,7 +86,7 @@ class SortOrder(Enum):
def elasticsearch_date_to_pandas_date(
value: Union[int, str], date_format: str
value: Union[int, str], date_format: Optional[str]
) -> pd.Timestamp:
"""
Given a specific Elasticsearch format for a date datatype, returns the
@ -298,6 +298,7 @@ def es_version(es_client: Elasticsearch) -> Tuple[int, int, int]:
"""Tags the current ES client with a cached '_eland_es_version'
property if one doesn't exist yet for the current Elasticsearch version.
"""
eland_es_version: Tuple[int, int, int]
if not hasattr(es_client, "_eland_es_version"):
version_info = es_client.info()["version"]["number"]
match = re.match(r"^(\d+)\.(\d+)\.(\d+)", version_info)
@ -306,6 +307,10 @@ def es_version(es_client: Elasticsearch) -> Tuple[int, int, int]:
f"Unable to determine Elasticsearch version. "
f"Received: {version_info}"
)
major, minor, patch = [int(x) for x in match.groups()]
es_client._eland_es_version = (major, minor, patch)
return cast(Tuple[int, int, int], es_client._eland_es_version)
eland_es_version = cast(
Tuple[int, int, int], tuple([int(x) for x in match.groups()])
)
es_client._eland_es_version = eland_es_version # type: ignore
else:
eland_es_version = es_client._eland_es_version # type: ignore
return eland_es_version

View File

@ -20,8 +20,8 @@ from collections import deque
from typing import Any, Dict, Generator, List, Mapping, Optional, Tuple, Union
import pandas as pd # type: ignore
from elasticsearch import Elasticsearch # type: ignore
from elasticsearch.helpers import parallel_bulk # type: ignore
from elasticsearch import Elasticsearch
from elasticsearch.helpers import parallel_bulk
from pandas.io.parsers import _c_parser_defaults # type: ignore
from eland import DataFrame
@ -240,7 +240,7 @@ def pandas_to_eland(
pd_df, es_dropna, use_pandas_index_for_es_ids, es_dest_index
),
thread_count=thread_count,
chunk_size=chunksize / thread_count,
chunk_size=int(chunksize / thread_count),
),
maxlen=0,
)

View File

@ -18,7 +18,7 @@
import warnings
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union, cast
import elasticsearch # type: ignore
import elasticsearch
import numpy as np # type: ignore
from eland.common import ensure_es_client, es_version
@ -447,11 +447,13 @@ class MLModel:
# In Elasticsearch 7.7 and earlier you can't get
# target type without pulling the model definition
# so we check the version first.
kwargs = {}
if es_version(self._client) < (7, 8):
kwargs["include_model_definition"] = True
resp = self._client.ml.get_trained_models(
model_id=self._model_id, include_model_definition=True
)
else:
resp = self._client.ml.get_trained_models(model_id=self._model_id)
resp = self._client.ml.get_trained_models(model_id=self._model_id, **kwargs)
if resp["count"] > 1:
raise ValueError(f"Model ID {self._model_id!r} wasn't unambiguous")
elif resp["count"] == 0:

View File

@ -17,13 +17,15 @@
import sys
from abc import ABC, abstractmethod
from typing import TYPE_CHECKING, Optional, Tuple
from typing import TYPE_CHECKING, List, Optional, Tuple, Union
import pandas as pd
import pandas as pd # type: ignore
from eland.query_compiler import QueryCompiler
if TYPE_CHECKING:
from elasticsearch import Elasticsearch
from eland.index import Index
"""
@ -55,12 +57,14 @@ only Elasticsearch aggregatable fields can be aggregated or grouped.
class NDFrame(ABC):
def __init__(
self,
es_client=None,
es_index_pattern=None,
columns=None,
es_index_field=None,
_query_compiler=None,
):
es_client: Optional[
Union[str, List[str], Tuple[str, ...], "Elasticsearch"]
] = None,
es_index_pattern: Optional[str] = None,
columns: Optional[List[str]] = None,
es_index_field: Optional[str] = None,
_query_compiler: Optional[QueryCompiler] = None,
) -> None:
"""
pandas.DataFrame/Series like API that proxies into Elasticsearch index(es).
@ -134,7 +138,7 @@ class NDFrame(ABC):
return self._query_compiler.dtypes
@property
def es_dtypes(self):
def es_dtypes(self) -> pd.Series:
"""
Return the Elasticsearch dtypes in the index
@ -155,7 +159,7 @@ class NDFrame(ABC):
"""
return self._query_compiler.es_dtypes
def _build_repr(self, num_rows) -> pd.DataFrame:
def _build_repr(self, num_rows: int) -> pd.DataFrame:
# self could be Series or DataFrame
if len(self.index) <= num_rows:
return self.to_pandas()
@ -639,20 +643,25 @@ class NDFrame(ABC):
return self._query_compiler.describe()
@abstractmethod
def to_pandas(self, show_progress=False):
pass
def to_pandas(self, show_progress: bool = False) -> pd.DataFrame:
raise NotImplementedError
@abstractmethod
def head(self, n=5):
pass
def head(self, n: int = 5) -> "NDFrame":
raise NotImplementedError
@abstractmethod
def tail(self, n=5):
pass
def tail(self, n: int = 5) -> "NDFrame":
raise NotImplementedError
@abstractmethod
def sample(self, n=None, frac=None, random_state=None):
pass
def sample(
self,
n: Optional[int] = None,
frac: Optional[float] = None,
random_state: Optional[int] = None,
) -> "NDFrame":
raise NotImplementedError
@property
def shape(self) -> Tuple[int, ...]:

0
eland/py.typed Normal file
View File

View File

@ -94,11 +94,11 @@ class QueryCompiler:
self._operations = Operations()
@property
def index(self):
def index(self) -> Index:
return self._index
@property
def columns(self):
def columns(self) -> pd.Index:
columns = self._mappings.display_names
return pd.Index(columns)
@ -120,11 +120,11 @@ class QueryCompiler:
return result
@property
def dtypes(self):
def dtypes(self) -> pd.Series:
return self._mappings.dtypes()
@property
def es_dtypes(self):
def es_dtypes(self) -> pd.Series:
return self._mappings.es_dtypes()
# END Index, columns, and dtypes objects

View File

@ -68,6 +68,7 @@ def format(session):
@nox.session(reuse_venv=True)
def lint(session):
session.install("black", "flake8", "mypy", "isort")
session.install("--pre", "elasticsearch")
session.run("python", "utils/license-headers.py", "check", *SOURCE_FILES)
session.run("black", "--check", "--target-version=py36", *SOURCE_FILES)
session.run("isort", "--check", *SOURCE_FILES)

View File

@ -72,6 +72,9 @@ setup(
packages=find_packages(include=["eland", "eland.*"]),
install_requires=["elasticsearch>=7.7", "pandas>=1", "matplotlib", "numpy"],
python_requires=">=3.6",
package_data={"eland": ["py.typed"]},
include_package_data=True,
zip_safe=False,
extras_require={
"xgboost": ["xgboost>=0.90,<2"],
"scikit-learn": ["scikit-learn>=0.22.1,<1"],