mirror of
https://github.com/elastic/eland.git
synced 2025-07-11 00:02:14 +08:00
[ML] Better memory estimation for NLP models (#568)
This PR adds an ability to estimate per deployment and per allocation memory usage of NLP transformer models. It uses torch.profiler and performs logs the peak memory usage during the inference. This information is then used in Elasticsearch to provision models with sufficient memory (elastic/elasticsearch#98874).
This commit is contained in:
parent
28e6d92430
commit
6cecb454e3
@ -16,7 +16,12 @@ fi
|
||||
|
||||
set -euxo pipefail
|
||||
|
||||
SCRIPT_PATH=$(dirname $(realpath -s $0))
|
||||
# realpath on MacOS use different flags than on Linux
|
||||
if [[ "$OSTYPE" == "darwin"* ]]; then
|
||||
SCRIPT_PATH=$(dirname $(realpath $0))
|
||||
else
|
||||
SCRIPT_PATH=$(dirname $(realpath -s $0))
|
||||
fi
|
||||
|
||||
moniker=$(echo "$ELASTICSEARCH_VERSION" | tr -C "[:alnum:]" '-')
|
||||
suffix=rest-test
|
||||
@ -132,7 +137,7 @@ url="http://elastic:$ELASTIC_PASSWORD@$NODE_NAME"
|
||||
docker_pull_attempts=0
|
||||
until [ "$docker_pull_attempts" -ge 5 ]
|
||||
do
|
||||
docker pull docker.elastic.co/elasticsearch/"$ELASTICSEARCH_VERSION" && break
|
||||
docker pull docker.elastic.co/elasticsearch/$ELASTICSEARCH_VERSION && break
|
||||
docker_pull_attempts=$((docker_pull_attempts+1))
|
||||
sleep 10
|
||||
done
|
||||
|
@ -169,7 +169,7 @@ currently using a minimum version of PyCharm 2019.2.4.
|
||||
* Setup Elasticsearch instance with docker
|
||||
|
||||
``` bash
|
||||
> ELASTICSEARCH_VERSION=elasticsearch:7.x-SNAPSHOT .ci/run-elasticsearch.sh
|
||||
> ELASTICSEARCH_VERSION=elasticsearch:8.x-SNAPSHOT BUILDKITE=false .buildkite/run-elasticsearch.sh
|
||||
```
|
||||
|
||||
* Now check `http://localhost:9200`
|
||||
|
@ -66,3 +66,7 @@ class TraceableModel(ABC):
|
||||
trace_model = torch.jit.freeze(trace_model)
|
||||
torch.jit.save(trace_model, model_path)
|
||||
return model_path
|
||||
|
||||
@property
|
||||
def model(self) -> nn.Module:
|
||||
return self._model
|
||||
|
@ -22,6 +22,7 @@ libraries such as sentence-transformers.
|
||||
|
||||
import json
|
||||
import os.path
|
||||
import random
|
||||
import re
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any, Dict, List, Optional, Set, Tuple, Union
|
||||
@ -30,6 +31,7 @@ import torch # type: ignore
|
||||
import transformers # type: ignore
|
||||
from sentence_transformers import SentenceTransformer # type: ignore
|
||||
from torch import Tensor, nn
|
||||
from torch.profiler import profile # type: ignore
|
||||
from transformers import (
|
||||
AutoConfig,
|
||||
AutoModel,
|
||||
@ -270,8 +272,8 @@ class _DistilBertWrapper(nn.Module): # type: ignore
|
||||
self,
|
||||
input_ids: Tensor,
|
||||
attention_mask: Tensor,
|
||||
_token_type_ids: Tensor,
|
||||
_position_ids: Tensor,
|
||||
_token_type_ids: Tensor = None,
|
||||
_position_ids: Tensor = None,
|
||||
) -> Tensor:
|
||||
"""Wrap the input and output to conform to the native process interface."""
|
||||
|
||||
@ -769,6 +771,18 @@ class TransformerModel:
|
||||
tokenization=tokenization_config
|
||||
)
|
||||
|
||||
# add static and dynamic memory state size to metadata
|
||||
per_deployment_memory_bytes = self._get_per_deployment_memory()
|
||||
|
||||
per_allocation_memory_bytes = self._get_per_allocation_memory(
|
||||
tokenization_config.max_sequence_length, 1
|
||||
)
|
||||
|
||||
metadata = {
|
||||
"per_deployment_memory_bytes": per_deployment_memory_bytes,
|
||||
"per_allocation_memory_bytes": per_allocation_memory_bytes,
|
||||
}
|
||||
|
||||
return NlpTrainedModelConfig(
|
||||
description=f"Model {self._model_id} for task type '{self._task_type}'",
|
||||
model_type="pytorch",
|
||||
@ -776,6 +790,127 @@ class TransformerModel:
|
||||
input=TrainedModelInput(
|
||||
field_names=["text_field"],
|
||||
),
|
||||
metadata=metadata,
|
||||
)
|
||||
|
||||
def _get_per_deployment_memory(self) -> float:
|
||||
"""
|
||||
Returns the static memory size of the model in bytes.
|
||||
"""
|
||||
psize: float = sum(
|
||||
param.nelement() * param.element_size()
|
||||
for param in self._traceable_model.model.parameters()
|
||||
)
|
||||
bsize: float = sum(
|
||||
buffer.nelement() * buffer.element_size()
|
||||
for buffer in self._traceable_model.model.buffers()
|
||||
)
|
||||
return psize + bsize
|
||||
|
||||
def _get_per_allocation_memory(
|
||||
self, max_seq_length: Optional[int], batch_size: int
|
||||
) -> float:
|
||||
"""
|
||||
Returns the transient memory size of the model in bytes.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
max_seq_length : Optional[int]
|
||||
Maximum sequence length to use for the model.
|
||||
batch_size : int
|
||||
Batch size to use for the model.
|
||||
"""
|
||||
activities = [torch.profiler.ProfilerActivity.CPU]
|
||||
|
||||
# Get the memory usage of the model with a batch size of 1.
|
||||
inputs_1 = self._get_model_inputs(max_seq_length, 1)
|
||||
with profile(activities=activities, profile_memory=True) as prof:
|
||||
self._traceable_model.model(*inputs_1)
|
||||
mem1: float = prof.key_averages().total_average().cpu_memory_usage
|
||||
|
||||
# This is measuring memory usage of the model with a batch size of 2 and
|
||||
# then linearly extrapolating it to get the memory usage of the model for
|
||||
# a batch size of batch_size.
|
||||
if batch_size == 1:
|
||||
return mem1
|
||||
inputs_2 = self._get_model_inputs(max_seq_length, 2)
|
||||
with profile(activities=activities, profile_memory=True) as prof:
|
||||
self._traceable_model.model(*inputs_2)
|
||||
mem2: float = prof.key_averages().total_average().cpu_memory_usage
|
||||
return mem1 + (mem2 - mem1) * (batch_size - 1)
|
||||
|
||||
def _get_model_inputs(
|
||||
self,
|
||||
max_length: Optional[int],
|
||||
batch_size: int,
|
||||
) -> Tuple[Tensor, ...]:
|
||||
"""
|
||||
Returns a random batch of inputs for the model.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
max_length : Optional[int]
|
||||
Maximum sequence length to use for the model. Default is 512.
|
||||
batch_size : int
|
||||
Batch size to use for the model.
|
||||
"""
|
||||
vocab: list[str] = list(self._tokenizer.get_vocab().keys())
|
||||
|
||||
# if optional max_length is not set, set it to 512
|
||||
if max_length is None:
|
||||
max_length = 512
|
||||
|
||||
# generate random text
|
||||
texts: list[str] = [
|
||||
" ".join(random.choices(vocab, k=max_length)) for _ in range(batch_size)
|
||||
]
|
||||
|
||||
# tokenize text
|
||||
inputs: transformers.BatchEncoding = self._tokenizer(
|
||||
texts,
|
||||
padding="max_length",
|
||||
return_tensors="pt",
|
||||
truncation=True,
|
||||
max_length=max_length,
|
||||
)
|
||||
|
||||
return self._make_inputs_compatible(inputs)
|
||||
|
||||
def _make_inputs_compatible(
|
||||
self, inputs: transformers.BatchEncoding
|
||||
) -> Tuple[Tensor, ...]:
|
||||
""" "
|
||||
Make the input batch format compatible to the model's requirements.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
inputs : transformers.BatchEncoding
|
||||
The input batch to make compatible.
|
||||
"""
|
||||
# Add params when not provided by the tokenizer (e.g. DistilBERT), to conform to BERT interface
|
||||
if "token_type_ids" not in inputs:
|
||||
inputs["token_type_ids"] = torch.zeros(
|
||||
inputs["input_ids"].size(1), dtype=torch.long
|
||||
)
|
||||
if isinstance(
|
||||
self._tokenizer,
|
||||
(
|
||||
transformers.BartTokenizer,
|
||||
transformers.MPNetTokenizer,
|
||||
transformers.RobertaTokenizer,
|
||||
transformers.XLMRobertaTokenizer,
|
||||
),
|
||||
):
|
||||
del inputs["token_type_ids"]
|
||||
return (inputs["input_ids"], inputs["attention_mask"])
|
||||
|
||||
position_ids = torch.arange(inputs["input_ids"].size(1), dtype=torch.long)
|
||||
inputs["position_ids"] = position_ids
|
||||
return (
|
||||
inputs["input_ids"],
|
||||
inputs["attention_mask"],
|
||||
inputs["token_type_ids"],
|
||||
inputs["position_ids"],
|
||||
)
|
||||
|
||||
def _create_traceable_model(self) -> _TransformerTraceableModel:
|
||||
|
@ -173,6 +173,15 @@ class TestModelConfguration:
|
||||
assert ["text_field"] == config.input.field_names
|
||||
assert isinstance(config.inference_config, config_type)
|
||||
tokenization = config.inference_config.tokenization
|
||||
assert isinstance(config.metadata, dict)
|
||||
assert (
|
||||
"per_deployment_memory_bytes" in config.metadata
|
||||
and config.metadata["per_deployment_memory_bytes"] > 0
|
||||
)
|
||||
assert (
|
||||
"per_allocation_memory_bytes" in config.metadata
|
||||
and config.metadata["per_allocation_memory_bytes"] > 0
|
||||
)
|
||||
assert isinstance(tokenization, tokenizer_type)
|
||||
assert max_sequence_len == tokenization.max_sequence_length
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user