File tree Expand file tree Collapse file tree 1 file changed +28
-0
lines changed Expand file tree Collapse file tree 1 file changed +28
-0
lines changed Original file line number Diff line number Diff line change
1
+ import pytest
2
+ import torchdrift
3
+ import torch
4
+ import sklearn .decomposition
5
+
6
+ def test_pca ():
7
+ pca = torchdrift .reducers .PCAReducer (n_components = 2 )
8
+ assert 'n_components' in str (pca )
9
+ a = torch .randn (100 , 50 , dtype = torch .double )
10
+ red = pca .fit (a )
11
+ pca_ref = sklearn .decomposition .PCA (n_components = 2 )
12
+ red_ref = torch .from_numpy (pca_ref .fit_transform (a ))
13
+ # need to find a way to deal with signs
14
+ torch .testing .assert_allclose (red .abs (), red_ref .abs ())
15
+ b = torch .randn (25 , 50 , dtype = torch .double )
16
+ red2 = pca (b )
17
+ red2_ref = torch .from_numpy (pca_ref .transform (b ))
18
+
19
+ def test_reducer ():
20
+ x = torch .randn (5 , 5 )
21
+ r = torchdrift .reducers .Reducer ()
22
+ with pytest .raises (NotImplementedError ):
23
+ r .fit (x )
24
+ with pytest .raises (NotImplementedError ):
25
+ r (x )
26
+
27
+ if __name__ == "__main__" :
28
+ pytest .main ([__file__ ])
You can’t perform that action at this time.
0 commit comments