From 081c8efaa02d22504538d43a2df227219cfda6ef Mon Sep 17 00:00:00 2001 From: David Kyle Date: Tue, 21 Jun 2022 13:43:18 +0100 Subject: [PATCH] Freeze the traced PyTorch model --- eland/ml/pytorch/traceable_model.py | 1 + 1 file changed, 1 insertion(+) diff --git a/eland/ml/pytorch/traceable_model.py b/eland/ml/pytorch/traceable_model.py index 0d6799c..fa6eb12 100644 --- a/eland/ml/pytorch/traceable_model.py +++ b/eland/ml/pytorch/traceable_model.py @@ -59,5 +59,6 @@ class TraceableModel(ABC): def save(self, path: str) -> str: model_path = os.path.join(path, "traced_pytorch_model.pt") trace_model = self.trace() + trace_model = torch.jit.freeze(trace_model) torch.jit.save(trace_model, model_path) return model_path