File tree Expand file tree Collapse file tree 2 files changed +12
-9
lines changed Expand file tree Collapse file tree 2 files changed +12
-9
lines changed Original file line number Diff line number Diff line change @@ -268,18 +268,21 @@ def train(data_provider,
268
268
with summary_writer .as_default ():
269
269
tick = time .time ()
270
270
271
- for iteration in range (num_steps ):
272
- step = trainer .step # Step is not iteration if restarting a model.
271
+ first_step = True
272
+
273
+ while trainer .step < num_steps :
274
+ step = trainer .step
273
275
274
276
# Take a step.
275
277
losses = trainer .train_step (dataset_iter )
276
278
277
279
# Create training loss metrics when starting/restarting training.
278
- if iteration == 0 :
280
+ if first_step :
279
281
loss_names = list (losses .keys ())
280
282
logging .info ('Creating metrics for %s' , loss_names )
281
283
avg_losses = {name : tf .keras .metrics .Mean (name = name , dtype = tf .float32 )
282
284
for name in loss_names }
285
+ first_step = False
283
286
284
287
# Update metrics.
285
288
for k , v in losses .items ():
@@ -312,16 +315,16 @@ def train(data_provider,
312
315
losses ['total_loss' ] <= early_stop_loss_value ):
313
316
logging .info ('Total loss reached early stopping value of %s' ,
314
317
early_stop_loss_value )
315
-
316
- # Write a final checkpoint.
317
- if save_dir :
318
- trainer .save (save_dir )
319
- summary_writer .flush ()
320
318
break
321
319
322
320
# Save Model.
323
321
if step % steps_per_save == 0 and save_dir :
324
322
trainer .save (save_dir )
325
323
summary_writer .flush ()
326
324
325
+ # Write a final checkpoint.
326
+ if save_dir :
327
+ trainer .save (save_dir )
328
+ summary_writer .flush ()
329
+
327
330
logging .info ('Training Finished!' )
Original file line number Diff line number Diff line change 18
18
pulling in all the dependencies in __init__.py.
19
19
"""
20
20
21
- __version__ = '3.3.3 '
21
+ __version__ = '3.3.4 '
You can’t perform that action at this time.
0 commit comments