Skip to content

Commit 4386596

Browse files
jesseengelMagenta Team
authored andcommitted
Change training to only run train.num_steps total, regardless of restarting. Always save a final checkpoint.
PiperOrigin-RevId: 442584734
1 parent 220bdff commit 4386596

File tree

2 files changed

+12
-9
lines changed

2 files changed

+12
-9
lines changed

ddsp/training/train_util.py

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -268,18 +268,21 @@ def train(data_provider,
268268
with summary_writer.as_default():
269269
tick = time.time()
270270

271-
for iteration in range(num_steps):
272-
step = trainer.step # Step is not iteration if restarting a model.
271+
first_step = True
272+
273+
while trainer.step < num_steps:
274+
step = trainer.step
273275

274276
# Take a step.
275277
losses = trainer.train_step(dataset_iter)
276278

277279
# Create training loss metrics when starting/restarting training.
278-
if iteration == 0:
280+
if first_step:
279281
loss_names = list(losses.keys())
280282
logging.info('Creating metrics for %s', loss_names)
281283
avg_losses = {name: tf.keras.metrics.Mean(name=name, dtype=tf.float32)
282284
for name in loss_names}
285+
first_step = False
283286

284287
# Update metrics.
285288
for k, v in losses.items():
@@ -312,16 +315,16 @@ def train(data_provider,
312315
losses['total_loss'] <= early_stop_loss_value):
313316
logging.info('Total loss reached early stopping value of %s',
314317
early_stop_loss_value)
315-
316-
# Write a final checkpoint.
317-
if save_dir:
318-
trainer.save(save_dir)
319-
summary_writer.flush()
320318
break
321319

322320
# Save Model.
323321
if step % steps_per_save == 0 and save_dir:
324322
trainer.save(save_dir)
325323
summary_writer.flush()
326324

325+
# Write a final checkpoint.
326+
if save_dir:
327+
trainer.save(save_dir)
328+
summary_writer.flush()
329+
327330
logging.info('Training Finished!')

ddsp/version.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,4 +18,4 @@
1818
pulling in all the dependencies in __init__.py.
1919
"""
2020

21-
__version__ = '3.3.3'
21+
__version__ = '3.3.4'

0 commit comments

Comments
 (0)