Skip to content

Commit 350b771

Browse files
Fix model pickling (sdv-dev#271)
* Fix model pickling
1 parent a526820 commit 350b771

File tree

3 files changed

+100
-5
lines changed

3 files changed

+100
-5
lines changed

ctgan/synthesizers/base.py

Lines changed: 51 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -57,13 +57,60 @@ def wrapper(self, *args, **kwargs):
5757

5858

5959
class BaseSynthesizer:
60-
"""Base class for all default synthesizers of ``CTGAN``.
61-
62-
This should contain the save/load methods.
63-
"""
60+
"""Base class for all default synthesizers of ``CTGAN``."""
6461

6562
random_states = None
6663

64+
def __getstate__(self):
65+
"""Improve pickling state for ``BaseSynthesizer``.
66+
67+
Convert to ``cpu`` device before starting the pickling process in order to be able to
68+
load the model even when used from an external tool such as ``SDV``. Also, if
69+
``random_states`` are set, store their states as dictionaries rather than generators.
70+
71+
Returns:
72+
dict:
73+
Python dict representing the object.
74+
"""
75+
device_backup = self._device
76+
self.set_device(torch.device('cpu'))
77+
state = self.__dict__.copy()
78+
self.set_device(device_backup)
79+
if (
80+
isinstance(self.random_states, tuple) and
81+
isinstance(self.random_states[0], np.random.RandomState) and
82+
isinstance(self.random_states[1], torch.Generator)
83+
):
84+
state['_numpy_random_state'] = self.random_states[0].get_state()
85+
state['_torch_random_state'] = self.random_states[1].get_state()
86+
state.pop('random_states')
87+
88+
return state
89+
90+
def __setstate__(self, state):
91+
"""Restore the state of a ``BaseSynthesizer``.
92+
93+
Restore the ``random_states`` from the state dict if those are present and then
94+
set the device according to the current hardware.
95+
"""
96+
if '_numpy_random_state' in state and '_torch_random_state' in state:
97+
np_state = state.pop('_numpy_random_state')
98+
torch_state = state.pop('_torch_random_state')
99+
100+
current_torch_state = torch.Generator()
101+
current_torch_state.set_state(torch_state)
102+
103+
current_numpy_state = np.random.RandomState()
104+
current_numpy_state.set_state(np_state)
105+
state['random_states'] = (
106+
current_numpy_state,
107+
current_torch_state
108+
)
109+
110+
self.__dict__ = state
111+
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
112+
self.set_device(device)
113+
67114
def save(self, path):
68115
"""Save the model in the passed `path`."""
69116
device_backup = self._device

tests/integration/synthesizer/test_ctgan.py

Lines changed: 25 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -210,7 +210,7 @@ def test_fixed_random_seed():
210210
})
211211
discrete_columns = ['discrete']
212212

213-
ctgan = CTGAN(epochs=1)
213+
ctgan = CTGAN(epochs=1, cuda=False)
214214

215215
# Run
216216
ctgan.fit(data, discrete_columns)
@@ -273,3 +273,27 @@ def test_conditional():
273273

274274
def test_batch_size_pack_size():
275275
"""Test that if batch size is not a multiple of pack size, it raises a sane error."""
276+
277+
278+
def test_ctgan_save_and_load(tmpdir):
279+
"""Test that the ``CTGAN`` model can be saved and loaded."""
280+
# Setup
281+
data = pd.DataFrame({
282+
'continuous': np.random.random(100),
283+
'discrete': np.random.choice(['a', 'b', 'c'], 100)
284+
})
285+
discrete_columns = [1]
286+
287+
ctgan = CTGAN(epochs=1)
288+
ctgan.fit(data.to_numpy(), discrete_columns)
289+
ctgan.set_random_state(0)
290+
291+
ctgan.sample(100)
292+
model_path = tmpdir / 'model.pkl'
293+
294+
# Save
295+
ctgan.save(str(model_path))
296+
297+
# Load
298+
loaded_instance = CTGAN.load(str(model_path))
299+
loaded_instance.sample(100)

tests/integration/synthesizer/test_tvae.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -129,3 +129,27 @@ def test_fixed_random_seed():
129129
assert not np.array_equal(sampled_random, sampled_0_1)
130130
np.testing.assert_array_equal(sampled_0_0, sampled_1_0)
131131
np.testing.assert_array_equal(sampled_0_1, sampled_1_1)
132+
133+
134+
def test_tvae_save(tmpdir):
135+
"""Test that the ``TVAE`` model can be saved and loaded."""
136+
# Setup
137+
data = pd.DataFrame({
138+
'continuous': np.random.random(100),
139+
'discrete': np.random.choice(['a', 'b', 'c'], 100)
140+
})
141+
discrete_columns = [1]
142+
143+
tvae = TVAE(epochs=1)
144+
tvae.fit(data.to_numpy(), discrete_columns)
145+
tvae.set_random_state(0)
146+
147+
tvae.sample(100)
148+
model_path = tmpdir / 'model.pkl'
149+
150+
# Save
151+
tvae.save(str(model_path))
152+
153+
# Load
154+
loaded_instance = TVAE.load(str(model_path))
155+
loaded_instance.sample(100)

0 commit comments

Comments
 (0)