Skip to content

RematScope with mixed precision: autocast on kernels doesn't seem to trigger on the backwards pass #21448

Open
@pumpnineteen

Description

@pumpnineteen

created this layer to monitor what's happening:

class ReportingConv1D(Conv1D):
def call(self, inputs):
print(f"{inputs.dtype} {inputs.shape} {self.kernel.dtype} {self.kernel.shape} {self.kernel.path}")
return super().call(inputs)

forward pass:
<dtype: 'float16'> (2, 147456, 16) float16 (5, 16, 16) reporting_conv1d_25/kernel

backward pass:
<dtype: 'float16'> (2, 147456, 16) float32 (5, 16, 16) reporting_conv1d_25/kernel

Metadata

Metadata

Assignees

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions