|
5 | 5 | import numpy as np
|
6 | 6 | import pandas as pd
|
7 | 7 | import torch
|
8 |
| -from packaging import version |
9 | 8 | from torch import optim
|
10 | 9 | from torch.nn import BatchNorm1d, Dropout, LeakyReLU, Linear, Module, ReLU, Sequential, functional
|
11 | 10 |
|
@@ -197,15 +196,12 @@ def _gumbel_softmax(logits, tau=1, hard=False, eps=1e-10, dim=-1):
|
197 | 196 | Returns:
|
198 | 197 | Sampled tensor of same shape as logits from the Gumbel-Softmax distribution.
|
199 | 198 | """
|
200 |
| - if version.parse(torch.__version__) < version.parse('1.2.0'): |
201 |
| - for i in range(10): |
202 |
| - transformed = functional.gumbel_softmax(logits, tau=tau, hard=hard, |
203 |
| - eps=eps, dim=dim) |
204 |
| - if not torch.isnan(transformed).any(): |
205 |
| - return transformed |
206 |
| - raise ValueError('gumbel_softmax returning NaN.') |
| 199 | + for _ in range(10): |
| 200 | + transformed = functional.gumbel_softmax(logits, tau=tau, hard=hard, eps=eps, dim=dim) |
| 201 | + if not torch.isnan(transformed).any(): |
| 202 | + return transformed |
207 | 203 |
|
208 |
| - return functional.gumbel_softmax(logits, tau=tau, hard=hard, eps=eps, dim=dim) |
| 204 | + raise ValueError('gumbel_softmax returning NaN.') |
209 | 205 |
|
210 | 206 | def _apply_activate(self, data):
|
211 | 207 | """Apply proper activation function to the output of the generator."""
|
@@ -381,7 +377,7 @@ def fit(self, train_data, discrete_columns=(), epochs=None):
|
381 | 377 | real_cat, fake_cat, self._device, self.pac)
|
382 | 378 | loss_d = -(torch.mean(y_real) - torch.mean(y_fake))
|
383 | 379 |
|
384 |
| - optimizerD.zero_grad() |
| 380 | + optimizerD.zero_grad(set_to_none=False) |
385 | 381 | pen.backward(retain_graph=True)
|
386 | 382 | loss_d.backward()
|
387 | 383 | optimizerD.step()
|
@@ -412,7 +408,7 @@ def fit(self, train_data, discrete_columns=(), epochs=None):
|
412 | 408 |
|
413 | 409 | loss_g = -torch.mean(y_fake) + cross_entropy
|
414 | 410 |
|
415 |
| - optimizerG.zero_grad() |
| 411 | + optimizerG.zero_grad(set_to_none=False) |
416 | 412 | loss_g.backward()
|
417 | 413 | optimizerG.step()
|
418 | 414 |
|
|
0 commit comments