Skip to content

MASPlugin errors if SGD Loss is zero #1676

Open
@man2machine

Description

@man2machine

In the MASPlugin in the before_backward callback there is a check to see if the loss has been generated from the SGDUpdate class's training_epoch function.

if not strategy.loss:
    raise ValueError("Loss is not available")

However, at times when dealing with a small amount of data in an experience, if the classifier is near perfect in its logits output, PyTorch may return tensor(0., device='cuda:0', grad_fn=<NllLossBackward0>) for its loss output from nn.CrossEntropyLoss() due to numerical precision. Since a tensor with all zeroes evaluates to False when converted to a boolean, in this case the MAS algorithm errors even though the SGD update has actually occurred correctly.

Here are a few solutions:

  • Use a if not strategy.loss.requires_grad check instead of a if not strategy.loss check
  • Replace strategy.loss to be None initially, do not use self._make_empty_loss(), and in the MASPlugin check if strategy.loss is not None

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugSomething isn't working

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions