Skip to content

Commit 51ad92f

Browse files
committed
test reducers
1 parent d20f1c0 commit 51ad92f

File tree

1 file changed

+28
-0
lines changed

1 file changed

+28
-0
lines changed

test/test_reducers.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
import pytest
2+
import torchdrift
3+
import torch
4+
import sklearn.decomposition
5+
6+
def test_pca():
7+
pca = torchdrift.reducers.PCAReducer(n_components=2)
8+
assert 'n_components' in str(pca)
9+
a = torch.randn(100, 50, dtype=torch.double)
10+
red = pca.fit(a)
11+
pca_ref = sklearn.decomposition.PCA(n_components=2)
12+
red_ref = torch.from_numpy(pca_ref.fit_transform(a))
13+
# need to find a way to deal with signs
14+
torch.testing.assert_allclose(red.abs(), red_ref.abs())
15+
b = torch.randn(25, 50, dtype=torch.double)
16+
red2 = pca(b)
17+
red2_ref = torch.from_numpy(pca_ref.transform(b))
18+
19+
def test_reducer():
20+
x = torch.randn(5, 5)
21+
r = torchdrift.reducers.Reducer()
22+
with pytest.raises(NotImplementedError):
23+
r.fit(x)
24+
with pytest.raises(NotImplementedError):
25+
r(x)
26+
27+
if __name__ == "__main__":
28+
pytest.main([__file__])

0 commit comments

Comments
 (0)