Skip to content

Commit d828e41

Browse files
authored
Misc Fixes (#14)
- Python Optimal Transport now wants numpy arrays and does not like tensors. - Saving and Loading when buffers are None...
1 parent bda529a commit d828e41

File tree

5 files changed

+34
-1
lines changed

5 files changed

+34
-1
lines changed

test/test_reducers.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,13 @@ def test_pca():
1717
red2 = pca(b)
1818
red2_ref = torch.from_numpy(pca_ref.transform(b))
1919

20+
def test_reducer_load_save():
21+
pca = torchdrift.reducers.PCAReducer(n_components=2)
22+
a = torch.randn(100, 50, dtype=torch.double)
23+
red = pca.fit(a)
24+
pca2 = torchdrift.reducers.PCAReducer(n_components=2)
25+
pca2.load_state_dict(pca.state_dict())
26+
red2 = pca2(a)
2027

2128
def test_reducer():
2229
x = torch.randn(5, 5)

torchdrift/detectors/detector.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,3 +63,11 @@ def forward(
6363
if self.return_p_value:
6464
return p_value
6565
return ood_score
66+
67+
def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict,
68+
missing_keys, unexpected_keys, error_msgs):
69+
for bname, b in self._buffers.items():
70+
if prefix + bname in state_dict and b is None:
71+
setattr(self, bname, state_dict[prefix + bname])
72+
super()._load_from_state_dict(state_dict, prefix, local_metadata, strict,
73+
missing_keys, unexpected_keys, error_msgs)

torchdrift/detectors/wasserstein.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@ def wasserstein(x, y, p=2.0, fraction_to_match=1.0, n_perm=1000, return_coupling
4747
if fraction_to_match < 1.0:
4848
weights_y[:, -1] = 1.0 - fraction_to_match
4949

50-
coupling = torch.from_numpy(ot_emd(weights_x[0], weights_y[0], dists_p.cpu()))
50+
coupling = torch.from_numpy(ot_emd(weights_x[0].numpy(), weights_y[0].numpy(), dists_p.cpu().numpy()))
5151

5252
if (coupling[:, :num_y].sum() / fraction_to_match - 1).abs().item() > 1e-5: # pragma: no cover
5353
raise RuntimeError("Numerical stability failed")

torchdrift/reducers/pca.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,8 @@ class PCAReducer(Reducer):
1616
def __init__(self, n_components: int = 2):
1717
super().__init__()
1818
self.n_components = n_components
19+
self.register_buffer("mean", None)
20+
self.register_buffer("comp", None)
1921

2022
def extra_repr(self) -> str:
2123
return f"n_components={self.n_components}"
@@ -37,3 +39,11 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
3739
x = x - self.mean
3840
reduced = x @ self.comp
3941
return reduced
42+
43+
def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict,
44+
missing_keys, unexpected_keys, error_msgs):
45+
for bname in ("mean", "comp"):
46+
if prefix + bname in state_dict:
47+
setattr(self, bname, state_dict[prefix + bname])
48+
super()._load_from_state_dict(state_dict, prefix, local_metadata, strict,
49+
missing_keys, unexpected_keys, error_msgs)

torchdrift/reducers/reducer.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,3 +19,11 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
1919
2020
Override this in your reducer implementation."""
2121
raise NotImplementedError("Override forward in subclass")
22+
23+
def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict,
24+
missing_keys, unexpected_keys, error_msgs):
25+
for bname, b in self._buffers.items():
26+
if prefix + bname in state_dict and b is None:
27+
setattr(self, bname, state_dict[prefix + bname])
28+
super()._load_from_state_dict(state_dict, prefix, local_metadata, strict,
29+
missing_keys, unexpected_keys, error_msgs)

0 commit comments

Comments
 (0)