mirror of
https://github.com/elastic/eland.git
synced 2025-07-11 00:02:14 +08:00
Freeze the traced PyTorch model
This commit is contained in:
parent
ec041ffdfd
commit
081c8efaa0
@ -59,5 +59,6 @@ class TraceableModel(ABC):
|
|||||||
def save(self, path: str) -> str:
|
def save(self, path: str) -> str:
|
||||||
model_path = os.path.join(path, "traced_pytorch_model.pt")
|
model_path = os.path.join(path, "traced_pytorch_model.pt")
|
||||||
trace_model = self.trace()
|
trace_model = self.trace()
|
||||||
|
trace_model = torch.jit.freeze(trace_model)
|
||||||
torch.jit.save(trace_model, model_path)
|
torch.jit.save(trace_model, model_path)
|
||||||
return model_path
|
return model_path
|
||||||
|
Loading…
x
Reference in New Issue
Block a user