30
30
--tflite=false --tfjs=false
31
31
"""
32
32
33
+ import datetime
34
+ import json
33
35
import os
34
36
35
37
from absl import app
36
38
from absl import flags
37
39
40
+ import ddsp
41
+ from ddsp .training import data
38
42
from ddsp .training import inference
43
+ from ddsp .training import postprocessing
39
44
from ddsp .training import train_util
40
45
import gin
41
46
import tensorflow as tf
86
91
87
92
FLAGS = flags .FLAGS
88
93
94
+ # Metadata.
95
+ flags .DEFINE_boolean ('metadata' , True , 'Save metadata for model as a json.' )
96
+ flags .DEFINE_string (
97
+ 'dataset_path' , None ,
98
+ 'Only required if FLAGS.metadata=True. Path to TF Records containing '
99
+ 'training examples. Only used if no binding to train.data_provider can '
100
+ 'be found.' )
101
+
102
+ FLAGS = flags .FLAGS
103
+
104
+
105
+ def get_data_provider (dataset_path , model_path ):
106
+ """Get the data provider for dataset for statistics.
107
+
108
+ Read TF examples from specified path if provided, else use the
109
+ data provider specified in the gin config.
110
+ Args:
111
+ dataset_path: Path to an sstable of TF Examples.
112
+ model_path: Path to the model checkpoint dir containing the gin config.
113
+ Returns:
114
+ Data provider to calculate statistics over.
115
+ """
116
+ # First, see if the dataset path is specified
117
+ if dataset_path is not None :
118
+ dataset_path = train_util .expand_path (dataset_path )
119
+ return data .TFRecordProvider (dataset_path )
120
+ else :
121
+ inference .parse_operative_config (model_path )
122
+ try :
123
+ dp_binding = gin .query_parameter ('train.data_provider' )
124
+ return dp_binding .scoped_configurable_fn ()
125
+
126
+ except ValueError as e :
127
+ raise Exception (
128
+ 'Failed to parse dataset from gin. Either --dataset_path '
129
+ 'or train.data_provider gin param must be set.' ) from e
130
+
131
+
132
+ def get_metadata_dict (data_provider , model_path ):
133
+ """Compute metadata using compute_dataset_statistics and add version/date."""
134
+
135
+ # Parse gin for num_harmonics and num_noise_amps.
136
+ inference .parse_operative_config (model_path )
137
+
138
+ # Get number of outputs.
139
+ ref = gin .query_parameter ('Autoencoder.decoder' )
140
+ decoder_type = ref .config_key [- 1 ].split ('.' )[- 1 ]
141
+ output_splits = dict (gin .query_parameter (f'{ decoder_type } .output_splits' ))
142
+
143
+ # Get power rate and size.
144
+ frame_size = gin .query_parameter ('%frame_size' )
145
+ frame_rate = gin .query_parameter ('%frame_rate' )
146
+
147
+ # Compute stats.
148
+ full_metadata = postprocessing .compute_dataset_statistics (
149
+ data_provider ,
150
+ power_frame_size = frame_size ,
151
+ power_frame_rate = frame_rate )
152
+
153
+ lite_metadata = {
154
+ 'mean_min_pitch_note' :
155
+ float (full_metadata ['mean_min_pitch_note' ]),
156
+ 'mean_max_pitch_note' :
157
+ float (full_metadata ['mean_max_pitch_note' ]),
158
+ 'mean_min_pitch_note_hz' :
159
+ float (ddsp .core .midi_to_hz (full_metadata ['mean_min_pitch_note' ])),
160
+ 'mean_max_pitch_note_hz' :
161
+ float (ddsp .core .midi_to_hz (full_metadata ['mean_max_pitch_note' ])),
162
+ 'mean_min_power_note' :
163
+ float (full_metadata ['mean_min_power_note' ]),
164
+ 'mean_max_power_note' :
165
+ float (full_metadata ['mean_max_power_note' ]),
166
+ 'version' :
167
+ ddsp .__version__ ,
168
+ 'export_time' :
169
+ datetime .datetime .now ().isoformat (),
170
+ 'num_harmonics' :
171
+ output_splits ['harmonic_distribution' ],
172
+ 'num_noise_amps' :
173
+ output_splits ['noise_magnitudes' ],
174
+ }
175
+ return lite_metadata
176
+
89
177
90
178
def get_inference_model (ckpt ):
91
179
"""Restore model from checkpoint using global FLAGS.
@@ -187,6 +275,14 @@ def main(unused_argv):
187
275
save_dir = train_util .expand_path (save_dir )
188
276
ensure_exits (save_dir )
189
277
278
+ # Save metadata.
279
+ if FLAGS .metadata :
280
+ metadata_path = os .path .join (save_dir , 'metadata.json' )
281
+ data_provider = get_data_provider (FLAGS .dataset_path , model_path )
282
+ metadata = get_metadata_dict (data_provider , model_path )
283
+ with tf .io .gfile .GFile (metadata_path , 'w' ) as f :
284
+ f .write (json .dumps (metadata ))
285
+
190
286
# Create SavedModel if none already exists.
191
287
if not is_saved_model :
192
288
ckpt_to_saved_model (model_path , save_dir )
@@ -200,7 +296,8 @@ def main(unused_argv):
200
296
if FLAGS .tflite :
201
297
tflite_dir = os .path .join (save_dir , 'tflite' )
202
298
ensure_exits (tflite_dir )
203
- saved_model_to_tflite (save_dir , tflite_dir , FLAGS .metadata_file )
299
+ saved_model_to_tflite (save_dir , tflite_dir ,
300
+ metadata_path if FLAGS .metadata else '' )
204
301
205
302
206
303
def console_entry_point ():
0 commit comments