mirror of
https://github.com/elastic/eland.git
synced 2025-07-11 00:02:14 +08:00
Support typed 'elasticsearch-py' and add 'py.typed'
This commit is contained in:
parent
05a24cbe0b
commit
bd7956ea72
@ -1,2 +1,3 @@
|
||||
include LICENSE.txt
|
||||
include README.md
|
||||
recursive-include eland py.typed
|
||||
|
@ -22,7 +22,7 @@ from typing import Any, Callable, Dict, List, Optional, Tuple, Union, cast
|
||||
|
||||
import numpy as np # type: ignore
|
||||
import pandas as pd # type: ignore
|
||||
from elasticsearch import Elasticsearch # type: ignore
|
||||
from elasticsearch import Elasticsearch
|
||||
|
||||
# Default number of rows displayed (different to pandas where ALL could be displayed)
|
||||
DEFAULT_NUM_ROWS_DISPLAYED = 60
|
||||
@ -86,7 +86,7 @@ class SortOrder(Enum):
|
||||
|
||||
|
||||
def elasticsearch_date_to_pandas_date(
|
||||
value: Union[int, str], date_format: str
|
||||
value: Union[int, str], date_format: Optional[str]
|
||||
) -> pd.Timestamp:
|
||||
"""
|
||||
Given a specific Elasticsearch format for a date datatype, returns the
|
||||
@ -298,6 +298,7 @@ def es_version(es_client: Elasticsearch) -> Tuple[int, int, int]:
|
||||
"""Tags the current ES client with a cached '_eland_es_version'
|
||||
property if one doesn't exist yet for the current Elasticsearch version.
|
||||
"""
|
||||
eland_es_version: Tuple[int, int, int]
|
||||
if not hasattr(es_client, "_eland_es_version"):
|
||||
version_info = es_client.info()["version"]["number"]
|
||||
match = re.match(r"^(\d+)\.(\d+)\.(\d+)", version_info)
|
||||
@ -306,6 +307,10 @@ def es_version(es_client: Elasticsearch) -> Tuple[int, int, int]:
|
||||
f"Unable to determine Elasticsearch version. "
|
||||
f"Received: {version_info}"
|
||||
)
|
||||
major, minor, patch = [int(x) for x in match.groups()]
|
||||
es_client._eland_es_version = (major, minor, patch)
|
||||
return cast(Tuple[int, int, int], es_client._eland_es_version)
|
||||
eland_es_version = cast(
|
||||
Tuple[int, int, int], tuple([int(x) for x in match.groups()])
|
||||
)
|
||||
es_client._eland_es_version = eland_es_version # type: ignore
|
||||
else:
|
||||
eland_es_version = es_client._eland_es_version # type: ignore
|
||||
return eland_es_version
|
||||
|
@ -20,8 +20,8 @@ from collections import deque
|
||||
from typing import Any, Dict, Generator, List, Mapping, Optional, Tuple, Union
|
||||
|
||||
import pandas as pd # type: ignore
|
||||
from elasticsearch import Elasticsearch # type: ignore
|
||||
from elasticsearch.helpers import parallel_bulk # type: ignore
|
||||
from elasticsearch import Elasticsearch
|
||||
from elasticsearch.helpers import parallel_bulk
|
||||
from pandas.io.parsers import _c_parser_defaults # type: ignore
|
||||
|
||||
from eland import DataFrame
|
||||
@ -240,7 +240,7 @@ def pandas_to_eland(
|
||||
pd_df, es_dropna, use_pandas_index_for_es_ids, es_dest_index
|
||||
),
|
||||
thread_count=thread_count,
|
||||
chunk_size=chunksize / thread_count,
|
||||
chunk_size=int(chunksize / thread_count),
|
||||
),
|
||||
maxlen=0,
|
||||
)
|
||||
|
@ -18,7 +18,7 @@
|
||||
import warnings
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union, cast
|
||||
|
||||
import elasticsearch # type: ignore
|
||||
import elasticsearch
|
||||
import numpy as np # type: ignore
|
||||
|
||||
from eland.common import ensure_es_client, es_version
|
||||
@ -447,11 +447,13 @@ class MLModel:
|
||||
# In Elasticsearch 7.7 and earlier you can't get
|
||||
# target type without pulling the model definition
|
||||
# so we check the version first.
|
||||
kwargs = {}
|
||||
if es_version(self._client) < (7, 8):
|
||||
kwargs["include_model_definition"] = True
|
||||
resp = self._client.ml.get_trained_models(
|
||||
model_id=self._model_id, include_model_definition=True
|
||||
)
|
||||
else:
|
||||
resp = self._client.ml.get_trained_models(model_id=self._model_id)
|
||||
|
||||
resp = self._client.ml.get_trained_models(model_id=self._model_id, **kwargs)
|
||||
if resp["count"] > 1:
|
||||
raise ValueError(f"Model ID {self._model_id!r} wasn't unambiguous")
|
||||
elif resp["count"] == 0:
|
||||
|
@ -17,13 +17,15 @@
|
||||
|
||||
import sys
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import TYPE_CHECKING, Optional, Tuple
|
||||
from typing import TYPE_CHECKING, List, Optional, Tuple, Union
|
||||
|
||||
import pandas as pd
|
||||
import pandas as pd # type: ignore
|
||||
|
||||
from eland.query_compiler import QueryCompiler
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from elasticsearch import Elasticsearch
|
||||
|
||||
from eland.index import Index
|
||||
|
||||
"""
|
||||
@ -55,12 +57,14 @@ only Elasticsearch aggregatable fields can be aggregated or grouped.
|
||||
class NDFrame(ABC):
|
||||
def __init__(
|
||||
self,
|
||||
es_client=None,
|
||||
es_index_pattern=None,
|
||||
columns=None,
|
||||
es_index_field=None,
|
||||
_query_compiler=None,
|
||||
):
|
||||
es_client: Optional[
|
||||
Union[str, List[str], Tuple[str, ...], "Elasticsearch"]
|
||||
] = None,
|
||||
es_index_pattern: Optional[str] = None,
|
||||
columns: Optional[List[str]] = None,
|
||||
es_index_field: Optional[str] = None,
|
||||
_query_compiler: Optional[QueryCompiler] = None,
|
||||
) -> None:
|
||||
"""
|
||||
pandas.DataFrame/Series like API that proxies into Elasticsearch index(es).
|
||||
|
||||
@ -134,7 +138,7 @@ class NDFrame(ABC):
|
||||
return self._query_compiler.dtypes
|
||||
|
||||
@property
|
||||
def es_dtypes(self):
|
||||
def es_dtypes(self) -> pd.Series:
|
||||
"""
|
||||
Return the Elasticsearch dtypes in the index
|
||||
|
||||
@ -155,7 +159,7 @@ class NDFrame(ABC):
|
||||
"""
|
||||
return self._query_compiler.es_dtypes
|
||||
|
||||
def _build_repr(self, num_rows) -> pd.DataFrame:
|
||||
def _build_repr(self, num_rows: int) -> pd.DataFrame:
|
||||
# self could be Series or DataFrame
|
||||
if len(self.index) <= num_rows:
|
||||
return self.to_pandas()
|
||||
@ -639,20 +643,25 @@ class NDFrame(ABC):
|
||||
return self._query_compiler.describe()
|
||||
|
||||
@abstractmethod
|
||||
def to_pandas(self, show_progress=False):
|
||||
pass
|
||||
def to_pandas(self, show_progress: bool = False) -> pd.DataFrame:
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def head(self, n=5):
|
||||
pass
|
||||
def head(self, n: int = 5) -> "NDFrame":
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def tail(self, n=5):
|
||||
pass
|
||||
def tail(self, n: int = 5) -> "NDFrame":
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def sample(self, n=None, frac=None, random_state=None):
|
||||
pass
|
||||
def sample(
|
||||
self,
|
||||
n: Optional[int] = None,
|
||||
frac: Optional[float] = None,
|
||||
random_state: Optional[int] = None,
|
||||
) -> "NDFrame":
|
||||
raise NotImplementedError
|
||||
|
||||
@property
|
||||
def shape(self) -> Tuple[int, ...]:
|
||||
|
0
eland/py.typed
Normal file
0
eland/py.typed
Normal file
@ -94,11 +94,11 @@ class QueryCompiler:
|
||||
self._operations = Operations()
|
||||
|
||||
@property
|
||||
def index(self):
|
||||
def index(self) -> Index:
|
||||
return self._index
|
||||
|
||||
@property
|
||||
def columns(self):
|
||||
def columns(self) -> pd.Index:
|
||||
columns = self._mappings.display_names
|
||||
|
||||
return pd.Index(columns)
|
||||
@ -120,11 +120,11 @@ class QueryCompiler:
|
||||
return result
|
||||
|
||||
@property
|
||||
def dtypes(self):
|
||||
def dtypes(self) -> pd.Series:
|
||||
return self._mappings.dtypes()
|
||||
|
||||
@property
|
||||
def es_dtypes(self):
|
||||
def es_dtypes(self) -> pd.Series:
|
||||
return self._mappings.es_dtypes()
|
||||
|
||||
# END Index, columns, and dtypes objects
|
||||
|
@ -68,6 +68,7 @@ def format(session):
|
||||
@nox.session(reuse_venv=True)
|
||||
def lint(session):
|
||||
session.install("black", "flake8", "mypy", "isort")
|
||||
session.install("--pre", "elasticsearch")
|
||||
session.run("python", "utils/license-headers.py", "check", *SOURCE_FILES)
|
||||
session.run("black", "--check", "--target-version=py36", *SOURCE_FILES)
|
||||
session.run("isort", "--check", *SOURCE_FILES)
|
||||
|
3
setup.py
3
setup.py
@ -72,6 +72,9 @@ setup(
|
||||
packages=find_packages(include=["eland", "eland.*"]),
|
||||
install_requires=["elasticsearch>=7.7", "pandas>=1", "matplotlib", "numpy"],
|
||||
python_requires=">=3.6",
|
||||
package_data={"eland": ["py.typed"]},
|
||||
include_package_data=True,
|
||||
zip_safe=False,
|
||||
extras_require={
|
||||
"xgboost": ["xgboost>=0.90,<2"],
|
||||
"scikit-learn": ["scikit-learn>=0.22.1,<1"],
|
||||
|
Loading…
x
Reference in New Issue
Block a user