Skip to content

Commit 934ea8d

Browse files
committed
added option to save the weights of the AE and the Siamese
1 parent bc32476 commit 934ea8d

File tree

2 files changed

+4
-1
lines changed

2 files changed

+4
-1
lines changed

setup.cfg

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[metadata]
22
name = spectralnet
3-
version = 0.1.0
3+
version = 0.1.1
44
author = Amitai
55
description = Spectral Clustering Using Deep Neural Networks
66
long_description = file: README.md

src/spectralnet/_trainers/_ae_trainer.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,10 @@ def __init__(self, config: dict, device: torch.device):
2020
self.patience = self.ae_config["patience"]
2121
self.architecture = self.ae_config["hiddens"]
2222
self.batch_size = self.ae_config["batch_size"]
23+
self.weights_dir = "spectralnet/_trainers/weights"
2324
self.weights_path = "spectralnet/_trainers/weights/ae_weights.pth"
25+
if not os.path.exists(self.weights_dir):
26+
os.makedirs(self.weights_dir)
2427

2528
def train(self, X: torch.Tensor) -> AEModel:
2629
self.X = X.view(X.size(0), -1)

0 commit comments

Comments
 (0)