Skip to content

Commit 0f6132f

Browse files
jesseengelMagenta Team
authored andcommitted
Have model ddsp_export do renaming instead of colab notebook.
PiperOrigin-RevId: 447651496
1 parent 84f6df8 commit 0f6132f

File tree

4 files changed

+26
-20
lines changed

4 files changed

+26
-20
lines changed

ddsp/colab/demos/Train_VST.ipynb

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -273,17 +273,13 @@
273273
" export_path = os.path.join(model_dir, model_name)\n",
274274
"\n",
275275
" !ddsp_export \\\n",
276+
" --name=$model_name \\\n",
276277
" --model_path=$model_dir \\\n",
277278
" --save_dir=$export_path \\\n",
278279
" --inference_model=vst_stateless_predict_controls \\\n",
279280
" --tflite \\\n",
280281
" --notfjs\n",
281282
"\n",
282-
" # Rename tflite model.\n",
283-
" tflite_old_fp = os.path.join(export_path, 'tflite', 'model.tflite')\n",
284-
" tflite_new_fp = os.path.join(export_path, 'tflite', f'{model_name}.tflite')\n",
285-
" !mv $tflite_old_fp $tflite_new_fp\n",
286-
"\n",
287283
" # Zip the whole directory.\n",
288284
" zip_fname = f'{model_name}.zip'\n",
289285
" zip_fp = os.path.join(model_dir, zip_fname)\n",
@@ -361,7 +357,7 @@
361357
" # ------------------------------------------------------------------------------\n",
362358
" print('Installing DDSP...')\n",
363359
" print('This should take about 2 minutes...')\n",
364-
" !pip install -U ddsp[data_preparation]==3.4.1 \u0026\u003e /dev/null\n",
360+
" !pip install -U ddsp[data_preparation]==3.4.3 \u0026\u003e /dev/null\n",
365361
"\n",
366362
" # ------------------------------------------------------------------------------\n",
367363
" # Import DDSP\n",

ddsp/training/ddsp_export.py

Lines changed: 18 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -48,10 +48,14 @@
4848
import tensorflow as tf
4949
from tensorflowjs.converters import converter
5050

51+
# pylint: disable=pointless-string-statement
5152

5253
from tflite_support import metadata as _metadata
54+
# pylint: enable=pointless-string-statement
5355

54-
56+
flags.DEFINE_string(
57+
'name', '', 'Name of your model to use as folder and filename on export. '
58+
'Defaults to "export/" and "model.tflite" if none is provided.')
5559
flags.DEFINE_string(
5660
'model_path', '', 'Path to checkpoint or SavedModel directory. If no '
5761
'SavedModel is found, will search for latest checkpoint '
@@ -236,7 +240,10 @@ def saved_model_to_tfjs(input_dir, save_dir):
236240
print('TFJS Conversion Success!')
237241

238242

239-
def saved_model_to_tflite(input_dir, save_dir, metadata_file=None):
243+
def saved_model_to_tflite(input_dir,
244+
save_dir,
245+
metadata_file=None,
246+
name=''):
240247
"""Convert SavedModel to TFLite model."""
241248
print(f'\nConverting to TFLite:\nInput:{input_dir}\nOutput:{save_dir}\n')
242249
# Convert the model.
@@ -247,7 +254,8 @@ def saved_model_to_tflite(input_dir, save_dir, metadata_file=None):
247254
]
248255
tflite_model = tflite_converter.convert() # Byte string.
249256
# Save the model.
250-
save_path = os.path.join(save_dir, 'model.tflite')
257+
name = name if name else 'model'
258+
save_path = os.path.join(save_dir, f'{name}.tflite')
251259
with tf.io.gfile.GFile(save_path, 'wb') as f:
252260
f.write(tflite_model)
253261

@@ -302,6 +310,7 @@ def main(unused_argv):
302310
is_ckpt = not tf.io.gfile.isdir(model_path)
303311

304312
# Infer save directory path.
313+
export_name = FLAGS.name if FLAGS.name else 'export'
305314
if FLAGS.save_dir:
306315
save_dir = FLAGS.save_dir
307316
else:
@@ -310,10 +319,10 @@ def main(unused_argv):
310319
save_dir = model_path
311320
elif is_ckpt:
312321
# If model_path is a checkpoint file, use the directory of the file.
313-
save_dir = os.path.join(os.path.dirname(model_path), 'export')
322+
save_dir = os.path.join(os.path.dirname(model_path), export_name)
314323
else:
315324
# If model_path is a checkpoint directory, use child export directory.
316-
save_dir = os.path.join(model_path, 'export')
325+
save_dir = os.path.join(model_path, export_name)
317326

318327
# Make a new save directory.
319328
save_dir = train_util.expand_path(save_dir)
@@ -344,8 +353,10 @@ def main(unused_argv):
344353
if FLAGS.tflite:
345354
tflite_dir = os.path.join(save_dir, 'tflite')
346355
ensure_exits(tflite_dir)
347-
saved_model_to_tflite(save_dir, tflite_dir,
348-
metadata_path if FLAGS.metadata else '')
356+
saved_model_to_tflite(save_dir,
357+
tflite_dir,
358+
metadata_path if FLAGS.metadata else '',
359+
name=FLAGS.name)
349360

350361

351362
def console_entry_point():

ddsp/training/postprocessing.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -254,24 +254,22 @@ def fit_transform(self, x):
254254
def compute_dataset_statistics(data_provider,
255255
batch_size=1,
256256
power_frame_size=1024,
257-
power_frame_rate=50,
258-
legacy=False):
257+
power_frame_rate=50):
259258
"""Calculate dataset stats.
260259
261260
Args:
262261
data_provider: A DataProvider from ddsp.training.data.
263262
batch_size: Iterate over dataset with this batch size.
264263
power_frame_size: Calculate power features on the fly with this frame size.
265264
power_frame_rate: Calculate power features on the fly with this frame rate.
266-
legacy: Use the 'audio' key instead of 'audio_16k'.
267265
268266
Returns:
269267
Dictionary of dataset statistics. This is an overcomplete set of statistics,
270268
as there are now several different tone transfer implementations (js, colab,
271269
vst) that need different statistics for normalization.
272270
"""
273271
print('Calculating dataset statistics for', data_provider)
274-
data_iter = iter(data_provider.get_batch(batch_size, repeats=1))
272+
ds = data_provider.get_batch(batch_size, repeats=1)
275273

276274
# Unpack dataset.
277275
i = 0
@@ -281,9 +279,10 @@ def compute_dataset_statistics(data_provider,
281279
f0_conf = []
282280
audio = []
283281

284-
audio_key = 'audio' if legacy else 'audio_16k'
282+
batch = next(iter(ds))
283+
audio_key = 'audio_16k' if 'audio_16k' in batch.keys() else 'audio'
285284

286-
for batch in data_iter:
285+
for batch in iter(ds):
287286
loudness.append(batch['loudness_db'])
288287
power.append(
289288
spectral_ops.compute_power(batch[audio_key],

ddsp/version.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,4 +18,4 @@
1818
pulling in all the dependencies in __init__.py.
1919
"""
2020

21-
__version__ = '3.4.2'
21+
__version__ = '3.4.3'

0 commit comments

Comments
 (0)