diff --git a/eland/ml/pytorch/traceable_model.py b/eland/ml/pytorch/traceable_model.py index 77b670c..0d6799c 100644 --- a/eland/ml/pytorch/traceable_model.py +++ b/eland/ml/pytorch/traceable_model.py @@ -41,7 +41,7 @@ class TraceableModel(ABC): def quantize(self) -> None: torch.quantization.quantize_dynamic( - self._model, {torch.nn.Linear}, dtype=torch.qint8 + self._model, {torch.nn.Linear}, dtype=torch.qint8, inplace=True ) def trace(self) -> TracedModelTypes: