Skip to content

Commit f075629

Browse files
Ethan ManilowMagenta Team
authored andcommitted
Moving gin_register_keras_layers() to nn.py.
PiperOrigin-RevId: 441565076
1 parent 01e1839 commit f075629

File tree

3 files changed

+15
-14
lines changed

3 files changed

+15
-14
lines changed

ddsp/training/ddsp_run.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -183,7 +183,6 @@ def main(unused_argv):
183183
gfile.makedirs(restore_dir) # Only makes dirs if they don't exist.
184184
parse_gin(restore_dir)
185185
logging.info('Operative Gin Config:\n%s', gin.config.config_str())
186-
train_util.gin_register_keras_layers()
187186

188187
if FLAGS.allow_memory_growth:
189188
allow_memory_growth()

ddsp/training/nn.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,21 @@
3232
# pylint: disable=redundant-keyword-arg
3333

3434

35+
def gin_register_keras_layers():
36+
"""Registers all keras layers and Sequential to be referenceable in gin."""
37+
# Register sequential model.
38+
gin.external_configurable(tf.keras.Sequential, 'tf.keras.Sequential')
39+
40+
# Register all the layers.
41+
for k, v in inspect.getmembers(tf.keras.layers):
42+
# Duck typing for tf.keras.layers.Layer since keras uses metaclasses.
43+
if hasattr(v, 'variables'):
44+
gin.external_configurable(v, f'tf.keras.layers.{k}')
45+
46+
47+
gin_register_keras_layers()
48+
49+
3550
class DictLayer(tfkl.Layer):
3651
"""Wrap a Keras Layer to take dictionary inputs and outputs.
3752

ddsp/training/train_util.py

Lines changed: 0 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,6 @@
1515
# Lint as: python3
1616
"""Library of training functions."""
1717

18-
import inspect
1918
import json
2019
import os
2120
import time
@@ -209,18 +208,6 @@ def format_for_tensorboard(line):
209208
summary_writer.flush()
210209

211210

212-
def gin_register_keras_layers():
213-
"""Registers all keras layers and Sequential to be referenceable in gin."""
214-
# Register sequential model.
215-
gin.external_configurable(tf.keras.Sequential, 'tf.keras.Sequential')
216-
217-
# Register all the layers.
218-
for k, v in inspect.getmembers(tf.keras.layers):
219-
# Duck typing for tf.keras.layers.Layer since keras uses metaclasses.
220-
if hasattr(v, 'variables'):
221-
gin.external_configurable(v, f'tf.keras.layers.{k}')
222-
223-
224211
# ------------------------ Training Loop ---------------------------------------
225212
@gin.configurable
226213
def train(data_provider,

0 commit comments

Comments
 (0)