diff --git a/eland/ml/pytorch/traceable_model.py b/eland/ml/pytorch/traceable_model.py index 0d6799c..fa6eb12 100644 --- a/eland/ml/pytorch/traceable_model.py +++ b/eland/ml/pytorch/traceable_model.py @@ -59,5 +59,6 @@ class TraceableModel(ABC): def save(self, path: str) -> str: model_path = os.path.join(path, "traced_pytorch_model.pt") trace_model = self.trace() + trace_model = torch.jit.freeze(trace_model) torch.jit.save(trace_model, model_path) return model_path