Skip to content

Commit 761d6e3

Browse files
jesseengelMagenta Team
authored andcommitted
Simplified colab for training VST models.
PiperOrigin-RevId: 442655762
1 parent 4386596 commit 761d6e3

File tree

3 files changed

+322
-1
lines changed

3 files changed

+322
-1
lines changed

ddsp/colab/demos/README.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,3 +10,6 @@ Here are colab notebooks for demonstrating neat things you can do with DDSP.
1010

1111
* [pitch_detection](https://colab.research.google.com/github/magenta/ddsp/blob/main/ddsp/colab/demos/pitch_detection.ipynb):
1212
Demonstration of self-supervised pitch detection models from [2020 ICML Workshop paper](https://openreview.net/forum?id=RlVTYWhsky7).
13+
14+
* [Train_VST](https://colab.research.google.com/github/magenta/ddsp/blob/main/ddsp/colab/demos/Train_VST.ipynb):
15+
Simplified training colab for the real-time audio plugin (WIP).

ddsp/colab/demos/Train_VST.ipynb

Lines changed: 318 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,318 @@
1+
{
2+
"cells": [
3+
{
4+
"cell_type": "code",
5+
"execution_count": null,
6+
"metadata": {
7+
"cellView": "form",
8+
"id": "VxPuPR0j5Gs7"
9+
},
10+
"outputs": [],
11+
"source": [
12+
"# ------------------------------------------------------------------------------\n",
13+
"# Copyright 2022 Google LLC. All Rights Reserved.\n",
14+
"#\n",
15+
"# Licensed under the Apache License, Version 2.0 (the \"License\");\n",
16+
"# you may not use this file except in compliance with the License.\n",
17+
"# You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0\n",
18+
"#\n",
19+
"# Unless required by applicable law or agreed to in writing, software\n",
20+
"# distributed under the License is distributed on an \"AS IS\" BASIS,\n",
21+
"# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n",
22+
"# See the License for the specific language governing permissions and\n",
23+
"# limitations under the License.\n",
24+
"# ------------------------------------------------------------------------------\n",
25+
"\n",
26+
"#@title Train your own DDSP-VST Model\n",
27+
"#@markdown Just press the ▶️ button!\n",
28+
"\n",
29+
"#@markdown \u003cbr/\u003e Custom models can train on as little as 10 minutes of audio (`.wav` or `.mp3`). Best results from \"monophonic\" (only one note at a time) audio from a single recording session (same mic, same reverb).\n",
30+
"\n",
31+
"#@markdown Training typically takes ~2-3 hours with free Colab, and less than an hour with ColabPro.\n",
32+
"\n",
33+
"\n",
34+
"#@markdown We recommend using Google Drive for training to load faster and save your model during training. Just create a folder on your drive with your audio files in it, and select the folder. If you don't use drive, you can still upload audio through the browser (slower) and download the final trained model.\n",
35+
"\n",
36+
"#@markdown Colab often kicks people off after ~12 hours, but hopefully that shouldn't be a problem.\n",
37+
"\n",
38+
"#@markdown After training, it should automatically export and download your model as `{my_name}.tflite` that you can use by dropping in the VST custom models folder. If it doesn't automatically download, you can find the file in the `ddsp-training-{date-time}/export` folder either on this page (click the 📁 icon on the left), or in the folder you selected from your drive.\n",
39+
"\n",
40+
"\n",
41+
"#@markdown \u003cbr/\u003e \u003cbr/\u003e\n",
42+
"#@markdown Name your model!\n",
43+
"Name = 'MyInstrument' #@param {type:\"string\"}\n",
44+
"Name = Name.replace(' ', '_')\n",
45+
"\n",
46+
"#@markdown \u003cbr/\u003e\n",
47+
"#@markdown Use Google Drive for training?\n",
48+
"Google_Drive = True #@param {type:\"boolean\"}\n",
49+
"\n",
50+
"\n",
51+
"\n",
52+
"#@markdown \u003cbr/\u003e \u003cbr/\u003e\n",
53+
"#@markdown ### Advanced Options\n",
54+
"\n",
55+
"#@markdown \u003ca href=\"https://colab.research.google.com/github/magenta/ddsp/blob/main/ddsp/colab/demos/Train_VST.ipynb\" target=\"_parent\"\u003e\u003cimg src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/\u003e\u003c/a\u003e\n",
56+
"\n",
57+
"Training_Steps = 30000 #@param {type:\"integer\"}\n",
58+
"\n",
59+
"#@markdown \u003cbr/\u003e\n",
60+
"#@markdown Ignore previous checkpoints and start a fresh run\n",
61+
"\n",
62+
"Ignore_Previous = False #@param {type:\"boolean\"}\n",
63+
"\n",
64+
"\n",
65+
"# Sample_Rate = '16kHz' #@param ['16kHz', '32kHz', '48kHz']\n",
66+
"# Sample_Rate = {'16kHz': 16000, '32kHz': 32000, '48kHz': 48000}[Sample_Rate]\n",
67+
"# Model_Gin_File = 'models/vst/vst.gin'\n",
68+
"\n",
69+
"\n",
70+
"\n",
71+
"\n",
72+
"# ------------------------------------------------------------------------------\n",
73+
"# Install\n",
74+
"# ------------------------------------------------------------------------------\n",
75+
"print('Installing DDSP...')\n",
76+
"print('This should take about 2 minutes...')\n",
77+
"!pip install -U ddsp[data_preparation]==3.3.4 \u0026\u003e /dev/null\n",
78+
"!pip install ipyfilechooser \u0026\u003e /dev/null\n",
79+
"\n",
80+
"\n",
81+
"# ------------------------------------------------------------------------------\n",
82+
"# Imports\n",
83+
"# ------------------------------------------------------------------------------\n",
84+
"print('Importing Libraries...')\n",
85+
"print()\n",
86+
"import datetime\n",
87+
"import glob\n",
88+
"import os\n",
89+
"import shutil\n",
90+
"\n",
91+
"from ddsp import spectral_ops\n",
92+
"from ddsp.colab import colab_utils\n",
93+
"import ddsp.training\n",
94+
"import gin\n",
95+
"from google.colab import drive\n",
96+
"from ipyfilechooser import FileChooser\n",
97+
"import pydub\n",
98+
"from matplotlib import pyplot as plt\n",
99+
"import numpy as np\n",
100+
"import tensorflow as tf\n",
101+
"\n",
102+
"from ddsp.training.data_preparation.prepare_tfrecord_lib import _load_audio_as_array as load_audio\n",
103+
"\n",
104+
"\n",
105+
"# ------------------------------------------------------------------------------\n",
106+
"# Functions\n",
107+
"# ------------------------------------------------------------------------------\n",
108+
"def directory_has_files(target_dir):\n",
109+
" n_files = len(glob.glob(os.path.join(target_dir, '*')))\n",
110+
" return n_files \u003e 0\n",
111+
"\n",
112+
"\n",
113+
"def get_audio_files(drive_dir, audio_dir):\n",
114+
" if drive_dir:\n",
115+
" mp3_files = glob.glob(os.path.join(drive_dir, '*.mp3'))\n",
116+
" wav_files = glob.glob(os.path.join(drive_dir, '*.wav'))\n",
117+
" audio_paths = mp3_files + wav_files\n",
118+
" if len(audio_paths) \u003c 1:\n",
119+
" raise FileNotFoundError(\"Sorry, it seems that there aren't any MP3 or \"\n",
120+
" f\"WAV files in your folder ({drive_dir}). Try \"\n",
121+
" \"running again and choose a different folder.\")\n",
122+
" else:\n",
123+
" audio_paths, _ = colab_utils.upload()\n",
124+
"\n",
125+
" # Copy Audio.\n",
126+
" for src in audio_paths:\n",
127+
" target = os.path.join(audio_dir, \n",
128+
" os.path.basename(src).replace(' ', '_'))\n",
129+
" print('Copying {} to {}'.format(src, target))\n",
130+
" shutil.copy(src, target)\n",
131+
" # !cp $src $target\n",
132+
"\n",
133+
"\n",
134+
"def prepare_dataset(audio_dir, \n",
135+
" data_dir,\n",
136+
" sample_rate=16000, \n",
137+
" frame_rate=50, \n",
138+
" example_secs=4.0, \n",
139+
" hop_secs=1.0, \n",
140+
" viterbi=True, \n",
141+
" center=True):\n",
142+
" if directory_has_files(data_dir):\n",
143+
" print(f'Dataset already exists in `{data_dir}`')\n",
144+
" return\n",
145+
" else:\n",
146+
" # Otherwise prepare new dataset locally.\n",
147+
" print(f'Preparing new dataset from `{audio_dir}`')\n",
148+
"\n",
149+
" print()\n",
150+
" print('Creating dataset...')\n",
151+
" print('This usually takes around 2-3 minutes for each minute of audio')\n",
152+
" print('(10 minutes of training audio -\u003e 20-30 minutes)')\n",
153+
"\n",
154+
" audio_filepattern = os.path.join(audio_dir, '*')\n",
155+
" !ddsp_prepare_tfrecord \\\n",
156+
" --input_audio_filepatterns=$audio_filepattern \\\n",
157+
" --output_tfrecord_path=$data_dir/train.tfrecord \\\n",
158+
" --num_shards=10 \\\n",
159+
" --sample_rate=$sample_rate \\\n",
160+
" --frame_rate=$frame_rate \\\n",
161+
" --example_secs=$example_secs \\\n",
162+
" --hop_secs=$hop_secs \\\n",
163+
" --viterbi=$viterbi \\\n",
164+
" --center=$center \\\n",
165+
" --alsologtostderr \u0026\u003e /dev/null\n",
166+
"\n",
167+
"\n",
168+
"def train(model_dir, data_dir, steps=30000):\n",
169+
" file_pattern = os.path.join(data_dir, 'train.tfrecord*')\n",
170+
" !ddsp_run \\\n",
171+
" --mode=train \\\n",
172+
" --save_dir=\"$model_dir\" \\\n",
173+
" --gin_file=models/vst/vst.gin \\\n",
174+
" --gin_file=datasets/tfrecord.gin \\\n",
175+
" --gin_param=\"TFRecordProvider.file_pattern='$file_pattern'\" \\\n",
176+
" --gin_param=\"TFRecordProvider.centered=True\" \\\n",
177+
" --gin_param=\"TFRecordProvider.frame_rate=50\" \\\n",
178+
" --gin_param=\"batch_size=16\" \\\n",
179+
" --gin_param=\"train_util.train.num_steps=$steps\" \\\n",
180+
" --gin_param=\"train_util.train.steps_per_save=300\" \\\n",
181+
" --gin_param=\"trainers.Trainer.checkpoints_to_keep=3\"\n",
182+
"\n",
183+
" # --gin_param=\"train.data_provider=@ExperimentalDataProvider()\" \\\n",
184+
" # --gin_param=\"ExperimentalRecordProvider.data_dir='$data_dir'\" \\\n",
185+
" # --gin_param=\"ExperimentalRecordProvider.sample_rate=16000\" \\\n",
186+
" # --gin_param=\"ExperimentalRecordProvider.frame_rate=50\" \\\n",
187+
"\n",
188+
"\n",
189+
"def launch_tensorboard(save_dir):\n",
190+
" %reload_ext tensorboard\n",
191+
" import tensorboard as tb\n",
192+
" tb.notebook.start('--logdir \"{}\"'.format(save_dir))\n",
193+
"\n",
194+
"\n",
195+
"def reset_state(data_dir, audio_dir, model_dir):\n",
196+
" if tf.io.gfile.exists(data_dir):\n",
197+
" !rm -r $data_dir\n",
198+
" !rm -r $audio_dir\n",
199+
" !mkdir -p $data_dir\n",
200+
" !mkdir -p $audio_dir\n",
201+
" !mkdir -p $model_dir\n",
202+
"\n",
203+
"\n",
204+
"def export_and_download(model_dir, model_name=Name):\n",
205+
" export_path = os.path.join(model_dir, 'export')\n",
206+
"\n",
207+
" !ddsp_export \\\n",
208+
" --model_path=$model_dir \\\n",
209+
" --save_dir=$export_path \\\n",
210+
" --inference_model=vst_stateless_predict_controls \\\n",
211+
" --tflite \\\n",
212+
" --notfjs\n",
213+
"\n",
214+
" # Just copy the tflite model.\n",
215+
" tflite_fp = os.path.join(export_path, 'tflite', 'model.tflite')\n",
216+
" my_model = os.path.join(model_dir, f'{model_name}.tflite')\n",
217+
" !cp $tflite_fp $my_model\n",
218+
" print('Export Complete! Downloading...')\n",
219+
" print(f'You can also find your model at {my_model}')\n",
220+
" colab_utils.download(my_model)\n",
221+
"\n",
222+
" # Copy the whole directory.\n",
223+
" # my_model = f'{model_name}.zip'\n",
224+
" # !zip -r $my_model $export_path\n",
225+
" # colab_utils.download(my_model)\n",
226+
"\n",
227+
"\n",
228+
"def get_model_dir(base_dir):\n",
229+
" base_str = 'ddsp-training'\n",
230+
" dirs = tf.io.gfile.glob(os.path.join(base_dir, f'{base_str}-*'))\n",
231+
" if dirs and not Ignore_Previous:\n",
232+
" model_dir = dirs[-1] # Sorted, so last is most recent.\n",
233+
" else:\n",
234+
" now = datetime.datetime.now().strftime('%Y-%m-%d-%H%M')\n",
235+
" model_dir = os.path.join(base_dir, f'{base_str}-{now}')\n",
236+
" return model_dir\n",
237+
"\n",
238+
"\n",
239+
"\n",
240+
"\n",
241+
"def run_training(drive_dir=''):\n",
242+
"\n",
243+
" # ------------------------------------------------------------------------------\n",
244+
" # Setup\n",
245+
" # ------------------------------------------------------------------------------\n",
246+
" # Save data locally, but model on drive.\n",
247+
" data_dir = 'data/'\n",
248+
" audio_dir = 'audio/'\n",
249+
" model_dir = get_model_dir(drive_dir)\n",
250+
"\n",
251+
" reset_state(data_dir, audio_dir, model_dir)\n",
252+
"\n",
253+
" # ------------------------------------------------------------------------------\n",
254+
" # Dataset\n",
255+
" # ------------------------------------------------------------------------------\n",
256+
" get_audio_files(drive_dir, audio_dir)\n",
257+
" prepare_dataset(audio_dir, data_dir)\n",
258+
"\n",
259+
" # ------------------------------------------------------------------------------\n",
260+
" # Train\n",
261+
" # ------------------------------------------------------------------------------\n",
262+
" print()\n",
263+
" print('Training...')\n",
264+
" train(model_dir, data_dir, steps=Training_Steps)\n",
265+
"\n",
266+
" # ------------------------------------------------------------------------------\n",
267+
" # Export\n",
268+
" # ------------------------------------------------------------------------------\n",
269+
" print()\n",
270+
" print('Exporting model...')\n",
271+
" export_and_download(model_dir)\n",
272+
"\n",
273+
"\n",
274+
"def run(Google_Drive=True):\n",
275+
" \"\"\"Create and display a FileChooser widget.\"\"\"\n",
276+
"\n",
277+
" if Google_Drive:\n",
278+
" print('Mounting Google Drive...')\n",
279+
" drive.mount('gdrive', force_remount=True, timeout_ms=10000) \n",
280+
" initial_dir = 'gdrive/MyDrive'\n",
281+
"\n",
282+
" def run_after_select(chooser):\n",
283+
" drive_dir = chooser.selected_path\n",
284+
" run_training(drive_dir=drive_dir)\n",
285+
"\n",
286+
" fc = FileChooser(initial_dir)\n",
287+
" fc.show_only_dirs = True\n",
288+
" fc.title = '\u003cb\u003ePick a folder with the audio files for training...\u003c/b\u003e'\n",
289+
" fc.register_callback(run_after_select)\n",
290+
" display(fc)\n",
291+
"\n",
292+
"\n",
293+
" else:\n",
294+
" print('Skipping Drive Setup...')\n",
295+
" print('Upload Audio Manually...')\n",
296+
" run_training(drive_dir='')\n",
297+
"\n",
298+
"\n",
299+
"run(Google_Drive)\n"
300+
]
301+
}
302+
],
303+
"metadata": {
304+
"accelerator": "GPU",
305+
"colab": {
306+
"collapsed_sections": [],
307+
"name": "Train_VST.ipynb",
308+
"private_outputs": true,
309+
"provenance": []
310+
},
311+
"kernelspec": {
312+
"display_name": "Python 3",
313+
"name": "python3"
314+
}
315+
},
316+
"nbformat": 4,
317+
"nbformat_minor": 0
318+
}

ddsp/version.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,4 +18,4 @@
1818
pulling in all the dependencies in __init__.py.
1919
"""
2020

21-
__version__ = '3.3.4'
21+
__version__ = '3.3.5'

0 commit comments

Comments
 (0)