Skip to content

Commit 1cca34d

Browse files
jesseengelMagenta Team
authored andcommitted
Train models with adjustable sample rate. Update dataset creation, data provider, gin configs, and inference models / export to support 32kHz and 48kHz models.
PiperOrigin-RevId: 443692298
1 parent def2c6b commit 1cca34d

File tree

9 files changed

+383
-64
lines changed

9 files changed

+383
-64
lines changed

ddsp/training/data.py

Lines changed: 32 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,8 @@
1616
import os
1717

1818
from absl import logging
19+
from ddsp.spectral_ops import CREPE_FRAME_SIZE
20+
from ddsp.spectral_ops import CREPE_SAMPLE_RATE
1921
from ddsp.spectral_ops import get_framed_lengths
2022
import gin
2123
import tensorflow.compat.v2 as tf
@@ -199,31 +201,30 @@ def preprocess_ex(ex):
199201
return dataset
200202

201203

202-
class RecordProvider(DataProvider):
203-
"""Class for reading records and returning a dataset."""
204+
@gin.register
205+
class TFRecordProvider(DataProvider):
206+
"""Class for reading TFRecords and returning a dataset."""
204207

205208
def __init__(self,
206-
file_pattern,
207-
example_secs,
208-
sample_rate,
209-
frame_rate,
210-
data_format_map_fn,
209+
file_pattern=None,
210+
example_secs=4,
211+
sample_rate=16000,
212+
frame_rate=250,
211213
centered=False):
212214
"""RecordProvider constructor."""
215+
super().__init__(sample_rate, frame_rate)
213216
self._file_pattern = file_pattern or self.default_file_pattern
214217
self._audio_length = example_secs * sample_rate
215-
super().__init__(sample_rate, frame_rate)
218+
self._audio_16k_length = example_secs * CREPE_SAMPLE_RATE
216219
self._feature_length = self.get_feature_length(centered)
217-
self._data_format_map_fn = data_format_map_fn
218220

219221
def get_feature_length(self, centered):
220222
"""Take into account center padding to get number of frames."""
221223
# Number of frames is independent of frame size for "center/same" padding.
222-
frame_size = 1024
223-
hop_size = self.sample_rate / self.frame_rate
224+
hop_size = CREPE_SAMPLE_RATE / self.frame_rate
224225
padding = 'center' if centered else 'same'
225226
return get_framed_lengths(
226-
self._audio_length, frame_size, hop_size, padding)[0]
227+
self._audio_16k_length, CREPE_FRAME_SIZE, hop_size, padding)[0]
227228

228229
@property
229230
def default_file_pattern(self):
@@ -246,7 +247,7 @@ def parse_tfexample(record):
246247

247248
filenames = tf.data.Dataset.list_files(self._file_pattern, shuffle=shuffle)
248249
dataset = filenames.interleave(
249-
map_func=self._data_format_map_fn,
250+
map_func=tf.data.TFRecordDataset,
250251
cycle_length=40,
251252
num_parallel_calls=_AUTOTUNE)
252253
dataset = dataset.map(parse_tfexample, num_parallel_calls=_AUTOTUNE)
@@ -258,6 +259,8 @@ def features_dict(self):
258259
return {
259260
'audio':
260261
tf.io.FixedLenFeature([self._audio_length], dtype=tf.float32),
262+
'audio_16k':
263+
tf.io.FixedLenFeature([self._audio_16k_length], dtype=tf.float32),
261264
'f0_hz':
262265
tf.io.FixedLenFeature([self._feature_length], dtype=tf.float32),
263266
'f0_confidence':
@@ -268,18 +271,22 @@ def features_dict(self):
268271

269272

270273
@gin.register
271-
class TFRecordProvider(RecordProvider):
274+
class LegacyTFRecordProvider(TFRecordProvider):
272275
"""Class for reading TFRecords and returning a dataset."""
273276

274-
def __init__(self,
275-
file_pattern=None,
276-
example_secs=4,
277-
sample_rate=16000,
278-
frame_rate=250,
279-
centered=False):
280-
"""TFRecordProvider constructor."""
281-
super().__init__(file_pattern, example_secs, sample_rate,
282-
frame_rate, tf.data.TFRecordDataset, centered=centered)
277+
@property
278+
def features_dict(self):
279+
"""Dictionary of features to read from dataset."""
280+
return {
281+
'audio':
282+
tf.io.FixedLenFeature([self._audio_length], dtype=tf.float32),
283+
'f0_hz':
284+
tf.io.FixedLenFeature([self._feature_length], dtype=tf.float32),
285+
'f0_confidence':
286+
tf.io.FixedLenFeature([self._feature_length], dtype=tf.float32),
287+
'loudness_db':
288+
tf.io.FixedLenFeature([self._feature_length], dtype=tf.float32),
289+
}
283290

284291

285292
# ------------------------------------------------------------------------------
@@ -397,7 +404,7 @@ def get_dataset(self, shuffle=True):
397404
# Synthetic Data for InverseSynthesis
398405
# ------------------------------------------------------------------------------
399406
@gin.register
400-
class SyntheticNotes(TFRecordProvider):
407+
class SyntheticNotes(LegacyTFRecordProvider):
401408
"""Create self-supervised control signal.
402409
403410
EXPERIMENTAL
@@ -440,7 +447,7 @@ def features_dict(self):
440447

441448

442449
@gin.register
443-
class Urmp(TFRecordProvider):
450+
class Urmp(LegacyTFRecordProvider):
444451
"""Urmp training set."""
445452

446453
def __init__(self,

ddsp/training/data_preparation/prepare_tfrecord_lib.py

Lines changed: 21 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
import pydub
2222
import tensorflow.compat.v2 as tf
2323

24+
CREPE_SAMPLE_RATE = spectral_ops.CREPE_SAMPLE_RATE # 16kHz.
2425

2526

2627
def _load_audio_as_array(audio_path, sample_rate):
@@ -57,24 +58,32 @@ def _load_audio(audio_path, sample_rate):
5758
logging.info("Loading '%s'.", audio_path)
5859
beam.metrics.Metrics.counter('prepare-tfrecord', 'load-audio').inc()
5960
audio = _load_audio_as_array(audio_path, sample_rate)
60-
return {'audio': audio}
61+
if sample_rate != CREPE_SAMPLE_RATE:
62+
audio_16k = _load_audio_as_array(audio_path, CREPE_SAMPLE_RATE)
63+
else:
64+
audio_16k = audio
65+
return {'audio': audio, 'audio_16k': audio_16k}
6166

6267

6368
def _chunk_audio(ex, sample_rate, chunk_secs):
6469
"""Pad audio and split into chunks."""
6570
beam.metrics.Metrics.counter('prepare-tfrecord', 'load-audio').inc()
66-
audio = ex['audio']
67-
chunk_size = int(chunk_secs * sample_rate)
68-
chunks = tf.signal.frame(audio, chunk_size, chunk_size, pad_end=True)
71+
def get_chunks(audio, sample_rate):
72+
chunk_size = int(chunk_secs * sample_rate)
73+
return tf.signal.frame(audio, chunk_size, chunk_size, pad_end=True).numpy()
74+
75+
chunks = get_chunks(ex['audio'], sample_rate)
76+
chunks_16k = get_chunks(ex['audio_16k'], CREPE_SAMPLE_RATE)
77+
assert chunks.shape[0] == chunks_16k.shape[0]
6978
n_chunks = chunks.shape[0]
7079
for i in range(n_chunks):
71-
yield {'audio': chunks[i].numpy()}
80+
yield {'audio': chunks[i], 'audio_16k': chunks_16k[i]}
7281

7382

7483
def _add_f0_estimate(ex, frame_rate, center, viterbi):
7584
"""Add fundamental frequency (f0) estimate using CREPE."""
7685
beam.metrics.Metrics.counter('prepare-tfrecord', 'estimate-f0').inc()
77-
audio = ex['audio']
86+
audio = ex['audio_16k']
7887
padding = 'center' if center else 'same'
7988
f0_hz, f0_confidence = spectral_ops.compute_f0(
8089
audio, frame_rate, viterbi=viterbi, padding=padding)
@@ -86,13 +95,13 @@ def _add_f0_estimate(ex, frame_rate, center, viterbi):
8695
return ex
8796

8897

89-
def _add_loudness(ex, sample_rate, frame_rate, n_fft, center):
98+
def _add_loudness(ex, frame_rate, n_fft, center):
9099
"""Add loudness in dB."""
91100
beam.metrics.Metrics.counter('prepare-tfrecord', 'compute-loudness').inc()
92-
audio = ex['audio']
101+
audio = ex['audio_16k']
93102
padding = 'center' if center else 'same'
94103
loudness_db = spectral_ops.compute_loudness(
95-
audio, sample_rate, frame_rate, n_fft, padding=padding)
104+
audio, CREPE_SAMPLE_RATE, frame_rate, n_fft, padding=padding)
96105
ex = dict(ex)
97106
ex['loudness_db'] = loudness_db.numpy().astype(np.float32)
98107
return ex
@@ -113,14 +122,16 @@ def get_windows(sequence, rate, center):
113122
end = start + window_size
114123
yield sequence[start:end]
115124

116-
for audio, loudness_db, f0_hz, f0_confidence in zip(
125+
for audio, audio_16k, loudness_db, f0_hz, f0_confidence in zip(
117126
get_windows(ex['audio'], sample_rate, center=False),
127+
get_windows(ex['audio_16k'], CREPE_SAMPLE_RATE, center=False),
118128
get_windows(ex['loudness_db'], frame_rate, center),
119129
get_windows(ex['f0_hz'], frame_rate, center),
120130
get_windows(ex['f0_confidence'], frame_rate, center)):
121131
beam.metrics.Metrics.counter('prepare-tfrecord', 'split-example').inc()
122132
yield {
123133
'audio': audio,
134+
'audio_16k': audio_16k,
124135
'loudness_db': loudness_db,
125136
'f0_hz': f0_hz,
126137
'f0_confidence': f0_confidence
@@ -238,7 +249,6 @@ def postprocess_pipeline(examples, output_path, stage_name=''):
238249
center=center,
239250
viterbi=viterbi)
240251
| beam.Map(_add_loudness,
241-
sample_rate=sample_rate,
242252
frame_rate=frame_rate,
243253
n_fft=512,
244254
center=center))

ddsp/training/data_preparation/prepare_tfrecord_lib_test.py

Lines changed: 43 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,8 @@
2626
import scipy.io.wavfile
2727
import tensorflow.compat.v2 as tf
2828

29+
CREPE_SAMPLE_RATE = spectral_ops.CREPE_SAMPLE_RATE
30+
2931

3032
class PrepareTFRecordBeamTest(parameterized.TestCase):
3133

@@ -70,7 +72,7 @@ def validate_outputs(self, expected_num_examples, expected_feature_lengths):
7072
try:
7173
self.assertLen(arr, expected_len)
7274
except AssertionError as e:
73-
raise AssertionError('%s feature: %s' % (e, feat))
75+
raise AssertionError('feature: %s' % feat) from e
7476
self.assertFalse(any(np.isinf(arr)))
7577

7678
def get_expected_length(self, input_length, frame_rate, center=False):
@@ -139,6 +141,7 @@ def test_prepare_tfrecord(self, chunk_secs, example_secs):
139141
expected_n_batch,
140142
{
141143
'audio': expected_n_t,
144+
'audio_16k': expected_n_t,
142145
'f0_hz': expected_n_frames,
143146
'f0_confidence': expected_n_frames,
144147
'loudness_db': expected_n_frames,
@@ -169,13 +172,49 @@ def test_centering(self, center):
169172
self.validate_outputs(
170173
n_batch, {
171174
'audio': n_t,
175+
'audio_16k': n_t,
176+
'f0_hz': n_frames,
177+
'f0_confidence': n_frames,
178+
'loudness_db': n_frames,
179+
})
180+
181+
@parameterized.named_parameters(
182+
('16kHz', 16000),
183+
('32kHz', 32000),
184+
('48kHz', 48000))
185+
def test_sample_rate(self, sample_rate):
186+
frame_rate = 250
187+
example_secs = 0.3
188+
hop_secs = 0.1
189+
center = True
190+
n_batch = self.get_n_per_chunk(self.wav_secs, example_secs, hop_secs)
191+
prepare_tfrecord_lib.prepare_tfrecord(
192+
[self.wav_path],
193+
os.path.join(self.test_dir, 'output.tfrecord'),
194+
num_shards=2,
195+
sample_rate=sample_rate,
196+
frame_rate=frame_rate,
197+
example_secs=example_secs,
198+
hop_secs=hop_secs,
199+
center=center,
200+
chunk_secs=None)
201+
202+
n_t = int(example_secs * sample_rate)
203+
n_t_16k = int(example_secs * CREPE_SAMPLE_RATE)
204+
n_frames = self.get_expected_length(n_t_16k, frame_rate, center)
205+
n_expected_frames = 76 # (250 * 0.3) + 1.
206+
self.assertEqual(n_frames, n_expected_frames)
207+
self.validate_outputs(
208+
n_batch, {
209+
'audio': n_t,
210+
'audio_16k': n_t_16k,
172211
'f0_hz': n_frames,
173212
'f0_confidence': n_frames,
174213
'loudness_db': n_frames,
175214
})
176215

177-
@parameterized.named_parameters(('16k', 16000), ('24k', 24000),
178-
('48k', 48000))
216+
@parameterized.named_parameters(('16kHz', 16000), ('44.1kHz', 44100),
217+
('48kHz', 48000))
179218
def test_audio_only(self, sample_rate):
180219
prepare_tfrecord_lib.prepare_tfrecord(
181220
[self.wav_path],
@@ -189,6 +228,7 @@ def test_audio_only(self, sample_rate):
189228
self.validate_outputs(
190229
1, {
191230
'audio': int(self.wav_secs * sample_rate),
231+
'audio_16k': int(self.wav_secs * CREPE_SAMPLE_RATE),
192232
})
193233

194234

0 commit comments

Comments
 (0)