Skip to content

Commit 2625bed

Browse files
jesseengelMagenta Team
authored andcommitted
Export and append metadata in ddsp_export for tflite models.
PiperOrigin-RevId: 437823948
1 parent b9b68b3 commit 2625bed

File tree

2 files changed

+99
-2
lines changed

2 files changed

+99
-2
lines changed

ddsp/training/ddsp_export.py

Lines changed: 98 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,12 +30,17 @@
3030
--tflite=false --tfjs=false
3131
"""
3232

33+
import datetime
34+
import json
3335
import os
3436

3537
from absl import app
3638
from absl import flags
3739

40+
import ddsp
41+
from ddsp.training import data
3842
from ddsp.training import inference
43+
from ddsp.training import postprocessing
3944
from ddsp.training import train_util
4045
import gin
4146
import tensorflow as tf
@@ -86,6 +91,89 @@
8691

8792
FLAGS = flags.FLAGS
8893

94+
# Metadata.
95+
flags.DEFINE_boolean('metadata', True, 'Save metadata for model as a json.')
96+
flags.DEFINE_string(
97+
'dataset_path', None,
98+
'Only required if FLAGS.metadata=True. Path to TF Records containing '
99+
'training examples. Only used if no binding to train.data_provider can '
100+
'be found.')
101+
102+
FLAGS = flags.FLAGS
103+
104+
105+
def get_data_provider(dataset_path, model_path):
106+
"""Get the data provider for dataset for statistics.
107+
108+
Read TF examples from specified path if provided, else use the
109+
data provider specified in the gin config.
110+
Args:
111+
dataset_path: Path to an sstable of TF Examples.
112+
model_path: Path to the model checkpoint dir containing the gin config.
113+
Returns:
114+
Data provider to calculate statistics over.
115+
"""
116+
# First, see if the dataset path is specified
117+
if dataset_path is not None:
118+
dataset_path = train_util.expand_path(dataset_path)
119+
return data.TFRecordProvider(dataset_path)
120+
else:
121+
inference.parse_operative_config(model_path)
122+
try:
123+
dp_binding = gin.query_parameter('train.data_provider')
124+
return dp_binding.scoped_configurable_fn()
125+
126+
except ValueError as e:
127+
raise Exception(
128+
'Failed to parse dataset from gin. Either --dataset_path '
129+
'or train.data_provider gin param must be set.') from e
130+
131+
132+
def get_metadata_dict(data_provider, model_path):
133+
"""Compute metadata using compute_dataset_statistics and add version/date."""
134+
135+
# Parse gin for num_harmonics and num_noise_amps.
136+
inference.parse_operative_config(model_path)
137+
138+
# Get number of outputs.
139+
ref = gin.query_parameter('Autoencoder.decoder')
140+
decoder_type = ref.config_key[-1].split('.')[-1]
141+
output_splits = dict(gin.query_parameter(f'{decoder_type}.output_splits'))
142+
143+
# Get power rate and size.
144+
frame_size = gin.query_parameter('%frame_size')
145+
frame_rate = gin.query_parameter('%frame_rate')
146+
147+
# Compute stats.
148+
full_metadata = postprocessing.compute_dataset_statistics(
149+
data_provider,
150+
power_frame_size=frame_size,
151+
power_frame_rate=frame_rate)
152+
153+
lite_metadata = {
154+
'mean_min_pitch_note':
155+
float(full_metadata['mean_min_pitch_note']),
156+
'mean_max_pitch_note':
157+
float(full_metadata['mean_max_pitch_note']),
158+
'mean_min_pitch_note_hz':
159+
float(ddsp.core.midi_to_hz(full_metadata['mean_min_pitch_note'])),
160+
'mean_max_pitch_note_hz':
161+
float(ddsp.core.midi_to_hz(full_metadata['mean_max_pitch_note'])),
162+
'mean_min_power_note':
163+
float(full_metadata['mean_min_power_note']),
164+
'mean_max_power_note':
165+
float(full_metadata['mean_max_power_note']),
166+
'version':
167+
ddsp.__version__,
168+
'export_time':
169+
datetime.datetime.now().isoformat(),
170+
'num_harmonics':
171+
output_splits['harmonic_distribution'],
172+
'num_noise_amps':
173+
output_splits['noise_magnitudes'],
174+
}
175+
return lite_metadata
176+
89177

90178
def get_inference_model(ckpt):
91179
"""Restore model from checkpoint using global FLAGS.
@@ -187,6 +275,14 @@ def main(unused_argv):
187275
save_dir = train_util.expand_path(save_dir)
188276
ensure_exits(save_dir)
189277

278+
# Save metadata.
279+
if FLAGS.metadata:
280+
metadata_path = os.path.join(save_dir, 'metadata.json')
281+
data_provider = get_data_provider(FLAGS.dataset_path, model_path)
282+
metadata = get_metadata_dict(data_provider, model_path)
283+
with tf.io.gfile.GFile(metadata_path, 'w') as f:
284+
f.write(json.dumps(metadata))
285+
190286
# Create SavedModel if none already exists.
191287
if not is_saved_model:
192288
ckpt_to_saved_model(model_path, save_dir)
@@ -200,7 +296,8 @@ def main(unused_argv):
200296
if FLAGS.tflite:
201297
tflite_dir = os.path.join(save_dir, 'tflite')
202298
ensure_exits(tflite_dir)
203-
saved_model_to_tflite(save_dir, tflite_dir, FLAGS.metadata_file)
299+
saved_model_to_tflite(save_dir, tflite_dir,
300+
metadata_path if FLAGS.metadata else '')
204301

205302

206303
def console_entry_point():

ddsp/version.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,4 +19,4 @@
1919
pulling in all the dependencies in __init__.py.
2020
"""
2121

22-
__version__ = '3.3.0'
22+
__version__ = '3.3.1'

0 commit comments

Comments
 (0)