@@ -9,7 +9,7 @@ class Kernel:
9
9
"""Base class for kernels
10
10
11
11
Unless otherwise noted, all kernels implementing lengthscale detection
12
- use the median of the first 100 pairwise distances as the lengthscale."""
12
+ use the median of pairwise distances as the lengthscale."""
13
13
pass
14
14
15
15
@@ -30,7 +30,7 @@ def __call__(self, dists):
30
30
if self .lengthscale is not None :
31
31
lengthscale = self .lengthscale
32
32
else :
33
- lengthscale = dists [: 100 , : 100 ] .median ()
33
+ lengthscale = dists .median ()
34
34
return torch .exp ((- 0.5 / lengthscale ** 2 ) * dists ** 2 )
35
35
36
36
@@ -51,7 +51,7 @@ def __call__(self, dists):
51
51
if self .lengthscale is not None :
52
52
lengthscale = self .lengthscale
53
53
else :
54
- lengthscale = dists [: 100 , : 100 ] .median ()
54
+ lengthscale = dists .median ()
55
55
return torch .exp ((- 1 / lengthscale ) * dists )
56
56
57
57
@@ -73,7 +73,7 @@ def __call__(self, dists):
73
73
if self .lengthscale is not None :
74
74
lengthscale = self .lengthscale
75
75
else :
76
- lengthscale = dists [: 100 , : 100 ] .median ()
76
+ lengthscale = dists .median ()
77
77
return torch .pow (
78
78
1 + (1 / (2 * self .alpha * lengthscale ** 2 )) * dists ** 2 , - self .alpha
79
79
)
0 commit comments