File tree Expand file tree Collapse file tree 3 files changed +15
-14
lines changed Expand file tree Collapse file tree 3 files changed +15
-14
lines changed Original file line number Diff line number Diff line change @@ -183,7 +183,6 @@ def main(unused_argv):
183
183
gfile .makedirs (restore_dir ) # Only makes dirs if they don't exist.
184
184
parse_gin (restore_dir )
185
185
logging .info ('Operative Gin Config:\n %s' , gin .config .config_str ())
186
- train_util .gin_register_keras_layers ()
187
186
188
187
if FLAGS .allow_memory_growth :
189
188
allow_memory_growth ()
Original file line number Diff line number Diff line change 32
32
# pylint: disable=redundant-keyword-arg
33
33
34
34
35
+ def gin_register_keras_layers ():
36
+ """Registers all keras layers and Sequential to be referenceable in gin."""
37
+ # Register sequential model.
38
+ gin .external_configurable (tf .keras .Sequential , 'tf.keras.Sequential' )
39
+
40
+ # Register all the layers.
41
+ for k , v in inspect .getmembers (tf .keras .layers ):
42
+ # Duck typing for tf.keras.layers.Layer since keras uses metaclasses.
43
+ if hasattr (v , 'variables' ):
44
+ gin .external_configurable (v , f'tf.keras.layers.{ k } ' )
45
+
46
+
47
+ gin_register_keras_layers ()
48
+
49
+
35
50
class DictLayer (tfkl .Layer ):
36
51
"""Wrap a Keras Layer to take dictionary inputs and outputs.
37
52
Original file line number Diff line number Diff line change 15
15
# Lint as: python3
16
16
"""Library of training functions."""
17
17
18
- import inspect
19
18
import json
20
19
import os
21
20
import time
@@ -209,18 +208,6 @@ def format_for_tensorboard(line):
209
208
summary_writer .flush ()
210
209
211
210
212
- def gin_register_keras_layers ():
213
- """Registers all keras layers and Sequential to be referenceable in gin."""
214
- # Register sequential model.
215
- gin .external_configurable (tf .keras .Sequential , 'tf.keras.Sequential' )
216
-
217
- # Register all the layers.
218
- for k , v in inspect .getmembers (tf .keras .layers ):
219
- # Duck typing for tf.keras.layers.Layer since keras uses metaclasses.
220
- if hasattr (v , 'variables' ):
221
- gin .external_configurable (v , f'tf.keras.layers.{ k } ' )
222
-
223
-
224
211
# ------------------------ Training Loop ---------------------------------------
225
212
@gin .configurable
226
213
def train (data_provider ,
You can’t perform that action at this time.
0 commit comments