[ML] Better memory estimation for NLP models (#568)

This PR adds an ability to estimate per deployment and per allocation memory usage of NLP transformer models. It uses torch.profiler and performs logs the peak memory usage during the inference.

This information is then used in Elasticsearch to provision models with sufficient memory (elastic/elasticsearch#98874).
This commit is contained in:
Valeriy Khakhutskyy 2023-11-06 12:18:20 +01:00 committed by GitHub
parent 28e6d92430
commit 6cecb454e3
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 158 additions and 5 deletions

View File

@ -16,7 +16,12 @@ fi
set -euxo pipefail set -euxo pipefail
SCRIPT_PATH=$(dirname $(realpath -s $0)) # realpath on MacOS use different flags than on Linux
if [[ "$OSTYPE" == "darwin"* ]]; then
SCRIPT_PATH=$(dirname $(realpath $0))
else
SCRIPT_PATH=$(dirname $(realpath -s $0))
fi
moniker=$(echo "$ELASTICSEARCH_VERSION" | tr -C "[:alnum:]" '-') moniker=$(echo "$ELASTICSEARCH_VERSION" | tr -C "[:alnum:]" '-')
suffix=rest-test suffix=rest-test
@ -132,7 +137,7 @@ url="http://elastic:$ELASTIC_PASSWORD@$NODE_NAME"
docker_pull_attempts=0 docker_pull_attempts=0
until [ "$docker_pull_attempts" -ge 5 ] until [ "$docker_pull_attempts" -ge 5 ]
do do
docker pull docker.elastic.co/elasticsearch/"$ELASTICSEARCH_VERSION" && break docker pull docker.elastic.co/elasticsearch/$ELASTICSEARCH_VERSION && break
docker_pull_attempts=$((docker_pull_attempts+1)) docker_pull_attempts=$((docker_pull_attempts+1))
sleep 10 sleep 10
done done

View File

@ -169,7 +169,7 @@ currently using a minimum version of PyCharm 2019.2.4.
* Setup Elasticsearch instance with docker * Setup Elasticsearch instance with docker
``` bash ``` bash
> ELASTICSEARCH_VERSION=elasticsearch:7.x-SNAPSHOT .ci/run-elasticsearch.sh > ELASTICSEARCH_VERSION=elasticsearch:8.x-SNAPSHOT BUILDKITE=false .buildkite/run-elasticsearch.sh
``` ```
* Now check `http://localhost:9200` * Now check `http://localhost:9200`

View File

@ -66,3 +66,7 @@ class TraceableModel(ABC):
trace_model = torch.jit.freeze(trace_model) trace_model = torch.jit.freeze(trace_model)
torch.jit.save(trace_model, model_path) torch.jit.save(trace_model, model_path)
return model_path return model_path
@property
def model(self) -> nn.Module:
return self._model

View File

@ -22,6 +22,7 @@ libraries such as sentence-transformers.
import json import json
import os.path import os.path
import random
import re import re
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from typing import Any, Dict, List, Optional, Set, Tuple, Union from typing import Any, Dict, List, Optional, Set, Tuple, Union
@ -30,6 +31,7 @@ import torch # type: ignore
import transformers # type: ignore import transformers # type: ignore
from sentence_transformers import SentenceTransformer # type: ignore from sentence_transformers import SentenceTransformer # type: ignore
from torch import Tensor, nn from torch import Tensor, nn
from torch.profiler import profile # type: ignore
from transformers import ( from transformers import (
AutoConfig, AutoConfig,
AutoModel, AutoModel,
@ -270,8 +272,8 @@ class _DistilBertWrapper(nn.Module): # type: ignore
self, self,
input_ids: Tensor, input_ids: Tensor,
attention_mask: Tensor, attention_mask: Tensor,
_token_type_ids: Tensor, _token_type_ids: Tensor = None,
_position_ids: Tensor, _position_ids: Tensor = None,
) -> Tensor: ) -> Tensor:
"""Wrap the input and output to conform to the native process interface.""" """Wrap the input and output to conform to the native process interface."""
@ -769,6 +771,18 @@ class TransformerModel:
tokenization=tokenization_config tokenization=tokenization_config
) )
# add static and dynamic memory state size to metadata
per_deployment_memory_bytes = self._get_per_deployment_memory()
per_allocation_memory_bytes = self._get_per_allocation_memory(
tokenization_config.max_sequence_length, 1
)
metadata = {
"per_deployment_memory_bytes": per_deployment_memory_bytes,
"per_allocation_memory_bytes": per_allocation_memory_bytes,
}
return NlpTrainedModelConfig( return NlpTrainedModelConfig(
description=f"Model {self._model_id} for task type '{self._task_type}'", description=f"Model {self._model_id} for task type '{self._task_type}'",
model_type="pytorch", model_type="pytorch",
@ -776,6 +790,127 @@ class TransformerModel:
input=TrainedModelInput( input=TrainedModelInput(
field_names=["text_field"], field_names=["text_field"],
), ),
metadata=metadata,
)
def _get_per_deployment_memory(self) -> float:
"""
Returns the static memory size of the model in bytes.
"""
psize: float = sum(
param.nelement() * param.element_size()
for param in self._traceable_model.model.parameters()
)
bsize: float = sum(
buffer.nelement() * buffer.element_size()
for buffer in self._traceable_model.model.buffers()
)
return psize + bsize
def _get_per_allocation_memory(
self, max_seq_length: Optional[int], batch_size: int
) -> float:
"""
Returns the transient memory size of the model in bytes.
Parameters
----------
max_seq_length : Optional[int]
Maximum sequence length to use for the model.
batch_size : int
Batch size to use for the model.
"""
activities = [torch.profiler.ProfilerActivity.CPU]
# Get the memory usage of the model with a batch size of 1.
inputs_1 = self._get_model_inputs(max_seq_length, 1)
with profile(activities=activities, profile_memory=True) as prof:
self._traceable_model.model(*inputs_1)
mem1: float = prof.key_averages().total_average().cpu_memory_usage
# This is measuring memory usage of the model with a batch size of 2 and
# then linearly extrapolating it to get the memory usage of the model for
# a batch size of batch_size.
if batch_size == 1:
return mem1
inputs_2 = self._get_model_inputs(max_seq_length, 2)
with profile(activities=activities, profile_memory=True) as prof:
self._traceable_model.model(*inputs_2)
mem2: float = prof.key_averages().total_average().cpu_memory_usage
return mem1 + (mem2 - mem1) * (batch_size - 1)
def _get_model_inputs(
self,
max_length: Optional[int],
batch_size: int,
) -> Tuple[Tensor, ...]:
"""
Returns a random batch of inputs for the model.
Parameters
----------
max_length : Optional[int]
Maximum sequence length to use for the model. Default is 512.
batch_size : int
Batch size to use for the model.
"""
vocab: list[str] = list(self._tokenizer.get_vocab().keys())
# if optional max_length is not set, set it to 512
if max_length is None:
max_length = 512
# generate random text
texts: list[str] = [
" ".join(random.choices(vocab, k=max_length)) for _ in range(batch_size)
]
# tokenize text
inputs: transformers.BatchEncoding = self._tokenizer(
texts,
padding="max_length",
return_tensors="pt",
truncation=True,
max_length=max_length,
)
return self._make_inputs_compatible(inputs)
def _make_inputs_compatible(
self, inputs: transformers.BatchEncoding
) -> Tuple[Tensor, ...]:
""" "
Make the input batch format compatible to the model's requirements.
Parameters
----------
inputs : transformers.BatchEncoding
The input batch to make compatible.
"""
# Add params when not provided by the tokenizer (e.g. DistilBERT), to conform to BERT interface
if "token_type_ids" not in inputs:
inputs["token_type_ids"] = torch.zeros(
inputs["input_ids"].size(1), dtype=torch.long
)
if isinstance(
self._tokenizer,
(
transformers.BartTokenizer,
transformers.MPNetTokenizer,
transformers.RobertaTokenizer,
transformers.XLMRobertaTokenizer,
),
):
del inputs["token_type_ids"]
return (inputs["input_ids"], inputs["attention_mask"])
position_ids = torch.arange(inputs["input_ids"].size(1), dtype=torch.long)
inputs["position_ids"] = position_ids
return (
inputs["input_ids"],
inputs["attention_mask"],
inputs["token_type_ids"],
inputs["position_ids"],
) )
def _create_traceable_model(self) -> _TransformerTraceableModel: def _create_traceable_model(self) -> _TransformerTraceableModel:

View File

@ -173,6 +173,15 @@ class TestModelConfguration:
assert ["text_field"] == config.input.field_names assert ["text_field"] == config.input.field_names
assert isinstance(config.inference_config, config_type) assert isinstance(config.inference_config, config_type)
tokenization = config.inference_config.tokenization tokenization = config.inference_config.tokenization
assert isinstance(config.metadata, dict)
assert (
"per_deployment_memory_bytes" in config.metadata
and config.metadata["per_deployment_memory_bytes"] > 0
)
assert (
"per_allocation_memory_bytes" in config.metadata
and config.metadata["per_allocation_memory_bytes"] > 0
)
assert isinstance(tokenization, tokenizer_type) assert isinstance(tokenization, tokenizer_type)
assert max_sequence_len == tokenization.max_sequence_length assert max_sequence_len == tokenization.max_sequence_length