Skip to content

Commit a20fad0

Browse files
authored
Fixing: Torch 2.0 fails with cuda=False (sdv-dev#291)
* Making changes to get torch 2.0 to pass demo * updating branch * handling nans from gumbel_softmax
1 parent 8d63e5d commit a20fad0

File tree

2 files changed

+9
-13
lines changed

2 files changed

+9
-13
lines changed

ctgan/synthesizers/ctgan.py

Lines changed: 7 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@
55
import numpy as np
66
import pandas as pd
77
import torch
8-
from packaging import version
98
from torch import optim
109
from torch.nn import BatchNorm1d, Dropout, LeakyReLU, Linear, Module, ReLU, Sequential, functional
1110

@@ -197,15 +196,12 @@ def _gumbel_softmax(logits, tau=1, hard=False, eps=1e-10, dim=-1):
197196
Returns:
198197
Sampled tensor of same shape as logits from the Gumbel-Softmax distribution.
199198
"""
200-
if version.parse(torch.__version__) < version.parse('1.2.0'):
201-
for i in range(10):
202-
transformed = functional.gumbel_softmax(logits, tau=tau, hard=hard,
203-
eps=eps, dim=dim)
204-
if not torch.isnan(transformed).any():
205-
return transformed
206-
raise ValueError('gumbel_softmax returning NaN.')
199+
for _ in range(10):
200+
transformed = functional.gumbel_softmax(logits, tau=tau, hard=hard, eps=eps, dim=dim)
201+
if not torch.isnan(transformed).any():
202+
return transformed
207203

208-
return functional.gumbel_softmax(logits, tau=tau, hard=hard, eps=eps, dim=dim)
204+
raise ValueError('gumbel_softmax returning NaN.')
209205

210206
def _apply_activate(self, data):
211207
"""Apply proper activation function to the output of the generator."""
@@ -381,7 +377,7 @@ def fit(self, train_data, discrete_columns=(), epochs=None):
381377
real_cat, fake_cat, self._device, self.pac)
382378
loss_d = -(torch.mean(y_real) - torch.mean(y_fake))
383379

384-
optimizerD.zero_grad()
380+
optimizerD.zero_grad(set_to_none=False)
385381
pen.backward(retain_graph=True)
386382
loss_d.backward()
387383
optimizerD.step()
@@ -412,7 +408,7 @@ def fit(self, train_data, discrete_columns=(), epochs=None):
412408

413409
loss_g = -torch.mean(y_fake) + cross_entropy
414410

415-
optimizerG.zero_grad()
411+
optimizerG.zero_grad(set_to_none=False)
416412
loss_g.backward()
417413
optimizerG.step()
418414

setup.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,8 +18,8 @@
1818
"pandas>=1.1.3;python_version<'3.10'",
1919
"pandas>=1.3.4;python_version>='3.10'",
2020
"scikit-learn>=1.1.3,<2;python_version>='3.10'",
21-
"torch>=1.8.0,<2;python_version<'3.10'",
22-
"torch>=1.11.0,<2;python_version>='3.10'",
21+
"torch>=1.8.0;python_version<'3.10'",
22+
"torch>=1.11.0;python_version>='3.10'",
2323
'rdt>=1.3.0,<2.0',
2424
]
2525

0 commit comments

Comments
 (0)