Support the v8.0 Elasticsearch client

This commit is contained in:
Seth Michael Larson 2021-12-09 15:01:26 -06:00 committed by GitHub
parent 1ffbe002c4
commit 109387184a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
33 changed files with 2599 additions and 464 deletions

View File

@ -11,7 +11,7 @@ without overloading your machine.
------------------------------------- -------------------------------------
>>> import eland as ed >>> import eland as ed
>>> # Connect to 'flights' index via localhost Elasticsearch node >>> # Connect to 'flights' index via localhost Elasticsearch node
>>> df = ed.DataFrame('localhost:9200', 'flights') >>> df = ed.DataFrame('http://localhost:9200', 'flights')
# eland.DataFrame instance has the same API as pandas.DataFrame # eland.DataFrame instance has the same API as pandas.DataFrame
# except all data is in Elasticsearch. See .info() memory usage. # except all data is in Elasticsearch. See .info() memory usage.

View File

@ -19,7 +19,7 @@ model in Elasticsearch
# Import the model into Elasticsearch # Import the model into Elasticsearch
>>> es_model = MLModel.import_model( >>> es_model = MLModel.import_model(
es_client="localhost:9200", es_client="http://localhost:9200",
model_id="xgb-classifier", model_id="xgb-classifier",
model=xgb_model, model=xgb_model,
feature_names=["f0", "f1", "f2", "f3", "f4"], feature_names=["f0", "f1", "f2", "f3", "f4"],

View File

@ -20,13 +20,13 @@ The recommended way to set your requirements in your `setup.py` or
[discrete] [discrete]
=== Getting Started === Getting Started
Create a `DataFrame` object connected to an {es} cluster running on `localhost:9200`: Create a `DataFrame` object connected to an {es} cluster running on `http://localhost:9200`:
[source,python] [source,python]
------------------------------------ ------------------------------------
>>> import eland as ed >>> import eland as ed
>>> df = ed.DataFrame( >>> df = ed.DataFrame(
... es_client="localhost:9200", ... es_client="http://localhost:9200",
... es_index_pattern="flights", ... es_index_pattern="flights",
... ) ... )
>>> df >>> df

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

View File

@ -309,8 +309,10 @@ def elasticsearch_date_to_pandas_date(
def ensure_es_client( def ensure_es_client(
es_client: Union[str, List[str], Tuple[str, ...], Elasticsearch] es_client: Union[str, List[str], Tuple[str, ...], Elasticsearch]
) -> Elasticsearch: ) -> Elasticsearch:
if isinstance(es_client, tuple):
es_client = list(es_client)
if not isinstance(es_client, Elasticsearch): if not isinstance(es_client, Elasticsearch):
es_client = Elasticsearch(es_client) es_client = Elasticsearch(es_client) # type: ignore[arg-type]
return es_client return es_client
@ -334,16 +336,3 @@ def es_version(es_client: Elasticsearch) -> Tuple[int, int, int]:
else: else:
eland_es_version = es_client._eland_es_version # type: ignore eland_es_version = es_client._eland_es_version # type: ignore
return eland_es_version return eland_es_version
def es_api_compat(
method: Callable[..., Dict[str, Any]], **kwargs: Any
) -> Dict[str, Any]:
"""Expands the 'body' parameter to top-level parameters
on clients that would raise DeprecationWarnings if used.
"""
if ES_CLIENT_HAS_V8_0_DEPRECATIONS:
body = kwargs.pop("body", None)
if body:
kwargs.update(body)
return method(**kwargs)

View File

@ -56,7 +56,7 @@ class DataFrame(NDFrame):
Parameters Parameters
---------- ----------
es_client: Elasticsearch client argument(s) (e.g. 'localhost:9200') es_client: Elasticsearch client argument(s) (e.g. 'http://localhost:9200')
- elasticsearch-py parameters or - elasticsearch-py parameters or
- elasticsearch-py instance - elasticsearch-py instance
es_index_pattern: str es_index_pattern: str
@ -74,7 +74,7 @@ class DataFrame(NDFrame):
-------- --------
Constructing DataFrame from an Elasticsearch configuration arguments and an Elasticsearch index Constructing DataFrame from an Elasticsearch configuration arguments and an Elasticsearch index
>>> df = ed.DataFrame('localhost:9200', 'flights') >>> df = ed.DataFrame('http://localhost:9200', 'flights')
>>> df.head() >>> df.head()
AvgTicketPrice Cancelled ... dayOfWeek timestamp AvgTicketPrice Cancelled ... dayOfWeek timestamp
0 841.265642 False ... 0 2018-01-01 00:00:00 0 841.265642 False ... 0 2018-01-01 00:00:00
@ -89,7 +89,7 @@ class DataFrame(NDFrame):
Constructing DataFrame from an Elasticsearch client and an Elasticsearch index Constructing DataFrame from an Elasticsearch client and an Elasticsearch index
>>> from elasticsearch import Elasticsearch >>> from elasticsearch import Elasticsearch
>>> es = Elasticsearch("localhost:9200") >>> es = Elasticsearch("http://localhost:9200")
>>> df = ed.DataFrame(es_client=es, es_index_pattern='flights', columns=['AvgTicketPrice', 'Cancelled']) >>> df = ed.DataFrame(es_client=es, es_index_pattern='flights', columns=['AvgTicketPrice', 'Cancelled'])
>>> df.head() >>> df.head()
AvgTicketPrice Cancelled AvgTicketPrice Cancelled
@ -106,7 +106,7 @@ class DataFrame(NDFrame):
(TODO - currently index_field must also be a field if not _id) (TODO - currently index_field must also be a field if not _id)
>>> df = ed.DataFrame( >>> df = ed.DataFrame(
... es_client='localhost', ... es_client='http://localhost:9200',
... es_index_pattern='flights', ... es_index_pattern='flights',
... columns=['AvgTicketPrice', 'timestamp'], ... columns=['AvgTicketPrice', 'timestamp'],
... es_index_field='timestamp' ... es_index_field='timestamp'
@ -170,7 +170,7 @@ class DataFrame(NDFrame):
Examples Examples
-------- --------
>>> df = ed.DataFrame('localhost', 'flights') >>> df = ed.DataFrame('http://localhost:9200', 'flights')
>>> assert isinstance(df.columns, pd.Index) >>> assert isinstance(df.columns, pd.Index)
>>> df.columns >>> df.columns
Index(['AvgTicketPrice', 'Cancelled', 'Carrier', 'Dest', 'DestAirportID', 'DestCityName', Index(['AvgTicketPrice', 'Cancelled', 'Carrier', 'Dest', 'DestAirportID', 'DestCityName',
@ -198,7 +198,7 @@ class DataFrame(NDFrame):
Examples Examples
-------- --------
>>> df = ed.DataFrame('localhost', 'flights') >>> df = ed.DataFrame('http://localhost:9200', 'flights')
>>> df.empty >>> df.empty
False False
""" """
@ -228,7 +228,7 @@ class DataFrame(NDFrame):
Examples Examples
-------- --------
>>> df = ed.DataFrame('localhost', 'flights', columns=['Origin', 'Dest']) >>> df = ed.DataFrame('http://localhost:9200', 'flights', columns=['Origin', 'Dest'])
>>> df.head(3) >>> df.head(3)
Origin Dest Origin Dest
0 Frankfurt am Main Airport Sydney Kingsford Smith International Airport 0 Frankfurt am Main Airport Sydney Kingsford Smith International Airport
@ -263,7 +263,7 @@ class DataFrame(NDFrame):
Examples Examples
-------- --------
>>> df = ed.DataFrame('localhost', 'flights', columns=['Origin', 'Dest']) >>> df = ed.DataFrame('http://localhost:9200', 'flights', columns=['Origin', 'Dest'])
>>> df.tail() >>> df.tail()
Origin \\ Origin \\
13054 Pisa International Airport... 13054 Pisa International Airport...
@ -365,7 +365,7 @@ class DataFrame(NDFrame):
-------- --------
Drop a column Drop a column
>>> df = ed.DataFrame('localhost', 'ecommerce', columns=['customer_first_name', 'email', 'user']) >>> df = ed.DataFrame('http://localhost:9200', 'ecommerce', columns=['customer_first_name', 'email', 'user'])
>>> df.drop(columns=['user']) >>> df.drop(columns=['user'])
customer_first_name email customer_first_name email
0 Eddie eddie@underwood-family.zzz 0 Eddie eddie@underwood-family.zzz
@ -575,7 +575,7 @@ class DataFrame(NDFrame):
Examples Examples
-------- --------
>>> df = ed.DataFrame('localhost', 'ecommerce', columns=['customer_first_name', 'geoip.city_name']) >>> df = ed.DataFrame('http://localhost:9200', 'ecommerce', columns=['customer_first_name', 'geoip.city_name'])
>>> df.count() >>> df.count()
customer_first_name 4675 customer_first_name 4675
geoip.city_name 4094 geoip.city_name 4094
@ -597,7 +597,7 @@ class DataFrame(NDFrame):
Examples Examples
-------- --------
>>> df = ed.DataFrame('localhost', 'flights') >>> df = ed.DataFrame('http://localhost:9200', 'flights')
>>> df = df[(df.OriginAirportID == 'AMS') & (df.FlightDelayMin > 60)] >>> df = df[(df.OriginAirportID == 'AMS') & (df.FlightDelayMin > 60)]
>>> df = df[['timestamp', 'OriginAirportID', 'DestAirportID', 'FlightDelayMin']] >>> df = df[['timestamp', 'OriginAirportID', 'DestAirportID', 'FlightDelayMin']]
>>> df = df.tail() >>> df = df.tail()
@ -692,7 +692,7 @@ class DataFrame(NDFrame):
Examples Examples
-------- --------
>>> df = ed.DataFrame("localhost:9200", "ecommerce") >>> df = ed.DataFrame("http://localhost:9200", "ecommerce")
>>> df.es_match("Men's", columns=["category"]) >>> df.es_match("Men's", columns=["category"])
category currency ... type user category currency ... type user
0 [Men's Clothing] EUR ... order eddie 0 [Men's Clothing] EUR ... order eddie
@ -754,7 +754,7 @@ class DataFrame(NDFrame):
.. _geo-distance query: https://www.elastic.co/guide/en/elasticsearch/reference/current/query-dsl-geo-distance-query.html .. _geo-distance query: https://www.elastic.co/guide/en/elasticsearch/reference/current/query-dsl-geo-distance-query.html
>>> df = ed.DataFrame('localhost', 'ecommerce', columns=['customer_first_name', 'geoip.city_name']) >>> df = ed.DataFrame('http://localhost:9200', 'ecommerce', columns=['customer_first_name', 'geoip.city_name'])
>>> df.es_query({"bool": {"filter": {"geo_distance": {"distance": "1km", "geoip.location": [55.3, 25.3]}}}}).head() >>> df.es_query({"bool": {"filter": {"geo_distance": {"distance": "1km", "geoip.location": [55.3, 25.3]}}}}).head()
customer_first_name geoip.city_name customer_first_name geoip.city_name
1 Mary Dubai 1 Mary Dubai
@ -830,7 +830,7 @@ class DataFrame(NDFrame):
Examples Examples
-------- --------
>>> df = ed.DataFrame('localhost', 'ecommerce', columns=['customer_first_name', 'geoip.city_name']) >>> df = ed.DataFrame('http://localhost:9200', 'ecommerce', columns=['customer_first_name', 'geoip.city_name'])
>>> df.info() >>> df.info()
<class 'eland.dataframe.DataFrame'> <class 'eland.dataframe.DataFrame'>
Index: 4675 entries, 0 to 4674 Index: 4675 entries, 0 to 4674
@ -1366,7 +1366,7 @@ class DataFrame(NDFrame):
Examples Examples
-------- --------
>>> df = ed.DataFrame('localhost', 'flights', >>> df = ed.DataFrame('http://localhost:9200', 'flights',
... columns=['AvgTicketPrice', 'Dest', 'Cancelled', 'timestamp', 'dayOfWeek']) ... columns=['AvgTicketPrice', 'Dest', 'Cancelled', 'timestamp', 'dayOfWeek'])
>>> df.dtypes >>> df.dtypes
AvgTicketPrice float64 AvgTicketPrice float64
@ -1407,7 +1407,7 @@ class DataFrame(NDFrame):
Examples Examples
-------- --------
>>> df = ed.DataFrame('localhost', 'ecommerce') >>> df = ed.DataFrame('http://localhost:9200', 'ecommerce')
>>> df.shape >>> df.shape
(4675, 45) (4675, 45)
""" """
@ -1462,7 +1462,7 @@ class DataFrame(NDFrame):
Examples Examples
-------- --------
>>> df = ed.DataFrame('localhost:9200', 'flights', columns=['AvgTicketPrice', 'Cancelled']).head() >>> df = ed.DataFrame('http://localhost:9200', 'flights', columns=['AvgTicketPrice', 'Cancelled']).head()
>>> df >>> df
AvgTicketPrice Cancelled AvgTicketPrice Cancelled
0 841.265642 False 0 841.265642 False
@ -1520,7 +1520,7 @@ class DataFrame(NDFrame):
Examples Examples
-------- --------
>>> df = ed.DataFrame('localhost:9200', 'flights', columns=['AvgTicketPrice', 'Cancelled']).head() >>> df = ed.DataFrame('http://localhost:9200', 'flights', columns=['AvgTicketPrice', 'Cancelled']).head()
>>> df >>> df
AvgTicketPrice Cancelled AvgTicketPrice Cancelled
0 841.265642 False 0 841.265642 False
@ -1614,7 +1614,7 @@ class DataFrame(NDFrame):
Examples Examples
-------- --------
>>> df = ed.DataFrame('localhost', 'flights', columns=['AvgTicketPrice', 'DistanceKilometers', 'timestamp', 'DestCountry']) >>> df = ed.DataFrame('http://localhost:9200', 'flights', columns=['AvgTicketPrice', 'DistanceKilometers', 'timestamp', 'DestCountry'])
>>> df.aggregate(['sum', 'min', 'std'], numeric_only=True).astype(int) >>> df.aggregate(['sum', 'min', 'std'], numeric_only=True).astype(int)
AvgTicketPrice DistanceKilometers AvgTicketPrice DistanceKilometers
sum 8204364 92616288 sum 8204364 92616288
@ -1689,7 +1689,7 @@ class DataFrame(NDFrame):
Examples Examples
-------- --------
>>> ed_flights = ed.DataFrame('localhost', 'flights', columns=["AvgTicketPrice", "Cancelled", "dayOfWeek", "timestamp", "DestCountry"]) >>> ed_flights = ed.DataFrame('http://localhost:9200', 'flights', columns=["AvgTicketPrice", "Cancelled", "dayOfWeek", "timestamp", "DestCountry"])
>>> ed_flights.groupby(["DestCountry", "Cancelled"]).agg(["min", "max"], numeric_only=True) # doctest: +NORMALIZE_WHITESPACE >>> ed_flights.groupby(["DestCountry", "Cancelled"]).agg(["min", "max"], numeric_only=True) # doctest: +NORMALIZE_WHITESPACE
AvgTicketPrice dayOfWeek AvgTicketPrice dayOfWeek
min max min max min max min max
@ -1784,7 +1784,7 @@ class DataFrame(NDFrame):
Examples Examples
-------- --------
>>> ed_ecommerce = ed.DataFrame('localhost', 'ecommerce') >>> ed_ecommerce = ed.DataFrame('http://localhost:9200', 'ecommerce')
>>> ed_df = ed_ecommerce.filter(["total_quantity", "geoip.city_name", "customer_birth_date", "day_of_week", "taxful_total_price"]) >>> ed_df = ed_ecommerce.filter(["total_quantity", "geoip.city_name", "customer_birth_date", "day_of_week", "taxful_total_price"])
>>> ed_df.mode(numeric_only=False) >>> ed_df.mode(numeric_only=False)
total_quantity geoip.city_name customer_birth_date day_of_week taxful_total_price total_quantity geoip.city_name customer_birth_date day_of_week taxful_total_price
@ -1849,7 +1849,7 @@ class DataFrame(NDFrame):
Examples Examples
-------- --------
>>> ed_df = ed.DataFrame('localhost', 'flights') >>> ed_df = ed.DataFrame('http://localhost:9200', 'flights')
>>> ed_flights = ed_df.filter(["AvgTicketPrice", "FlightDelayMin", "dayOfWeek", "timestamp"]) >>> ed_flights = ed_df.filter(["AvgTicketPrice", "FlightDelayMin", "dayOfWeek", "timestamp"])
>>> ed_flights.quantile() # doctest: +SKIP >>> ed_flights.quantile() # doctest: +SKIP
AvgTicketPrice 640.387285 AvgTicketPrice 640.387285
@ -1892,7 +1892,7 @@ class DataFrame(NDFrame):
Examples Examples
-------- --------
>>> ed_df = ed.DataFrame('localhost', 'flights') >>> ed_df = ed.DataFrame('http://localhost:9200', 'flights')
>>> ed_flights = ed_df.filter(["AvgTicketPrice", "FlightDelayMin", "dayOfWeek", "timestamp"]) >>> ed_flights = ed_df.filter(["AvgTicketPrice", "FlightDelayMin", "dayOfWeek", "timestamp"])
>>> ed_flights.idxmax() >>> ed_flights.idxmax()
AvgTicketPrice 1843 AvgTicketPrice 1843
@ -1924,7 +1924,7 @@ class DataFrame(NDFrame):
Examples Examples
-------- --------
>>> ed_df = ed.DataFrame('localhost', 'flights') >>> ed_df = ed.DataFrame('http://localhost:9200', 'flights')
>>> ed_flights = ed_df.filter(["AvgTicketPrice", "FlightDelayMin", "dayOfWeek", "timestamp"]) >>> ed_flights = ed_df.filter(["AvgTicketPrice", "FlightDelayMin", "dayOfWeek", "timestamp"])
>>> ed_flights.idxmin() >>> ed_flights.idxmin()
AvgTicketPrice 5454 AvgTicketPrice 5454
@ -1960,7 +1960,7 @@ class DataFrame(NDFrame):
Examples Examples
-------- --------
>>> df = ed.DataFrame('localhost', 'flights') >>> df = ed.DataFrame('http://localhost:9200', 'flights')
>>> df.shape >>> df.shape
(13059, 27) (13059, 27)
>>> df.query('FlightDelayMin > 60').shape >>> df.query('FlightDelayMin > 60').shape
@ -2004,7 +2004,7 @@ class DataFrame(NDFrame):
Examples Examples
-------- --------
>>> df = ed.DataFrame('localhost', 'flights') >>> df = ed.DataFrame('http://localhost:9200', 'flights')
>>> df.get('Carrier') >>> df.get('Carrier')
0 Kibana Airlines 0 Kibana Airlines
1 Logstash Airways 1 Logstash Airways
@ -2135,7 +2135,7 @@ class DataFrame(NDFrame):
Examples Examples
-------- --------
>>> ed_df = ed.DataFrame('localhost', 'flights', columns=['AvgTicketPrice', 'Carrier']).head(5) >>> ed_df = ed.DataFrame('http://localhost:9200', 'flights', columns=['AvgTicketPrice', 'Carrier']).head(5)
>>> pd_df = ed.eland_to_pandas(ed_df) >>> pd_df = ed.eland_to_pandas(ed_df)
>>> print(f"type(ed_df)={type(ed_df)}\\ntype(pd_df)={type(pd_df)}") >>> print(f"type(ed_df)={type(ed_df)}\\ntype(pd_df)={type(pd_df)}")
type(ed_df)=<class 'eland.dataframe.DataFrame'> type(ed_df)=<class 'eland.dataframe.DataFrame'>

View File

@ -24,12 +24,7 @@ from elasticsearch import Elasticsearch
from elasticsearch.helpers import parallel_bulk from elasticsearch.helpers import parallel_bulk
from eland import DataFrame from eland import DataFrame
from eland.common import ( from eland.common import DEFAULT_CHUNK_SIZE, PANDAS_VERSION, ensure_es_client
DEFAULT_CHUNK_SIZE,
PANDAS_VERSION,
ensure_es_client,
es_api_compat,
)
from eland.field_mappings import FieldMappings, verify_mapping_compatibility from eland.field_mappings import FieldMappings, verify_mapping_compatibility
try: try:
@ -128,7 +123,7 @@ def pandas_to_eland(
>>> ed_df = ed.pandas_to_eland(pd_df, >>> ed_df = ed.pandas_to_eland(pd_df,
... 'localhost', ... 'http://localhost:9200',
... 'pandas_to_eland', ... 'pandas_to_eland',
... es_if_exists="replace", ... es_if_exists="replace",
... es_refresh=True, ... es_refresh=True,
@ -175,7 +170,7 @@ def pandas_to_eland(
elif es_if_exists == "replace": elif es_if_exists == "replace":
es_client.indices.delete(index=es_dest_index) es_client.indices.delete(index=es_dest_index)
es_api_compat(es_client.indices.create, index=es_dest_index, body=mapping) es_client.indices.create(index=es_dest_index, mappings=mapping["mappings"])
elif es_if_exists == "append": elif es_if_exists == "append":
dest_mapping = es_client.indices.get_mapping(index=es_dest_index)[ dest_mapping = es_client.indices.get_mapping(index=es_dest_index)[
@ -187,7 +182,7 @@ def pandas_to_eland(
es_type_overrides=es_type_overrides, es_type_overrides=es_type_overrides,
) )
else: else:
es_api_compat(es_client.indices.create, index=es_dest_index, body=mapping) es_client.indices.create(index=es_dest_index, mappings=mapping["mappings"])
def action_generator( def action_generator(
pd_df: pd.DataFrame, pd_df: pd.DataFrame,
@ -252,7 +247,7 @@ def eland_to_pandas(ed_df: DataFrame, show_progress: bool = False) -> pd.DataFra
Examples Examples
-------- --------
>>> ed_df = ed.DataFrame('localhost', 'flights').head() >>> ed_df = ed.DataFrame('http://localhost:9200', 'flights').head()
>>> type(ed_df) >>> type(ed_df)
<class 'eland.dataframe.DataFrame'> <class 'eland.dataframe.DataFrame'>
>>> ed_df >>> ed_df
@ -282,7 +277,7 @@ def eland_to_pandas(ed_df: DataFrame, show_progress: bool = False) -> pd.DataFra
Convert `eland.DataFrame` to `pandas.DataFrame` and show progress every 10000 rows Convert `eland.DataFrame` to `pandas.DataFrame` and show progress every 10000 rows
>>> pd_df = ed.eland_to_pandas(ed.DataFrame('localhost', 'flights'), show_progress=True) # doctest: +SKIP >>> pd_df = ed.eland_to_pandas(ed.DataFrame('http://localhost:9200', 'flights'), show_progress=True) # doctest: +SKIP
2020-01-29 12:43:36.572395: read 10000 rows 2020-01-29 12:43:36.572395: read 10000 rows
2020-01-29 12:43:37.309031: read 13059 rows 2020-01-29 12:43:37.309031: read 13059 rows
@ -420,7 +415,7 @@ def csv_to_eland( # type: ignore
>>> ed.csv_to_eland( >>> ed.csv_to_eland(
... "churn.csv", ... "churn.csv",
... es_client='localhost', ... es_client='http://localhost:9200',
... es_dest_index='churn', ... es_dest_index='churn',
... es_refresh=True, ... es_refresh=True,
... index_col=0 ... index_col=0

View File

@ -515,7 +515,7 @@ class FieldMappings:
@staticmethod @staticmethod
def _generate_es_mappings( def _generate_es_mappings(
dataframe: "pd.DataFrame", es_type_overrides: Optional[Mapping[str, str]] = None dataframe: "pd.DataFrame", es_type_overrides: Optional[Mapping[str, str]] = None
) -> Dict[str, Dict[str, Dict[str, Any]]]: ) -> Dict[str, Dict[str, Any]]:
"""Given a pandas dataframe, generate the associated Elasticsearch mapping """Given a pandas dataframe, generate the associated Elasticsearch mapping
Parameters Parameters
@ -894,20 +894,20 @@ def verify_mapping_compatibility(
problems = [] problems = []
es_type_overrides = es_type_overrides or {} es_type_overrides = es_type_overrides or {}
ed_mapping = ed_mapping["mappings"]["properties"] ed_props = ed_mapping["mappings"]["properties"]
es_mapping = es_mapping["mappings"]["properties"] es_props = es_mapping["mappings"]["properties"]
for key in sorted(es_mapping.keys()): for key in sorted(es_props.keys()):
if key not in ed_mapping: if key not in ed_props:
problems.append(f"- {key!r} is missing from DataFrame columns") problems.append(f"- {key!r} is missing from DataFrame columns")
for key, key_def in sorted(ed_mapping.items()): for key, key_def in sorted(ed_props.items()):
if key not in es_mapping: if key not in es_props:
problems.append(f"- {key!r} is missing from ES index mapping") problems.append(f"- {key!r} is missing from ES index mapping")
continue continue
key_type = es_type_overrides.get(key, key_def["type"]) key_type = es_type_overrides.get(key, key_def["type"])
es_key_type = es_mapping[key]["type"] es_key_type = es_props[key]["type"]
if key_type != es_key_type and es_key_type not in ES_COMPATIBLE_TYPES.get( if key_type != es_key_type and es_key_type not in ES_COMPATIBLE_TYPES.get(
key_type, () key_type, ()
): ):

View File

@ -68,7 +68,7 @@ class DataFrameGroupBy(GroupBy):
Examples Examples
-------- --------
>>> df = ed.DataFrame( >>> df = ed.DataFrame(
... "localhost", "flights", ... "http://localhost:9200", "flights",
... columns=["AvgTicketPrice", "Cancelled", "dayOfWeek", "timestamp", "DestCountry"] ... columns=["AvgTicketPrice", "Cancelled", "dayOfWeek", "timestamp", "DestCountry"]
... ) ... )
>>> df.groupby("DestCountry").mean(numeric_only=False) # doctest: +SKIP >>> df.groupby("DestCountry").mean(numeric_only=False) # doctest: +SKIP
@ -119,7 +119,7 @@ class DataFrameGroupBy(GroupBy):
Examples Examples
-------- --------
>>> df = ed.DataFrame( >>> df = ed.DataFrame(
... "localhost", "flights", ... "http://localhost:9200", "flights",
... columns=["AvgTicketPrice", "Cancelled", "dayOfWeek", "timestamp", "DestCountry"] ... columns=["AvgTicketPrice", "Cancelled", "dayOfWeek", "timestamp", "DestCountry"]
... ) ... )
>>> df.groupby("DestCountry").var() # doctest: +NORMALIZE_WHITESPACE >>> df.groupby("DestCountry").var() # doctest: +NORMALIZE_WHITESPACE
@ -170,7 +170,7 @@ class DataFrameGroupBy(GroupBy):
Examples Examples
-------- --------
>>> df = ed.DataFrame( >>> df = ed.DataFrame(
... "localhost", "flights", ... "http://localhost:9200", "flights",
... columns=["AvgTicketPrice", "Cancelled", "dayOfWeek", "DestCountry"] ... columns=["AvgTicketPrice", "Cancelled", "dayOfWeek", "DestCountry"]
... ) ... )
>>> df.groupby("DestCountry").std() # doctest: +NORMALIZE_WHITESPACE >>> df.groupby("DestCountry").std() # doctest: +NORMALIZE_WHITESPACE
@ -221,7 +221,7 @@ class DataFrameGroupBy(GroupBy):
Examples Examples
-------- --------
>>> df = ed.DataFrame( >>> df = ed.DataFrame(
... "localhost", "flights", ... "http://localhost:9200", "flights",
... columns=["AvgTicketPrice", "Cancelled", "dayOfWeek", "timestamp", "DestCountry"] ... columns=["AvgTicketPrice", "Cancelled", "dayOfWeek", "timestamp", "DestCountry"]
... ) ... )
>>> df.groupby("DestCountry").mad() # doctest: +SKIP >>> df.groupby("DestCountry").mad() # doctest: +SKIP
@ -272,7 +272,7 @@ class DataFrameGroupBy(GroupBy):
Examples Examples
-------- --------
>>> df = ed.DataFrame( >>> df = ed.DataFrame(
... "localhost", "flights", ... "http://localhost:9200", "flights",
... columns=["AvgTicketPrice", "Cancelled", "dayOfWeek", "timestamp", "DestCountry"] ... columns=["AvgTicketPrice", "Cancelled", "dayOfWeek", "timestamp", "DestCountry"]
... ) ... )
>>> df.groupby("DestCountry").median(numeric_only=False) # doctest: +SKIP >>> df.groupby("DestCountry").median(numeric_only=False) # doctest: +SKIP
@ -323,7 +323,7 @@ class DataFrameGroupBy(GroupBy):
Examples Examples
-------- --------
>>> df = ed.DataFrame( >>> df = ed.DataFrame(
... "localhost", "flights", ... "http://localhost:9200", "flights",
... columns=["AvgTicketPrice", "Cancelled", "dayOfWeek", "DestCountry"] ... columns=["AvgTicketPrice", "Cancelled", "dayOfWeek", "DestCountry"]
... ) ... )
>>> df.groupby("DestCountry").sum() # doctest: +NORMALIZE_WHITESPACE >>> df.groupby("DestCountry").sum() # doctest: +NORMALIZE_WHITESPACE
@ -374,7 +374,7 @@ class DataFrameGroupBy(GroupBy):
Examples Examples
-------- --------
>>> df = ed.DataFrame( >>> df = ed.DataFrame(
... "localhost", "flights", ... "http://localhost:9200", "flights",
... columns=["AvgTicketPrice", "Cancelled", "dayOfWeek", "timestamp", "DestCountry"] ... columns=["AvgTicketPrice", "Cancelled", "dayOfWeek", "timestamp", "DestCountry"]
... ) ... )
>>> df.groupby("DestCountry").min(numeric_only=False) # doctest: +NORMALIZE_WHITESPACE >>> df.groupby("DestCountry").min(numeric_only=False) # doctest: +NORMALIZE_WHITESPACE
@ -425,7 +425,7 @@ class DataFrameGroupBy(GroupBy):
Examples Examples
-------- --------
>>> df = ed.DataFrame( >>> df = ed.DataFrame(
... "localhost", "flights", ... "http://localhost:9200", "flights",
... columns=["AvgTicketPrice", "Cancelled", "dayOfWeek", "timestamp", "DestCountry"] ... columns=["AvgTicketPrice", "Cancelled", "dayOfWeek", "timestamp", "DestCountry"]
... ) ... )
>>> df.groupby("DestCountry").max(numeric_only=False) # doctest: +NORMALIZE_WHITESPACE >>> df.groupby("DestCountry").max(numeric_only=False) # doctest: +NORMALIZE_WHITESPACE
@ -476,7 +476,7 @@ class DataFrameGroupBy(GroupBy):
Examples Examples
-------- --------
>>> df = ed.DataFrame( >>> df = ed.DataFrame(
... "localhost", "flights", ... "http://localhost:9200", "flights",
... columns=["AvgTicketPrice", "Cancelled", "dayOfWeek", "DestCountry"] ... columns=["AvgTicketPrice", "Cancelled", "dayOfWeek", "DestCountry"]
... ) ... )
>>> df.groupby("DestCountry").nunique() # doctest: +NORMALIZE_WHITESPACE >>> df.groupby("DestCountry").nunique() # doctest: +NORMALIZE_WHITESPACE
@ -526,7 +526,7 @@ class DataFrameGroupBy(GroupBy):
Examples Examples
-------- --------
>>> ed_df = ed.DataFrame('localhost', 'flights') >>> ed_df = ed.DataFrame('http://localhost:9200', 'flights')
>>> ed_flights = ed_df.filter(["AvgTicketPrice", "FlightDelayMin", "dayOfWeek", "timestamp"]) >>> ed_flights = ed_df.filter(["AvgTicketPrice", "FlightDelayMin", "dayOfWeek", "timestamp"])
>>> ed_flights.groupby(["dayOfWeek", "Cancelled"]).quantile() # doctest: +SKIP >>> ed_flights.groupby(["dayOfWeek", "Cancelled"]).quantile() # doctest: +SKIP
AvgTicketPrice FlightDelayMin AvgTicketPrice FlightDelayMin
@ -616,7 +616,7 @@ class DataFrameGroupBy(GroupBy):
Examples Examples
-------- --------
>>> df = ed.DataFrame( >>> df = ed.DataFrame(
... "localhost", "flights", ... "http://localhost:9200", "flights",
... columns=["AvgTicketPrice", "Cancelled", "dayOfWeek", "DestCountry"] ... columns=["AvgTicketPrice", "Cancelled", "dayOfWeek", "DestCountry"]
... ) ... )
>>> df.groupby("DestCountry").aggregate(["min", "max"]) # doctest: +NORMALIZE_WHITESPACE >>> df.groupby("DestCountry").aggregate(["min", "max"]) # doctest: +NORMALIZE_WHITESPACE
@ -670,7 +670,7 @@ class DataFrameGroupBy(GroupBy):
Examples Examples
-------- --------
>>> df = ed.DataFrame( >>> df = ed.DataFrame(
... "localhost", "flights", ... "http://localhost:9200", "flights",
... columns=["AvgTicketPrice", "Cancelled", "dayOfWeek", "DestCountry"] ... columns=["AvgTicketPrice", "Cancelled", "dayOfWeek", "DestCountry"]
... ) ... )
>>> df.groupby("DestCountry").count() # doctest: +NORMALIZE_WHITESPACE >>> df.groupby("DestCountry").count() # doctest: +NORMALIZE_WHITESPACE

View File

@ -79,7 +79,7 @@ class MLModel:
model_id: str model_id: str
The unique identifier of the trained inference model in Elasticsearch. The unique identifier of the trained inference model in Elasticsearch.
""" """
self._client = ensure_es_client(es_client) self._client: Elasticsearch = ensure_es_client(es_client)
self._model_id = model_id self._model_id = model_id
self._trained_model_config_cache: Optional[Dict[str, Any]] = None self._trained_model_config_cache: Optional[Dict[str, Any]] = None
@ -120,7 +120,7 @@ class MLModel:
>>> # Serialise the model to Elasticsearch >>> # Serialise the model to Elasticsearch
>>> feature_names = ["f0", "f1", "f2", "f3", "f4", "f5"] >>> feature_names = ["f0", "f1", "f2", "f3", "f4", "f5"]
>>> model_id = "test_xgb_regressor" >>> model_id = "test_xgb_regressor"
>>> es_model = MLModel.import_model('localhost', model_id, regressor, feature_names, es_if_exists='replace') >>> es_model = MLModel.import_model('http://localhost:9200', model_id, regressor, feature_names, es_if_exists='replace')
>>> # Get some test results from Elasticsearch model >>> # Get some test results from Elasticsearch model
>>> es_model.predict(test_data) # doctest: +SKIP >>> es_model.predict(test_data) # doctest: +SKIP
@ -167,20 +167,18 @@ class MLModel:
) )
results = self._client.ingest.simulate( results = self._client.ingest.simulate(
body={ pipeline={
"pipeline": { "processors": [
"processors": [ {
{ "inference": {
"inference": { "model_id": self._model_id,
"model_id": self._model_id, "inference_config": {self.model_type: {}},
"inference_config": {self.model_type: {}}, field_map_name: {},
field_map_name: {},
}
} }
] }
}, ]
"docs": docs, },
} docs=docs,
) )
# Unpack results into an array. Errors can be present # Unpack results into an array. Errors can be present
@ -342,7 +340,7 @@ class MLModel:
>>> feature_names = ["f0", "f1", "f2", "f3", "f4"] >>> feature_names = ["f0", "f1", "f2", "f3", "f4"]
>>> model_id = "test_decision_tree_classifier" >>> model_id = "test_decision_tree_classifier"
>>> es_model = MLModel.import_model( >>> es_model = MLModel.import_model(
... 'localhost', ... 'http://localhost:9200',
... model_id=model_id, ... model_id=model_id,
... model=classifier, ... model=classifier,
... feature_names=feature_names, ... feature_names=feature_names,
@ -383,22 +381,21 @@ class MLModel:
elif es_if_exists == "replace": elif es_if_exists == "replace":
ml_model.delete_model() ml_model.delete_model()
body: Dict[str, Any] = {
"input": {"field_names": feature_names},
}
# 'inference_config' is required in 7.8+ but isn't available in <=7.7
if es_version(es_client) >= (7, 8):
body["inference_config"] = {model_type: {}}
if es_compress_model_definition: if es_compress_model_definition:
body["compressed_definition"] = serializer.serialize_and_compress_model() ml_model._client.ml.put_trained_model(
model_id=model_id,
input={"field_names": feature_names},
inference_config={model_type: {}},
compressed_definition=serializer.serialize_and_compress_model(),
)
else: else:
body["definition"] = serializer.serialize_model() ml_model._client.ml.put_trained_model(
model_id=model_id,
input={"field_names": feature_names},
inference_config={model_type: {}},
definition=serializer.serialize_model(),
)
ml_model._client.ml.put_trained_model(
model_id=model_id,
body=body,
)
return ml_model return ml_model
def delete_model(self) -> None: def delete_model(self) -> None:
@ -408,7 +405,9 @@ class MLModel:
If model doesn't exist, ignore failure. If model doesn't exist, ignore failure.
""" """
try: try:
self._client.ml.delete_trained_model(model_id=self._model_id, ignore=(404,)) self._client.options(ignore_status=404).ml.delete_trained_model(
model_id=self._model_id
)
except elasticsearch.NotFoundError: except elasticsearch.NotFoundError:
pass pass
@ -426,16 +425,7 @@ class MLModel:
def _trained_model_config(self) -> Dict[str, Any]: def _trained_model_config(self) -> Dict[str, Any]:
"""Lazily loads an ML models 'trained_model_config' information""" """Lazily loads an ML models 'trained_model_config' information"""
if self._trained_model_config_cache is None: if self._trained_model_config_cache is None:
resp = self._client.ml.get_trained_models(model_id=self._model_id)
# In Elasticsearch 7.7 and earlier you can't get
# target type without pulling the model definition
# so we check the version first.
if es_version(self._client) < (7, 8):
resp = self._client.ml.get_trained_models(
model_id=self._model_id, include_model_definition=True
)
else:
resp = self._client.ml.get_trained_models(model_id=self._model_id)
if resp["count"] > 1: if resp["count"] > 1:
raise ValueError(f"Model ID {self._model_id!r} wasn't unambiguous") raise ValueError(f"Model ID {self._model_id!r} wasn't unambiguous")

View File

@ -46,21 +46,19 @@ class PyTorchModel:
es_client: Union[str, List[str], Tuple[str, ...], "Elasticsearch"], es_client: Union[str, List[str], Tuple[str, ...], "Elasticsearch"],
model_id: str, model_id: str,
): ):
self._client = ensure_es_client(es_client) self._client: Elasticsearch = ensure_es_client(es_client)
self.model_id = model_id self.model_id = model_id
def put_config(self, path: str) -> None: def put_config(self, path: str) -> None:
with open(path) as f: with open(path) as f:
config = json.load(f) config = json.load(f)
self._client.ml.put_trained_model(model_id=self.model_id, body=config) self._client.ml.put_trained_model(model_id=self.model_id, **config)
def put_vocab(self, path: str) -> None: def put_vocab(self, path: str) -> None:
with open(path) as f: with open(path) as f:
vocab = json.load(f) vocab = json.load(f)
self._client.transport.perform_request( self._client.ml.put_trained_model_vocabulary(
"PUT", model_id=self.model_id, vocabulary=vocab["vocabulary"]
f"/_ml/trained_models/{self.model_id}/vocabulary",
body=vocab,
) )
def put_model(self, model_path: str, chunk_size: int = DEFAULT_CHUNK_SIZE) -> None: def put_model(self, model_path: str, chunk_size: int = DEFAULT_CHUNK_SIZE) -> None:
@ -76,15 +74,12 @@ class PyTorchModel:
yield base64.b64encode(data).decode() yield base64.b64encode(data).decode()
for i, data in tqdm(enumerate(model_file_chunk_generator()), total=total_parts): for i, data in tqdm(enumerate(model_file_chunk_generator()), total=total_parts):
body = { self._client.ml.put_trained_model_definition_part(
"total_definition_length": model_size, model_id=self.model_id,
"total_parts": total_parts, part=i,
"definition": data, total_definition_length=model_size,
} total_parts=total_parts,
self._client.transport.perform_request( definition=data,
"PUT",
f"/_ml/trained_models/{self.model_id}/definition/{i}",
body=body,
) )
def import_model( def import_model(
@ -100,42 +95,41 @@ class PyTorchModel:
self.put_vocab(vocab_path) self.put_vocab(vocab_path)
def infer( def infer(
self, body: Dict[str, Any], timeout: str = DEFAULT_TIMEOUT self,
) -> Union[bool, Any]: docs: List[Dict[str, str]],
return self._client.transport.perform_request( timeout: str = DEFAULT_TIMEOUT,
"POST", ) -> Any:
f"/_ml/trained_models/{self.model_id}/deployment/_infer", return self._client.options(
body=body, request_timeout=60
params={"timeout": timeout, "request_timeout": 60}, ).ml.infer_trained_model_deployment(
model_id=self.model_id,
timeout=timeout,
docs=docs,
) )
def start(self, timeout: str = DEFAULT_TIMEOUT) -> None: def start(self, timeout: str = DEFAULT_TIMEOUT) -> None:
self._client.transport.perform_request( self._client.options(request_timeout=60).ml.start_trained_model_deployment(
"POST", model_id=self.model_id, timeout=timeout, wait_for="started"
f"/_ml/trained_models/{self.model_id}/deployment/_start",
params={"timeout": timeout, "request_timeout": 60, "wait_for": "started"},
) )
def stop(self) -> None: def stop(self) -> None:
self._client.transport.perform_request( self._client.ml.stop_trained_model_deployment(model_id=self.model_id)
"POST",
f"/_ml/trained_models/{self.model_id}/deployment/_stop",
params={"ignore": 404},
)
def delete(self) -> None: def delete(self) -> None:
self._client.ml.delete_trained_model(model_id=self.model_id, ignore=(404,)) self._client.options(ignore_status=404).ml.delete_trained_model(
model_id=self.model_id
)
@classmethod @classmethod
def list( def list(
cls, es_client: Union[str, List[str], Tuple[str, ...], "Elasticsearch"] cls, es_client: Union[str, List[str], Tuple[str, ...], "Elasticsearch"]
) -> Set[str]: ) -> Set[str]:
client = ensure_es_client(es_client) client = ensure_es_client(es_client)
res = client.ml.get_trained_models(model_id="*", allow_no_match=True) resp = client.ml.get_trained_models(model_id="*", allow_no_match=True)
return set( return set(
[ [
model["model_id"] model["model_id"]
for model in res["trained_model_configs"] for model in resp["trained_model_configs"]
if model["model_type"] == "pytorch" if model["model_type"] == "pytorch"
] ]
) )

View File

@ -99,7 +99,7 @@ class NDFrame(ABC):
Examples Examples
-------- --------
>>> df = ed.DataFrame('localhost', 'flights') >>> df = ed.DataFrame('http://localhost:9200', 'flights')
>>> assert isinstance(df.index, ed.Index) >>> assert isinstance(df.index, ed.Index)
>>> df.index.es_index_field >>> df.index.es_index_field
'_id' '_id'
@ -127,7 +127,7 @@ class NDFrame(ABC):
Examples Examples
-------- --------
>>> df = ed.DataFrame('localhost', 'flights', columns=['Origin', 'AvgTicketPrice', 'timestamp', 'dayOfWeek']) >>> df = ed.DataFrame('http://localhost:9200', 'flights', columns=['Origin', 'AvgTicketPrice', 'timestamp', 'dayOfWeek'])
>>> df.dtypes >>> df.dtypes
Origin object Origin object
AvgTicketPrice float64 AvgTicketPrice float64
@ -149,7 +149,7 @@ class NDFrame(ABC):
Examples Examples
-------- --------
>>> df = ed.DataFrame('localhost', 'flights', columns=['Origin', 'AvgTicketPrice', 'timestamp', 'dayOfWeek']) >>> df = ed.DataFrame('http://localhost:9200', 'flights', columns=['Origin', 'AvgTicketPrice', 'timestamp', 'dayOfWeek'])
>>> df.es_dtypes >>> df.es_dtypes
Origin keyword Origin keyword
AvgTicketPrice float AvgTicketPrice float
@ -213,7 +213,7 @@ class NDFrame(ABC):
Examples Examples
-------- --------
>>> df = ed.DataFrame('localhost', 'flights', columns=["AvgTicketPrice", "Cancelled", "dayOfWeek", "timestamp", "DestCountry"]) >>> df = ed.DataFrame('http://localhost:9200', 'flights', columns=["AvgTicketPrice", "Cancelled", "dayOfWeek", "timestamp", "DestCountry"])
>>> df.mean() # doctest: +SKIP >>> df.mean() # doctest: +SKIP
AvgTicketPrice 628.254 AvgTicketPrice 628.254
Cancelled 0.128494 Cancelled 0.128494
@ -262,7 +262,7 @@ class NDFrame(ABC):
Examples Examples
-------- --------
>>> df = ed.DataFrame('localhost', 'flights', columns=["AvgTicketPrice", "Cancelled", "dayOfWeek", "timestamp", "DestCountry"]) >>> df = ed.DataFrame('http://localhost:9200', 'flights', columns=["AvgTicketPrice", "Cancelled", "dayOfWeek", "timestamp", "DestCountry"])
>>> df.sum() # doctest: +SKIP >>> df.sum() # doctest: +SKIP
AvgTicketPrice 8.20436e+06 AvgTicketPrice 8.20436e+06
Cancelled 1678 Cancelled 1678
@ -310,7 +310,7 @@ class NDFrame(ABC):
Examples Examples
-------- --------
>>> df = ed.DataFrame('localhost', 'flights', columns=["AvgTicketPrice", "Cancelled", "dayOfWeek", "timestamp", "DestCountry"]) >>> df = ed.DataFrame('http://localhost:9200', 'flights', columns=["AvgTicketPrice", "Cancelled", "dayOfWeek", "timestamp", "DestCountry"])
>>> df.min() # doctest: +SKIP >>> df.min() # doctest: +SKIP
AvgTicketPrice 100.021 AvgTicketPrice 100.021
Cancelled False Cancelled False
@ -357,7 +357,7 @@ class NDFrame(ABC):
Examples Examples
-------- --------
>>> df = ed.DataFrame('localhost', 'flights', columns=["AvgTicketPrice", "Cancelled", "dayOfWeek", "timestamp", "DestCountry"]) >>> df = ed.DataFrame('http://localhost:9200', 'flights', columns=["AvgTicketPrice", "Cancelled", "dayOfWeek", "timestamp", "DestCountry"])
>>> df.var() # doctest: +SKIP >>> df.var() # doctest: +SKIP
AvgTicketPrice 70964.570234 AvgTicketPrice 70964.570234
Cancelled 0.111987 Cancelled 0.111987
@ -403,7 +403,7 @@ class NDFrame(ABC):
Examples Examples
-------- --------
>>> df = ed.DataFrame('localhost', 'flights', columns=["AvgTicketPrice", "Cancelled", "dayOfWeek", "timestamp", "DestCountry"]) >>> df = ed.DataFrame('http://localhost:9200', 'flights', columns=["AvgTicketPrice", "Cancelled", "dayOfWeek", "timestamp", "DestCountry"])
>>> df.std() # doctest: +SKIP >>> df.std() # doctest: +SKIP
AvgTicketPrice 266.407061 AvgTicketPrice 266.407061
Cancelled 0.334664 Cancelled 0.334664
@ -449,7 +449,7 @@ class NDFrame(ABC):
Examples Examples
-------- --------
>>> df = ed.DataFrame('localhost', 'flights', columns=["AvgTicketPrice", "Cancelled", "dayOfWeek", "timestamp", "DestCountry"]) >>> df = ed.DataFrame('http://localhost:9200', 'flights', columns=["AvgTicketPrice", "Cancelled", "dayOfWeek", "timestamp", "DestCountry"])
>>> df.median() # doctest: +SKIP >>> df.median() # doctest: +SKIP
AvgTicketPrice 640.363 AvgTicketPrice 640.363
Cancelled False Cancelled False
@ -498,7 +498,7 @@ class NDFrame(ABC):
Examples Examples
-------- --------
>>> df = ed.DataFrame('localhost', 'flights', columns=["AvgTicketPrice", "Cancelled", "dayOfWeek", "timestamp", "DestCountry"]) >>> df = ed.DataFrame('http://localhost:9200', 'flights', columns=["AvgTicketPrice", "Cancelled", "dayOfWeek", "timestamp", "DestCountry"])
>>> df.max() # doctest: +SKIP >>> df.max() # doctest: +SKIP
AvgTicketPrice 1199.73 AvgTicketPrice 1199.73
Cancelled True Cancelled True
@ -557,7 +557,7 @@ class NDFrame(ABC):
Examples Examples
-------- --------
>>> columns = ['category', 'currency', 'customer_birth_date', 'customer_first_name', 'user'] >>> columns = ['category', 'currency', 'customer_birth_date', 'customer_first_name', 'user']
>>> df = ed.DataFrame('localhost', 'ecommerce', columns=columns) >>> df = ed.DataFrame('http://localhost:9200', 'ecommerce', columns=columns)
>>> df.nunique() >>> df.nunique()
category 6 category 6
currency 1 currency 1
@ -583,7 +583,7 @@ class NDFrame(ABC):
Examples Examples
-------- --------
>>> df = ed.DataFrame('localhost', 'flights', columns=["AvgTicketPrice", "Cancelled", "dayOfWeek", "timestamp", "DestCountry"]) >>> df = ed.DataFrame('http://localhost:9200', 'flights', columns=["AvgTicketPrice", "Cancelled", "dayOfWeek", "timestamp", "DestCountry"])
>>> df.mad() # doctest: +SKIP >>> df.mad() # doctest: +SKIP
AvgTicketPrice 213.35497 AvgTicketPrice 213.35497
dayOfWeek 2.00000 dayOfWeek 2.00000
@ -628,7 +628,7 @@ class NDFrame(ABC):
Examples Examples
-------- --------
>>> df = ed.DataFrame('localhost', 'flights', columns=['AvgTicketPrice', 'FlightDelayMin']) # ignoring percentiles >>> df = ed.DataFrame('http://localhost:9200', 'flights', columns=['AvgTicketPrice', 'FlightDelayMin']) # ignoring percentiles
>>> df.describe() # doctest: +SKIP >>> df.describe() # doctest: +SKIP
AvgTicketPrice FlightDelayMin AvgTicketPrice FlightDelayMin
count 13059.000000 13059.000000 count 13059.000000 13059.000000

View File

@ -34,7 +34,6 @@ from typing import (
import numpy as np import numpy as np
import pandas as pd # type: ignore import pandas as pd # type: ignore
from elasticsearch.exceptions import NotFoundError
from eland.actions import PostProcessingAction from eland.actions import PostProcessingAction
from eland.common import ( from eland.common import (
@ -45,8 +44,6 @@ from eland.common import (
SortOrder, SortOrder,
build_pd_series, build_pd_series,
elasticsearch_date_to_pandas_date, elasticsearch_date_to_pandas_date,
es_api_compat,
es_version,
) )
from eland.index import Index from eland.index import Index
from eland.query import Query from eland.query import Query
@ -173,7 +170,7 @@ class Operations:
body.exists(field, must=True) body.exists(field, must=True)
field_exists_count = query_compiler._client.count( field_exists_count = query_compiler._client.count(
index=query_compiler._index_pattern, body=body.to_count_body() index=query_compiler._index_pattern, **body.to_count_body()
)["count"] )["count"]
counts[field] = field_exists_count counts[field] = field_exists_count
@ -240,7 +237,7 @@ class Operations:
# Fetch Response # Fetch Response
response = query_compiler._client.search( response = query_compiler._client.search(
index=query_compiler._index_pattern, size=0, body=body.to_search_body() index=query_compiler._index_pattern, size=0, **body.to_search_body()
) )
response = response["aggregations"] response = response["aggregations"]
@ -404,7 +401,7 @@ class Operations:
) )
response = query_compiler._client.search( response = query_compiler._client.search(
index=query_compiler._index_pattern, size=0, body=body.to_search_body() index=query_compiler._index_pattern, size=0, **body.to_search_body()
) )
""" """
@ -1275,7 +1272,7 @@ class Operations:
body.exists(field, must=True) body.exists(field, must=True)
count: int = query_compiler._client.count( count: int = query_compiler._client.count(
index=query_compiler._index_pattern, body=body.to_count_body() index=query_compiler._index_pattern, **body.to_count_body()
)["count"] )["count"]
return count return count
@ -1313,7 +1310,7 @@ class Operations:
body.terms(field, items, must=True) body.terms(field, items, must=True)
count: int = query_compiler._client.count( count: int = query_compiler._client.count(
index=query_compiler._index_pattern, body=body.to_count_body() index=query_compiler._index_pattern, **body.to_count_body()
)["count"] )["count"]
return count return count
@ -1488,99 +1485,16 @@ def _search_yield_hits(
[[{'_index': 'flights', '_type': '_doc', '_id': '0', '_score': None, '_source': {...}, 'sort': [...]}, [[{'_index': 'flights', '_type': '_doc', '_id': '0', '_score': None, '_source': {...}, 'sort': [...]},
{'_index': 'flights', '_type': '_doc', '_id': '1', '_score': None, '_source': {...}, 'sort': [...]}]] {'_index': 'flights', '_type': '_doc', '_id': '1', '_score': None, '_source': {...}, 'sort': [...]}]]
""" """
# No documents, no reason to send a search.
if max_number_of_hits == 0:
return
# Make a copy of 'body' to avoid mutating it outside this function. # Make a copy of 'body' to avoid mutating it outside this function.
body = body.copy() body = body.copy()
# Use the default search size # Use the default search size
body.setdefault("size", DEFAULT_SEARCH_SIZE) body.setdefault("size", DEFAULT_SEARCH_SIZE)
# Elasticsearch 7.12 added '_shard_doc' sort tiebreaker for PITs which
# means we're guaranteed to be safe on documents with a duplicate sort rank.
if es_version(query_compiler._client) >= (7, 12, 0):
yield from _search_with_pit_and_search_after(
query_compiler=query_compiler,
body=body,
max_number_of_hits=max_number_of_hits,
)
# Otherwise we use 'scroll' like we used to.
else:
yield from _search_with_scroll(
query_compiler=query_compiler,
body=body,
max_number_of_hits=max_number_of_hits,
)
def _search_with_scroll(
query_compiler: "QueryCompiler",
body: Dict[str, Any],
max_number_of_hits: Optional[int],
) -> Generator[List[Dict[str, Any]], None, None]:
# No documents, no reason to send a search.
if max_number_of_hits == 0:
return
client = query_compiler._client
hits_yielded = 0
# Make the initial search with 'scroll' set
resp = es_api_compat(
client.search,
index=query_compiler._index_pattern,
body=body,
scroll=DEFAULT_PIT_KEEP_ALIVE,
)
scroll_id: Optional[str] = resp.get("_scroll_id", None)
try:
while scroll_id and (
max_number_of_hits is None or hits_yielded < max_number_of_hits
):
hits: List[Dict[str, Any]] = resp["hits"]["hits"]
# If we didn't receive any hits it means we've reached the end.
if not hits:
break
# Calculate which hits should be yielded from this batch
if max_number_of_hits is None:
hits_to_yield = len(hits)
else:
hits_to_yield = min(len(hits), max_number_of_hits - hits_yielded)
# Yield the hits we need to and then track the total number.
# Never yield an empty list as that makes things simpler for
# downstream consumers.
if hits and hits_to_yield > 0:
yield hits[:hits_to_yield]
hits_yielded += hits_to_yield
# Retrieve the next set of results
resp = client.scroll(
body={"scroll_id": scroll_id, "scroll": DEFAULT_PIT_KEEP_ALIVE},
)
scroll_id = resp.get("_scroll_id", None) # Update the scroll ID.
finally:
# Close the scroll if we have one open
if scroll_id is not None:
try:
client.clear_scroll(body={"scroll_id": [scroll_id]})
except NotFoundError:
pass
def _search_with_pit_and_search_after(
query_compiler: "QueryCompiler",
body: Dict[str, Any],
max_number_of_hits: Optional[int],
) -> Generator[List[Dict[str, Any]], None, None]:
# No documents, no reason to send a search.
if max_number_of_hits == 0:
return
client = query_compiler._client client = query_compiler._client
hits_yielded = 0 # Track the total number of hits yielded. hits_yielded = 0 # Track the total number of hits yielded.
pit_id: Optional[str] = None pit_id: Optional[str] = None
@ -1603,7 +1517,7 @@ def _search_with_pit_and_search_after(
body["pit"] = {"id": pit_id, "keep_alive": DEFAULT_PIT_KEEP_ALIVE} body["pit"] = {"id": pit_id, "keep_alive": DEFAULT_PIT_KEEP_ALIVE}
while max_number_of_hits is None or hits_yielded < max_number_of_hits: while max_number_of_hits is None or hits_yielded < max_number_of_hits:
resp = es_api_compat(client.search, body=body) resp = client.search(**body)
hits: List[Dict[str, Any]] = resp["hits"]["hits"] hits: List[Dict[str, Any]] = resp["hits"]["hits"]
# The point in time ID can change between searches so we # The point in time ID can change between searches so we
@ -1636,8 +1550,4 @@ def _search_with_pit_and_search_after(
# We want to cleanup the point in time if we allocated one # We want to cleanup the point in time if we allocated one
# to keep our memory footprint low. # to keep our memory footprint low.
if pit_id is not None: if pit_id is not None:
try: client.options(ignore_status=404).close_point_in_time(id=pit_id)
client.close_point_in_time(body={"id": pit_id})
except NotFoundError:
# If a point in time is already closed Elasticsearch throws NotFoundError
pass

View File

@ -43,7 +43,7 @@ def ed_hist_series(
Examples Examples
-------- --------
>>> import matplotlib.pyplot as plt >>> import matplotlib.pyplot as plt
>>> df = ed.DataFrame('localhost', 'flights') >>> df = ed.DataFrame('http://localhost:9200', 'flights')
>>> df[df.OriginWeather == 'Sunny']['FlightTimeMin'].hist(alpha=0.5, density=True) # doctest: +SKIP >>> df[df.OriginWeather == 'Sunny']['FlightTimeMin'].hist(alpha=0.5, density=True) # doctest: +SKIP
>>> df[df.OriginWeather != 'Sunny']['FlightTimeMin'].hist(alpha=0.5, density=True) # doctest: +SKIP >>> df[df.OriginWeather != 'Sunny']['FlightTimeMin'].hist(alpha=0.5, density=True) # doctest: +SKIP
>>> plt.show() # doctest: +SKIP >>> plt.show() # doctest: +SKIP
@ -109,7 +109,7 @@ def ed_hist_frame(
Examples Examples
-------- --------
>>> df = ed.DataFrame('localhost', 'flights') >>> df = ed.DataFrame('http://localhost:9200', 'flights')
>>> hist = df.select_dtypes(include=[np.number]).hist(figsize=[10,10]) # doctest: +SKIP >>> hist = df.select_dtypes(include=[np.number]).hist(figsize=[10,10]) # doctest: +SKIP
""" """
return hist_frame( return hist_frame(

View File

@ -97,7 +97,7 @@ class Series(NDFrame):
Examples Examples
-------- --------
>>> ed.Series(es_client='localhost', es_index_pattern='flights', name='Carrier') >>> ed.Series(es_client='http://localhost:9200', es_index_pattern='flights', name='Carrier')
0 Kibana Airlines 0 Kibana Airlines
1 Logstash Airways 1 Logstash Airways
2 Logstash Airways 2 Logstash Airways
@ -165,7 +165,7 @@ class Series(NDFrame):
Examples Examples
-------- --------
>>> df = ed.Series('localhost', 'ecommerce', name='total_quantity') >>> df = ed.Series('http://localhost:9200', 'ecommerce', name='total_quantity')
>>> df.shape >>> df.shape
(4675, 1) (4675, 1)
""" """
@ -214,7 +214,7 @@ class Series(NDFrame):
Examples Examples
-------- --------
>>> df = ed.DataFrame('localhost', 'flights') >>> df = ed.DataFrame('http://localhost:9200', 'flights')
>>> df.Carrier >>> df.Carrier
0 Kibana Airlines 0 Kibana Airlines
1 Logstash Airways 1 Logstash Airways
@ -290,7 +290,7 @@ class Series(NDFrame):
Examples Examples
-------- --------
>>> df = ed.DataFrame('localhost', 'flights') >>> df = ed.DataFrame('http://localhost:9200', 'flights')
>>> df['Carrier'].value_counts() >>> df['Carrier'].value_counts()
Logstash Airways 3331 Logstash Airways 3331
JetBeats 3274 JetBeats 3274
@ -587,7 +587,7 @@ class Series(NDFrame):
Examples Examples
-------- --------
>>> ed_flights = ed.DataFrame('localhost', 'flights') >>> ed_flights = ed.DataFrame('http://localhost:9200', 'flights')
>>> ed_flights["timestamp"].quantile([.2,.5,.75]) # doctest: +SKIP >>> ed_flights["timestamp"].quantile([.2,.5,.75]) # doctest: +SKIP
0.20 2018-01-09 04:30:57.289159912 0.20 2018-01-09 04:30:57.289159912
0.50 2018-01-21 23:39:27.031627441 0.50 2018-01-21 23:39:27.031627441
@ -691,7 +691,7 @@ class Series(NDFrame):
Examples Examples
-------- --------
>>> ed_ecommerce = ed.DataFrame('localhost', 'ecommerce') >>> ed_ecommerce = ed.DataFrame('http://localhost:9200', 'ecommerce')
>>> ed_ecommerce["day_of_week"].mode() >>> ed_ecommerce["day_of_week"].mode()
0 Thursday 0 Thursday
dtype: object dtype: object
@ -760,7 +760,7 @@ class Series(NDFrame):
Examples Examples
-------- --------
>>> df = ed.DataFrame( >>> df = ed.DataFrame(
... "localhost:9200", "ecommerce", ... "http://localhost:9200", "ecommerce",
... columns=["category", "taxful_total_price"] ... columns=["category", "taxful_total_price"]
... ) ... )
>>> df[ >>> df[
@ -807,7 +807,7 @@ class Series(NDFrame):
Examples Examples
-------- --------
>>> df = ed.DataFrame('localhost', 'ecommerce').head(5) >>> df = ed.DataFrame('http://localhost:9200', 'ecommerce').head(5)
>>> df.taxful_total_price >>> df.taxful_total_price
0 36.98 0 36.98
1 53.98 1 53.98
@ -867,7 +867,7 @@ class Series(NDFrame):
Examples Examples
-------- --------
>>> df = ed.DataFrame('localhost', 'ecommerce').head(5) >>> df = ed.DataFrame('http://localhost:9200', 'ecommerce').head(5)
>>> df.taxful_total_price >>> df.taxful_total_price
0 36.98 0 36.98
1 53.98 1 53.98
@ -906,7 +906,7 @@ class Series(NDFrame):
Examples Examples
-------- --------
>>> df = ed.DataFrame('localhost', 'ecommerce').head(5) >>> df = ed.DataFrame('http://localhost:9200', 'ecommerce').head(5)
>>> df.taxful_total_price >>> df.taxful_total_price
0 36.98 0 36.98
1 53.98 1 53.98
@ -945,7 +945,7 @@ class Series(NDFrame):
Examples Examples
-------- --------
>>> df = ed.DataFrame('localhost', 'ecommerce').head(5) >>> df = ed.DataFrame('http://localhost:9200', 'ecommerce').head(5)
>>> df.taxful_total_price >>> df.taxful_total_price
0 36.98 0 36.98
1 53.98 1 53.98
@ -984,7 +984,7 @@ class Series(NDFrame):
Examples Examples
-------- --------
>>> df = ed.DataFrame('localhost', 'ecommerce').head(5) >>> df = ed.DataFrame('http://localhost:9200', 'ecommerce').head(5)
>>> df.taxful_total_price >>> df.taxful_total_price
0 36.98 0 36.98
1 53.98 1 53.98
@ -1023,7 +1023,7 @@ class Series(NDFrame):
Examples Examples
-------- --------
>>> df = ed.DataFrame('localhost', 'ecommerce').head(5) >>> df = ed.DataFrame('http://localhost:9200', 'ecommerce').head(5)
>>> df.taxful_total_price >>> df.taxful_total_price
0 36.98 0 36.98
1 53.98 1 53.98
@ -1062,7 +1062,7 @@ class Series(NDFrame):
Examples Examples
-------- --------
>>> df = ed.DataFrame('localhost', 'ecommerce').head(5) >>> df = ed.DataFrame('http://localhost:9200', 'ecommerce').head(5)
>>> df.taxful_total_price >>> df.taxful_total_price
0 36.98 0 36.98
1 53.98 1 53.98
@ -1101,7 +1101,7 @@ class Series(NDFrame):
Examples Examples
-------- --------
>>> df = ed.DataFrame('localhost', 'ecommerce').head(5) >>> df = ed.DataFrame('http://localhost:9200', 'ecommerce').head(5)
>>> df.taxful_total_price >>> df.taxful_total_price
0 36.98 0 36.98
1 53.98 1 53.98
@ -1133,7 +1133,7 @@ class Series(NDFrame):
Examples Examples
-------- --------
>>> df = ed.DataFrame('localhost', 'ecommerce').head(5) >>> df = ed.DataFrame('http://localhost:9200', 'ecommerce').head(5)
>>> df.taxful_total_price >>> df.taxful_total_price
0 36.98 0 36.98
1 53.98 1 53.98
@ -1165,7 +1165,7 @@ class Series(NDFrame):
Examples Examples
-------- --------
>>> df = ed.DataFrame('localhost', 'ecommerce').head(5) >>> df = ed.DataFrame('http://localhost:9200', 'ecommerce').head(5)
>>> df.taxful_total_price >>> df.taxful_total_price
0 36.98 0 36.98
1 53.98 1 53.98
@ -1197,7 +1197,7 @@ class Series(NDFrame):
Examples Examples
-------- --------
>>> df = ed.DataFrame('localhost', 'ecommerce').head(5) >>> df = ed.DataFrame('http://localhost:9200', 'ecommerce').head(5)
>>> df.taxful_total_price >>> df.taxful_total_price
0 36.98 0 36.98
1 53.98 1 53.98
@ -1229,7 +1229,7 @@ class Series(NDFrame):
Examples Examples
-------- --------
>>> df = ed.DataFrame('localhost', 'ecommerce').head(5) >>> df = ed.DataFrame('http://localhost:9200', 'ecommerce').head(5)
>>> df.taxful_total_price >>> df.taxful_total_price
0 36.98 0 36.98
1 53.98 1 53.98
@ -1261,7 +1261,7 @@ class Series(NDFrame):
Examples Examples
-------- --------
>>> df = ed.DataFrame('localhost', 'ecommerce').head(5) >>> df = ed.DataFrame('http://localhost:9200', 'ecommerce').head(5)
>>> df.total_quantity >>> df.total_quantity
0 2 0 2
1 2 1 2
@ -1293,7 +1293,7 @@ class Series(NDFrame):
Examples Examples
-------- --------
>>> df = ed.DataFrame('localhost', 'ecommerce').head(5) >>> df = ed.DataFrame('http://localhost:9200', 'ecommerce').head(5)
>>> df.taxful_total_price >>> df.taxful_total_price
0 36.98 0 36.98
1 53.98 1 53.98
@ -1415,7 +1415,7 @@ class Series(NDFrame):
Examples Examples
-------- --------
>>> s = ed.DataFrame('localhost', 'flights')['AvgTicketPrice'] >>> s = ed.DataFrame('http://localhost:9200', 'flights')['AvgTicketPrice']
>>> int(s.max()) >>> int(s.max())
1199 1199
""" """
@ -1439,7 +1439,7 @@ class Series(NDFrame):
Examples Examples
-------- --------
>>> s = ed.DataFrame('localhost', 'flights')['AvgTicketPrice'] >>> s = ed.DataFrame('http://localhost:9200', 'flights')['AvgTicketPrice']
>>> int(s.mean()) >>> int(s.mean())
628 628
""" """
@ -1463,7 +1463,7 @@ class Series(NDFrame):
Examples Examples
-------- --------
>>> s = ed.DataFrame('localhost', 'flights')['AvgTicketPrice'] >>> s = ed.DataFrame('http://localhost:9200', 'flights')['AvgTicketPrice']
>>> int(s.median()) >>> int(s.median())
640 640
""" """
@ -1487,7 +1487,7 @@ class Series(NDFrame):
Examples Examples
-------- --------
>>> s = ed.DataFrame('localhost', 'flights')['AvgTicketPrice'] >>> s = ed.DataFrame('http://localhost:9200', 'flights')['AvgTicketPrice']
>>> int(s.min()) >>> int(s.min())
100 100
""" """
@ -1511,7 +1511,7 @@ class Series(NDFrame):
Examples Examples
-------- --------
>>> s = ed.DataFrame('localhost', 'flights')['AvgTicketPrice'] >>> s = ed.DataFrame('http://localhost:9200', 'flights')['AvgTicketPrice']
>>> int(s.sum()) >>> int(s.sum())
8204364 8204364
""" """
@ -1533,7 +1533,7 @@ class Series(NDFrame):
Examples Examples
-------- --------
>>> s = ed.DataFrame('localhost', 'flights')['Carrier'] >>> s = ed.DataFrame('http://localhost:9200', 'flights')['Carrier']
>>> s.nunique() >>> s.nunique()
4 4
""" """
@ -1555,7 +1555,7 @@ class Series(NDFrame):
Examples Examples
-------- --------
>>> s = ed.DataFrame('localhost', 'flights')['AvgTicketPrice'] >>> s = ed.DataFrame('http://localhost:9200', 'flights')['AvgTicketPrice']
>>> int(s.var()) >>> int(s.var())
70964 70964
""" """
@ -1577,7 +1577,7 @@ class Series(NDFrame):
Examples Examples
-------- --------
>>> s = ed.DataFrame('localhost', 'flights')['AvgTicketPrice'] >>> s = ed.DataFrame('http://localhost:9200', 'flights')['AvgTicketPrice']
>>> int(s.std()) >>> int(s.std())
266 266
""" """
@ -1599,7 +1599,7 @@ class Series(NDFrame):
Examples Examples
-------- --------
>>> s = ed.DataFrame('localhost', 'flights')['AvgTicketPrice'] >>> s = ed.DataFrame('http://localhost:9200', 'flights')['AvgTicketPrice']
>>> int(s.mad()) >>> int(s.mad())
213 213
""" """
@ -1627,7 +1627,7 @@ class Series(NDFrame):
Examples Examples
-------- --------
>>> df = ed.DataFrame('localhost', 'flights') # ignoring percentiles as they don't generate consistent results >>> df = ed.DataFrame('http://localhost:9200', 'flights') # ignoring percentiles as they don't generate consistent results
>>> df.AvgTicketPrice.describe() # doctest: +SKIP >>> df.AvgTicketPrice.describe() # doctest: +SKIP
count 13059.000000 count 13059.000000
mean 628.253689 mean 628.253689
@ -1660,7 +1660,7 @@ class Series(NDFrame):
Examples Examples
-------- --------
>>> ed_s = ed.Series('localhost', 'flights', name='Carrier').head(5) >>> ed_s = ed.Series('http://localhost:9200', 'flights', name='Carrier').head(5)
>>> pd_s = ed.eland_to_pandas(ed_s) >>> pd_s = ed.eland_to_pandas(ed_s)
>>> print(f"type(ed_s)={type(ed_s)}\\ntype(pd_s)={type(pd_s)}") >>> print(f"type(ed_s)={type(ed_s)}\\ntype(pd_s)={type(pd_s)}")
type(ed_s)=<class 'eland.series.Series'> type(ed_s)=<class 'eland.series.Series'>

View File

@ -71,19 +71,19 @@ def lint(session):
# Install numpy to use its mypy plugin # Install numpy to use its mypy plugin
# https://numpy.org/devdocs/reference/typing.html#mypy-plugin # https://numpy.org/devdocs/reference/typing.html#mypy-plugin
session.install("black", "flake8", "mypy", "isort", "numpy") session.install("black", "flake8", "mypy", "isort", "numpy")
session.install("--pre", "elasticsearch") session.install("--pre", "elasticsearch>=8.0.0a1,<9")
session.run("python", "utils/license-headers.py", "check", *SOURCE_FILES) session.run("python", "utils/license-headers.py", "check", *SOURCE_FILES)
session.run("black", "--check", "--target-version=py37", *SOURCE_FILES) session.run("black", "--check", "--target-version=py37", *SOURCE_FILES)
session.run("isort", "--check", "--profile=black", *SOURCE_FILES) session.run("isort", "--check", "--profile=black", *SOURCE_FILES)
session.run("flake8", "--ignore=E501,W503,E402,E712,E203", *SOURCE_FILES) session.run("flake8", "--ignore=E501,W503,E402,E712,E203", *SOURCE_FILES)
# TODO: When all files are typed we can change this to .run("mypy", "--strict", "eland/") # TODO: When all files are typed we can change this to .run("mypy", "--strict", "eland/")
session.log("mypy --strict eland/") session.log("mypy --show-error-codes --strict eland/")
for typed_file in TYPED_FILES: for typed_file in TYPED_FILES:
if not os.path.isfile(typed_file): if not os.path.isfile(typed_file):
session.error(f"The file {typed_file!r} couldn't be found") session.error(f"The file {typed_file!r} couldn't be found")
process = subprocess.run( process = subprocess.run(
["mypy", "--strict", typed_file], ["mypy", "--show-error-codes", "--strict", typed_file],
env=session.env, env=session.env,
stdout=subprocess.PIPE, stdout=subprocess.PIPE,
stderr=subprocess.STDOUT, stderr=subprocess.STDOUT,
@ -100,14 +100,15 @@ def lint(session):
session.error("\n" + "\n".join(sorted(set(errors)))) session.error("\n" + "\n".join(sorted(set(errors))))
@nox.session(python=["3.7", "3.8", "3.9"]) @nox.session(python=["3.7", "3.8", "3.9", "3.10"])
@nox.parametrize("pandas_version", ["1.2.0", "1.3.0"]) @nox.parametrize("pandas_version", ["1.2.0", "1.3.0"])
def test(session, pandas_version: str): def test(session, pandas_version: str):
session.install("-r", "requirements-dev.txt") session.install("-r", "requirements-dev.txt")
session.install(".") session.install(".")
session.run("python", "-m", "pip", "install", f"pandas~={pandas_version}") session.run("python", "-m", "pip", "install", f"pandas~={pandas_version}")
session.run("python", "-m", "tests.setup_tests") session.run("python", "-m", "tests.setup_tests")
session.run(
pytest_args = (
"python", "python",
"-m", "-m",
"pytest", "pytest",
@ -116,6 +117,13 @@ def test(session, pandas_version: str):
"--cov-config=setup.cfg", "--cov-config=setup.cfg",
"--doctest-modules", "--doctest-modules",
"--nbval", "--nbval",
)
# PyTorch doesn't support Python 3.10 yet
if session.python == "3.10":
pytest_args += ("--ignore=eland/ml/pytorch",)
session.run(
*pytest_args,
*(session.posargs or ("eland/", "tests/")), *(session.posargs or ("eland/", "tests/")),
) )
@ -144,21 +152,25 @@ def docs(session):
# See if we have an Elasticsearch cluster active # See if we have an Elasticsearch cluster active
# to rebuild the Jupyter notebooks with. # to rebuild the Jupyter notebooks with.
es_active = False
try: try:
import elasticsearch from elasticsearch import ConnectionError, Elasticsearch
es = elasticsearch.Elasticsearch("localhost:9200") try:
es.info() es = Elasticsearch("http://localhost:9200")
if not es.indices.exists("flights"): es.info()
session.run("python", "-m", "tests.setup_tests") if not es.indices.exists(index="flights"):
es_active = True session.run("python", "-m", "tests.setup_tests")
except Exception: es_active = True
es_active = False except ConnectionError:
pass
except ImportError:
pass
# Rebuild all the example notebooks inplace # Rebuild all the example notebooks inplace
if es_active: if es_active:
session.install("jupyter-client", "ipykernel") session.install("jupyter-client", "ipykernel")
for filename in os.listdir(BASE_DIR / "docs/source/examples"): for filename in os.listdir(BASE_DIR / "docs/sphinx/examples"):
if ( if (
filename.endswith(".ipynb") filename.endswith(".ipynb")
and filename != "introduction_to_eland_webinar.ipynb" and filename != "introduction_to_eland_webinar.ipynb"
@ -170,7 +182,7 @@ def docs(session):
"notebook", "notebook",
"--inplace", "--inplace",
"--execute", "--execute",
str(BASE_DIR / "docs/source/examples" / filename), str(BASE_DIR / "docs/sphinx/examples" / filename),
) )
session.cd("docs") session.cd("docs")

View File

@ -1,4 +1,4 @@
elasticsearch>=7.7 elasticsearch>=8.0.0a1,<9
pandas>=1.2.0 pandas>=1.2.0
matplotlib matplotlib
pytest>=5.2.1 pytest>=5.2.1
@ -11,6 +11,9 @@ nox
lightgbm lightgbm
pytest-cov pytest-cov
mypy mypy
sentence-transformers>=2.1.0 huggingface-hub>=0.0.17
torch>=1.9.0
transformers[torch]>=4.12.0 # Torch doesn't support Python 3.10 yet (pytorch/pytorch#66424)
sentence-transformers>=2.1.0; python_version<'3.10'
torch>=1.9.0; python_version<'3.10'
transformers[torch]>=4.12.0; python_version<'3.10'

View File

@ -1,3 +1,3 @@
elasticsearch>=7.7 elasticsearch>=8.0.0a1,<9
pandas>=1 pandas>=1
matplotlib matplotlib

View File

@ -82,7 +82,7 @@ setup(
keywords="elastic eland pandas python", keywords="elastic eland pandas python",
packages=find_packages(include=["eland", "eland.*"]), packages=find_packages(include=["eland", "eland.*"]),
install_requires=[ install_requires=[
"elasticsearch>=7.11,<8", "elasticsearch>=8.0.0a1,<9",
"pandas>=1.2,<1.4", "pandas>=1.2,<1.4",
"matplotlib", "matplotlib",
"numpy", "numpy",

View File

@ -25,17 +25,12 @@ from eland.common import es_version
ROOT_DIR = os.path.dirname(os.path.abspath(__file__)) ROOT_DIR = os.path.dirname(os.path.abspath(__file__))
# Define test files and indices # Define test files and indices
ELASTICSEARCH_HOST = os.environ.get("ELASTICSEARCH_HOST") or "localhost" ELASTICSEARCH_HOST = os.environ.get(
"ELASTICSEARCH_URL", os.environ.get("ELASTICSEARCH_HOST", "http://localhost:9200")
)
# Define client to use in tests # Define client to use in tests
TEST_SUITE = os.environ.get("TEST_SUITE", "xpack") ES_TEST_CLIENT = Elasticsearch(ELASTICSEARCH_HOST)
if TEST_SUITE == "xpack":
ES_TEST_CLIENT = Elasticsearch(
ELASTICSEARCH_HOST,
http_auth=("elastic", "changeme"),
)
else:
ES_TEST_CLIENT = Elasticsearch(ELASTICSEARCH_HOST)
ES_VERSION = es_version(ES_TEST_CLIENT) ES_VERSION = es_version(ES_TEST_CLIENT)

View File

@ -42,7 +42,7 @@ class TestDataFrameDateTime(TestData):
usually contains tests). usually contains tests).
""" """
es = ES_TEST_CLIENT es = ES_TEST_CLIENT
if es.indices.exists(cls.time_index_name): if es.indices.exists(index=cls.time_index_name):
es.indices.delete(index=cls.time_index_name) es.indices.delete(index=cls.time_index_name)
dts = [datetime.strptime(time, "%Y-%m-%dT%H:%M:%S.%f%z") for time in cls.times] dts = [datetime.strptime(time, "%Y-%m-%dT%H:%M:%S.%f%z") for time in cls.times]
@ -58,11 +58,11 @@ class TestDataFrameDateTime(TestData):
body = {"mappings": mappings} body = {"mappings": mappings}
index = "test_time_formats" index = "test_time_formats"
es.indices.delete(index=index, ignore=[400, 404]) es.options(ignore_status=[400, 404]).indices.delete(index=index)
es.indices.create(index=index, body=body) es.indices.create(index=index, body=body)
for i, time_formats in enumerate(time_formats_docs): for i, time_formats in enumerate(time_formats_docs):
es.index(index=index, body=time_formats, id=i) es.index(index=index, id=i, document=time_formats)
es.indices.refresh(index=index) es.indices.refresh(index=index)
@classmethod @classmethod

View File

@ -69,7 +69,7 @@ class TestDataFrameQuery(TestData):
assert_pandas_eland_frame_equal(pd_q4, ed_q4) assert_pandas_eland_frame_equal(pd_q4, ed_q4)
ES_TEST_CLIENT.indices.delete(index_name) ES_TEST_CLIENT.indices.delete(index=index_name)
def test_simple_query(self): def test_simple_query(self):
ed_flights = self.ed_flights() ed_flights = self.ed_flights()
@ -141,4 +141,4 @@ class TestDataFrameQuery(TestData):
assert_pandas_eland_frame_equal(pd_q4, ed_q4) assert_pandas_eland_frame_equal(pd_q4, ed_q4)
ES_TEST_CLIENT.indices.delete(index_name) ES_TEST_CLIENT.indices.delete(index=index_name)

View File

@ -99,7 +99,7 @@ class TestDataFrameToCSV(TestData):
print(pd_flights_from_csv.head()) print(pd_flights_from_csv.head())
# clean up index # clean up index
ES_TEST_CLIENT.indices.delete(test_index) ES_TEST_CLIENT.indices.delete(index=test_index)
def test_pd_to_csv_without_filepath(self): def test_pd_to_csv_without_filepath(self):

View File

@ -122,7 +122,7 @@ class TestDataFrameUtils(TestData):
} }
} }
mapping = ES_TEST_CLIENT.indices.get_mapping(index_name) mapping = ES_TEST_CLIENT.indices.get_mapping(index=index_name)
assert expected_mapping == mapping assert expected_mapping == mapping

View File

@ -195,7 +195,7 @@ class TestPandasToEland:
) )
# Assert that the value 128 caused the index error # Assert that the value 128 caused the index error
assert "Value [128] is out of range for a byte" in str(e.value) assert "Value [128] is out of range for a byte" in str(e.value.errors)
def test_pandas_to_eland_text_inserts_keyword(self): def test_pandas_to_eland_text_inserts_keyword(self):
es = ES_TEST_CLIENT es = ES_TEST_CLIENT

View File

@ -32,7 +32,7 @@ class TestDateTime(TestData):
usually contains tests). usually contains tests).
""" """
es = ES_TEST_CLIENT es = ES_TEST_CLIENT
if es.indices.exists(cls.time_index_name): if es.indices.exists(index=cls.time_index_name):
es.indices.delete(index=cls.time_index_name) es.indices.delete(index=cls.time_index_name)
dts = [datetime.strptime(time, "%Y-%m-%dT%H:%M:%S.%f%z") for time in cls.times] dts = [datetime.strptime(time, "%Y-%m-%dT%H:%M:%S.%f%z") for time in cls.times]
@ -46,13 +46,12 @@ class TestDateTime(TestData):
mappings["properties"][field_name]["type"] = "date" mappings["properties"][field_name]["type"] = "date"
mappings["properties"][field_name]["format"] = field_name mappings["properties"][field_name]["format"] = field_name
body = {"mappings": mappings}
index = "test_time_formats" index = "test_time_formats"
es.indices.delete(index=index, ignore=[400, 404]) es.options(ignore_status=[400, 404]).indices.delete(index=index)
es.indices.create(index=index, body=body) es.indices.create(index=index, mappings=mappings)
for i, time_formats in enumerate(time_formats_docs): for i, time_formats in enumerate(time_formats_docs):
es.index(index=index, body=time_formats, id=i) es.index(index=index, id=i, document=time_formats)
es.indices.refresh(index=index) es.indices.refresh(index=index)
@classmethod @classmethod

View File

@ -90,5 +90,5 @@ class TestPytorchModel:
def test_text_classification(self, model_id, task, text_input, value): def test_text_classification(self, model_id, task, text_input, value):
with tempfile.TemporaryDirectory() as tmp_dir: with tempfile.TemporaryDirectory() as tmp_dir:
ptm = download_model_and_start_deployment(tmp_dir, True, model_id, task) ptm = download_model_and_start_deployment(tmp_dir, True, model_id, task)
result = ptm.infer({"docs": [{"text_field": text_input}]}) result = ptm.infer(docs=[{"text_field": text_input}])
assert result["predicted_value"] == value assert result["predicted_value"] == value

File diff suppressed because one or more lines are too long

View File

@ -18,7 +18,9 @@
"outputs": [ "outputs": [
{ {
"data": { "data": {
"text/plain": "False" "text/plain": [
"False"
]
}, },
"execution_count": 2, "execution_count": 2,
"metadata": {}, "metadata": {},
@ -27,7 +29,7 @@
], ],
"source": [ "source": [
"es = Elasticsearch()\n", "es = Elasticsearch()\n",
"ed_df = ed.DataFrame('localhost', 'flights', columns = [\"AvgTicketPrice\", \"Cancelled\", \"dayOfWeek\", \"timestamp\", \"DestCountry\"])\n", "ed_df = ed.DataFrame('http://localhost:9200', 'flights', columns = [\"AvgTicketPrice\", \"Cancelled\", \"dayOfWeek\", \"timestamp\", \"DestCountry\"])\n",
"es.indices.exists(index=\"churn\")" "es.indices.exists(index=\"churn\")"
] ]
}, },
@ -57,7 +59,9 @@
"outputs": [ "outputs": [
{ {
"data": { "data": {
"text/plain": "pandas.core.frame.DataFrame" "text/plain": [
"pandas.core.frame.DataFrame"
]
}, },
"execution_count": 4, "execution_count": 4,
"metadata": {}, "metadata": {},
@ -75,8 +79,125 @@
"outputs": [ "outputs": [
{ {
"data": { "data": {
"text/plain": " account length area code churn customer service calls \\\n0 128 415 0 1 \n1 107 415 0 1 \n\n international plan number vmail messages phone number state \\\n0 no 25 382-4657 KS \n1 no 26 371-7191 OH \n\n total day calls total day charge ... total eve calls total eve charge \\\n0 110 45.07 ... 99 16.78 \n1 123 27.47 ... 103 16.62 \n\n total eve minutes total intl calls total intl charge total intl minutes \\\n0 197.4 3 2.7 10.0 \n1 195.5 3 3.7 13.7 \n\n total night calls total night charge total night minutes voice mail plan \n0 91 11.01 244.7 yes \n1 103 11.45 254.4 yes \n\n[2 rows x 21 columns]", "text/html": [
"text/html": "<div>\n<style scoped>\n .dataframe tbody tr th:only-of-type {\n vertical-align: middle;\n }\n\n .dataframe tbody tr th {\n vertical-align: top;\n }\n\n .dataframe thead th {\n text-align: right;\n }\n</style>\n<table border=\"1\" class=\"dataframe\">\n <thead>\n <tr style=\"text-align: right;\">\n <th></th>\n <th>account length</th>\n <th>area code</th>\n <th>churn</th>\n <th>customer service calls</th>\n <th>international plan</th>\n <th>number vmail messages</th>\n <th>phone number</th>\n <th>state</th>\n <th>total day calls</th>\n <th>total day charge</th>\n <th>...</th>\n <th>total eve calls</th>\n <th>total eve charge</th>\n <th>total eve minutes</th>\n <th>total intl calls</th>\n <th>total intl charge</th>\n <th>total intl minutes</th>\n <th>total night calls</th>\n <th>total night charge</th>\n <th>total night minutes</th>\n <th>voice mail plan</th>\n </tr>\n </thead>\n <tbody>\n <tr>\n <th>0</th>\n <td>128</td>\n <td>415</td>\n <td>0</td>\n <td>1</td>\n <td>no</td>\n <td>25</td>\n <td>382-4657</td>\n <td>KS</td>\n <td>110</td>\n <td>45.07</td>\n <td>...</td>\n <td>99</td>\n <td>16.78</td>\n <td>197.4</td>\n <td>3</td>\n <td>2.7</td>\n <td>10.0</td>\n <td>91</td>\n <td>11.01</td>\n <td>244.7</td>\n <td>yes</td>\n </tr>\n <tr>\n <th>1</th>\n <td>107</td>\n <td>415</td>\n <td>0</td>\n <td>1</td>\n <td>no</td>\n <td>26</td>\n <td>371-7191</td>\n <td>OH</td>\n <td>123</td>\n <td>27.47</td>\n <td>...</td>\n <td>103</td>\n <td>16.62</td>\n <td>195.5</td>\n <td>3</td>\n <td>3.7</td>\n <td>13.7</td>\n <td>103</td>\n <td>11.45</td>\n <td>254.4</td>\n <td>yes</td>\n </tr>\n </tbody>\n</table>\n</div>\n<p>2 rows × 21 columns</p>" "<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>account length</th>\n",
" <th>area code</th>\n",
" <th>churn</th>\n",
" <th>customer service calls</th>\n",
" <th>international plan</th>\n",
" <th>number vmail messages</th>\n",
" <th>phone number</th>\n",
" <th>state</th>\n",
" <th>total day calls</th>\n",
" <th>total day charge</th>\n",
" <th>...</th>\n",
" <th>total eve calls</th>\n",
" <th>total eve charge</th>\n",
" <th>total eve minutes</th>\n",
" <th>total intl calls</th>\n",
" <th>total intl charge</th>\n",
" <th>total intl minutes</th>\n",
" <th>total night calls</th>\n",
" <th>total night charge</th>\n",
" <th>total night minutes</th>\n",
" <th>voice mail plan</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>128</td>\n",
" <td>415</td>\n",
" <td>0</td>\n",
" <td>1</td>\n",
" <td>no</td>\n",
" <td>25</td>\n",
" <td>382-4657</td>\n",
" <td>KS</td>\n",
" <td>110</td>\n",
" <td>45.07</td>\n",
" <td>...</td>\n",
" <td>99</td>\n",
" <td>16.78</td>\n",
" <td>197.4</td>\n",
" <td>3</td>\n",
" <td>2.7</td>\n",
" <td>10.0</td>\n",
" <td>91</td>\n",
" <td>11.01</td>\n",
" <td>244.7</td>\n",
" <td>yes</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1</th>\n",
" <td>107</td>\n",
" <td>415</td>\n",
" <td>0</td>\n",
" <td>1</td>\n",
" <td>no</td>\n",
" <td>26</td>\n",
" <td>371-7191</td>\n",
" <td>OH</td>\n",
" <td>123</td>\n",
" <td>27.47</td>\n",
" <td>...</td>\n",
" <td>103</td>\n",
" <td>16.62</td>\n",
" <td>195.5</td>\n",
" <td>3</td>\n",
" <td>3.7</td>\n",
" <td>13.7</td>\n",
" <td>103</td>\n",
" <td>11.45</td>\n",
" <td>254.4</td>\n",
" <td>yes</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"</div>\n",
"<p>2 rows × 21 columns</p>"
],
"text/plain": [
" account length area code churn customer service calls \\\n",
"0 128 415 0 1 \n",
"1 107 415 0 1 \n",
"\n",
" international plan number vmail messages phone number state \\\n",
"0 no 25 382-4657 KS \n",
"1 no 26 371-7191 OH \n",
"\n",
" total day calls total day charge ... total eve calls total eve charge \\\n",
"0 110 45.07 ... 99 16.78 \n",
"1 123 27.47 ... 103 16.62 \n",
"\n",
" total eve minutes total intl calls total intl charge total intl minutes \\\n",
"0 197.4 3 2.7 10.0 \n",
"1 195.5 3 3.7 13.7 \n",
"\n",
" total night calls total night charge total night minutes voice mail plan \n",
"0 91 11.01 244.7 yes \n",
"1 103 11.45 254.4 yes \n",
"\n",
"[2 rows x 21 columns]"
]
}, },
"execution_count": 5, "execution_count": 5,
"metadata": {}, "metadata": {},
@ -85,7 +206,7 @@
], ],
"source": [ "source": [
"# NBVAL_IGNORE_OUTPUT\n", "# NBVAL_IGNORE_OUTPUT\n",
"ed.csv_to_eland(\"./test_churn.csv\", es_client='localhost', es_dest_index='churn', es_refresh=True, index_col=0)" "ed.csv_to_eland(\"./test_churn.csv\", es_client='http://localhost:9200', es_dest_index='churn', es_refresh=True, index_col=0)"
] ]
}, },
{ {
@ -95,7 +216,37 @@
"outputs": [ "outputs": [
{ {
"data": { "data": {
"text/plain": "{'took': 0,\n 'timed_out': False,\n '_shards': {'total': 1, 'successful': 1, 'skipped': 0, 'failed': 0},\n 'hits': {'total': {'value': 2, 'relation': 'eq'},\n 'max_score': 1.0,\n 'hits': [{'_index': 'churn',\n '_id': '0',\n '_score': 1.0,\n '_source': {'state': 'KS',\n 'account length': 128,\n 'area code': 415,\n 'phone number': '382-4657',\n 'international plan': 'no',\n 'voice mail plan': 'yes',\n 'number vmail messages': 25,\n 'total day minutes': 265.1,\n 'total day calls': 110,\n 'total day charge': 45.07,\n 'total eve minutes': 197.4,\n 'total eve calls': 99,\n 'total eve charge': 16.78,\n 'total night minutes': 244.7,\n 'total night calls': 91,\n 'total night charge': 11.01,\n 'total intl minutes': 10.0,\n 'total intl calls': 3,\n 'total intl charge': 2.7,\n 'customer service calls': 1,\n 'churn': 0}}]}}" "text/plain": [
"{'took': 0,\n",
" 'timed_out': False,\n",
" '_shards': {'total': 1, 'successful': 1, 'skipped': 0, 'failed': 0},\n",
" 'hits': {'total': {'value': 2, 'relation': 'eq'},\n",
" 'max_score': 1.0,\n",
" 'hits': [{'_index': 'churn',\n",
" '_id': '0',\n",
" '_score': 1.0,\n",
" '_source': {'state': 'KS',\n",
" 'account length': 128,\n",
" 'area code': 415,\n",
" 'phone number': '382-4657',\n",
" 'international plan': 'no',\n",
" 'voice mail plan': 'yes',\n",
" 'number vmail messages': 25,\n",
" 'total day minutes': 265.1,\n",
" 'total day calls': 110,\n",
" 'total day charge': 45.07,\n",
" 'total eve minutes': 197.4,\n",
" 'total eve calls': 99,\n",
" 'total eve charge': 16.78,\n",
" 'total night minutes': 244.7,\n",
" 'total night calls': 91,\n",
" 'total night charge': 11.01,\n",
" 'total intl minutes': 10.0,\n",
" 'total intl calls': 3,\n",
" 'total intl charge': 2.7,\n",
" 'customer service calls': 1,\n",
" 'churn': 0}}]}}"
]
}, },
"execution_count": 6, "execution_count": 6,
"metadata": {}, "metadata": {},
@ -114,7 +265,9 @@
"outputs": [ "outputs": [
{ {
"data": { "data": {
"text/plain": "{'acknowledged': True}" "text/plain": [
"{'acknowledged': True}"
]
}, },
"execution_count": 7, "execution_count": 7,
"metadata": {}, "metadata": {},
@ -147,4 +300,4 @@
}, },
"nbformat": 4, "nbformat": 4,
"nbformat_minor": 4 "nbformat_minor": 4
} }

View File

@ -22,7 +22,7 @@
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
"ed_df = ed.DataFrame('localhost', 'flights', columns=[\"AvgTicketPrice\", \"Cancelled\", \"dayOfWeek\", \"timestamp\", \"DestCountry\"])" "ed_df = ed.DataFrame('http://localhost:9200', 'flights', columns=[\"AvgTicketPrice\", \"Cancelled\", \"dayOfWeek\", \"timestamp\", \"DestCountry\"])"
] ]
}, },
{ {
@ -32,7 +32,13 @@
"outputs": [ "outputs": [
{ {
"data": { "data": {
"text/plain": "AvgTicketPrice 640.387285\nCancelled False\ndayOfWeek 3\ntimestamp 2018-01-21 23:43:19.256498944\ndtype: object" "text/plain": [
"AvgTicketPrice 640.387285\n",
"Cancelled False\n",
"dayOfWeek 3\n",
"timestamp 2018-01-21 23:43:19.256498944\n",
"dtype: object"
]
}, },
"execution_count": 3, "execution_count": 3,
"metadata": {}, "metadata": {},
@ -51,7 +57,12 @@
"outputs": [ "outputs": [
{ {
"data": { "data": {
"text/plain": "AvgTicketPrice 640.387285\nCancelled 0.000000\ndayOfWeek 3.000000\ndtype: float64" "text/plain": [
"AvgTicketPrice 640.387285\n",
"Cancelled 0.000000\n",
"dayOfWeek 3.000000\n",
"dtype: float64"
]
}, },
"execution_count": 4, "execution_count": 4,
"metadata": {}, "metadata": {},
@ -70,7 +81,14 @@
"outputs": [ "outputs": [
{ {
"data": { "data": {
"text/plain": "AvgTicketPrice 640.387285\nCancelled False\ndayOfWeek 3\ntimestamp 2018-01-21 23:43:19.256498944\nDestCountry NaN\ndtype: object" "text/plain": [
"AvgTicketPrice 640.387285\n",
"Cancelled False\n",
"dayOfWeek 3\n",
"timestamp 2018-01-21 23:43:19.256498944\n",
"DestCountry NaN\n",
"dtype: object"
]
}, },
"execution_count": 5, "execution_count": 5,
"metadata": {}, "metadata": {},
@ -89,7 +107,11 @@
"outputs": [ "outputs": [
{ {
"data": { "data": {
"text/plain": "AvgTicketPrice 213.430365\ndayOfWeek 2.000000\ndtype: float64" "text/plain": [
"AvgTicketPrice 213.430365\n",
"dayOfWeek 2.000000\n",
"dtype: float64"
]
}, },
"execution_count": 6, "execution_count": 6,
"metadata": {}, "metadata": {},
@ -108,7 +130,11 @@
"outputs": [ "outputs": [
{ {
"data": { "data": {
"text/plain": "AvgTicketPrice 213.430365\ndayOfWeek 2.000000\ndtype: float64" "text/plain": [
"AvgTicketPrice 213.430365\n",
"dayOfWeek 2.000000\n",
"dtype: float64"
]
}, },
"execution_count": 7, "execution_count": 7,
"metadata": {}, "metadata": {},
@ -127,7 +153,14 @@
"outputs": [ "outputs": [
{ {
"data": { "data": {
"text/plain": "AvgTicketPrice 213.430365\nCancelled NaN\ndayOfWeek 2.0\ntimestamp NaT\nDestCountry NaN\ndtype: object" "text/plain": [
"AvgTicketPrice 213.430365\n",
"Cancelled NaN\n",
"dayOfWeek 2.0\n",
"timestamp NaT\n",
"DestCountry NaN\n",
"dtype: object"
]
}, },
"execution_count": 8, "execution_count": 8,
"metadata": {}, "metadata": {},
@ -161,4 +194,4 @@
}, },
"nbformat": 4, "nbformat": 4,
"nbformat_minor": 4 "nbformat_minor": 4
} }

File diff suppressed because one or more lines are too long

View File

@ -17,8 +17,8 @@
import pandas as pd import pandas as pd
from elasticsearch import helpers from elasticsearch import helpers
from elasticsearch._sync.client import Elasticsearch
from eland.common import es_version
from tests import ( from tests import (
ECOMMERCE_FILE_NAME, ECOMMERCE_FILE_NAME,
ECOMMERCE_INDEX_NAME, ECOMMERCE_INDEX_NAME,
@ -53,9 +53,9 @@ def _setup_data(es):
# Delete index # Delete index
print("Deleting index:", index_name) print("Deleting index:", index_name)
es.indices.delete(index=index_name, ignore=[400, 404]) es.options(ignore_status=[400, 404]).indices.delete(index=index_name)
print("Creating index:", index_name) print("Creating index:", index_name)
es.indices.create(index=index_name, body=mapping) es.indices.create(index=index_name, **mapping)
df = pd.read_json(json_file_name, lines=True) df = pd.read_json(json_file_name, lines=True)
@ -85,30 +85,28 @@ def _setup_data(es):
print("Done", index_name) print("Done", index_name)
def _update_max_compilations_limit(es, limit="10000/1m"): def _update_max_compilations_limit(es: Elasticsearch, limit="10000/1m"):
print("Updating script.max_compilations_rate to ", limit) print("Updating script.max_compilations_rate to ", limit)
if es_version(es) < (7, 8): es.cluster.put_settings(
body = {"transient": {"script.max_compilations_rate": limit}} transient={
else: "script.max_compilations_rate": "use-context",
body = { "script.context.field.max_compilations_rate": limit,
"transient": {
"script.max_compilations_rate": "use-context",
"script.context.field.max_compilations_rate": limit,
}
} }
es.cluster.put_settings(body=body) )
def _setup_test_mappings(es): def _setup_test_mappings(es: Elasticsearch):
# Create a complex mapping containing many Elasticsearch features # Create a complex mapping containing many Elasticsearch features
es.indices.delete(index=TEST_MAPPING1_INDEX_NAME, ignore=[400, 404]) es.options(ignore_status=[400, 404]).indices.delete(index=TEST_MAPPING1_INDEX_NAME)
es.indices.create(index=TEST_MAPPING1_INDEX_NAME, body=TEST_MAPPING1) es.indices.create(index=TEST_MAPPING1_INDEX_NAME, **TEST_MAPPING1)
def _setup_test_nested(es): def _setup_test_nested(es):
es.indices.delete(index=TEST_NESTED_USER_GROUP_INDEX_NAME, ignore=[400, 404]) es.options(ignore_status=[400, 404]).indices.delete(
index=TEST_NESTED_USER_GROUP_INDEX_NAME
)
es.indices.create( es.indices.create(
index=TEST_NESTED_USER_GROUP_INDEX_NAME, body=TEST_NESTED_USER_GROUP_MAPPING index=TEST_NESTED_USER_GROUP_INDEX_NAME, **TEST_NESTED_USER_GROUP_MAPPING
) )
helpers.bulk(es, TEST_NESTED_USER_GROUP_DOCS) helpers.bulk(es, TEST_NESTED_USER_GROUP_DOCS)