Skip to content

NaN metrics when training under tf.distribute.MirroredStrategy #21467

@gianlucasama

Description

@gianlucasama

Hi!
I have stumbled upon a pretty annoying issue when training a model using tf.distribute.MirroredStrategy.

Fundamentally, after a certain number of steps in each epoch, all metrics go "nan", but the model is actually training fine under the hood.

After looking through and debugging keras' code that implements metrics, I finally found it.

keras.metrics.Mean has "total" and "count" mirrored variables that are reduced through sum: "total" just accumulates the state of the metric, "count" is used to divide "total" and give back the correct average result. When running on multiple gpus it seems that "total" could increase too drastically, overflowing, resulting in the nan metrics.

keras.metrics.Sum might have the same issue.

If you guys at keras have any idea on what path to follow to fix this, I would be happy to contribute if necessary.

P.S.: this issue could be related to some tensorflow issues describing this same behaviour, like tensorflow/tensorflow#90686

Metadata

Metadata

Labels

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions