@@ -172,7 +172,9 @@ component does not have ability to execute more than one search worker in
172
172
parallel, by using the
173
173
[ Google Cloud AI Platform extension Tuner component] ( https://github.com/tensorflow/tfx/blob/master/tfx/extensions/google_cloud_ai_platform/tuner/component.py ) ,
174
174
it provides the ability to run parallel tuning, using an AI Platform Training
175
- Job as a distributed worker flock manager.
175
+ Job as a distributed worker flock manager.
176
+
177
+
176
178
[ TuneArgs] ( https://github.com/tensorflow/tfx/blob/master/tfx/proto/tuner.proto )
177
179
is the configuration given to this component. This is a drop-in replacement of
178
180
the stock Tuner component.
@@ -207,6 +209,66 @@ algorithm uses information from results of prior trials, such as Google Vizier
207
209
algorithm implemented in the AI Platform Vizier does, an excessively parallel
208
210
search would negatively affect the efficacy of the search.
209
211
212
+ It is also possible to use the new Vertex AI api as in the example shown below.
213
+ ```
214
+ from tfx.v1.extensions.google_cloud_ai_platform import Tuner
215
+ ai_platform_tuning_args = {
216
+ 'project': GOOGLE_CLOUD_PROJECT,
217
+ 'job_spec': {
218
+ # 'service_account': ACCOUNT,
219
+ 'worker_pool_specs': [{'container_spec': {'image_uri': default_kfp_image},
220
+ 'machine_spec': {'machine_type': MACHINE_TYPE,
221
+ 'accelerator_type': accelerator_type,
222
+ 'accelerator_count': 1
223
+ },
224
+ 'replica_count': 1}],
225
+
226
+ # "enable_web_access": True, #In case you need to debug from within the container
227
+ }
228
+ }
229
+ vertex_job_spec = {
230
+ 'project': GOOGLE_CLOUD_PROJECT,
231
+ 'job_spec': {
232
+ 'worker_pool_specs': [{
233
+ 'machine_spec': {
234
+ 'machine_type': MACHINE_TYPE,
235
+ 'accelerator_type': accelerator_type,
236
+ 'accelerator_count': 1
237
+ },
238
+ 'replica_count': 1,
239
+ 'container_spec': {
240
+ 'image_uri': 'us-east1-docker.pkg.dev/itp-ml-sndbx/intuitive-ml-docker-repo/beam260tf215tft151deep:v2.60',
241
+ },
242
+ }],
243
+ "enable_web_access": True,
244
+ }
245
+ }
246
+ tuner = Tuner(
247
+ module_file=_tuner_module_file,
248
+ examples=transform.outputs['transformed_examples'],
249
+ transform_graph=transform.outputs['transform_graph'],
250
+ train_args=proto.TrainArgs(
251
+ splits=['train'], num_steps=int(
252
+ TRAINING_STEPS // 4)),
253
+ eval_args=proto.EvalArgs(
254
+ splits=['eval'], num_steps=int(
255
+ VAL_STEPS // 4)),
256
+ tune_args=proto.TuneArgs(num_parallel_trials=num_parallel_trials),
257
+ custom_config={
258
+ tfx.extensions.google_cloud_ai_platform.ENABLE_VERTEX_KEY:
259
+ True,
260
+ tfx.extensions.google_cloud_ai_platform.VERTEX_REGION_KEY:
261
+ GOOGLE_CLOUD_REGION,
262
+ tfx.extensions.google_cloud_ai_platform.experimental.TUNING_ARGS_KEY:
263
+ vertex_job_spec,
264
+ 'use_gpu':
265
+ USE_GPU,
266
+ 'ai_platform_tuning_args': ai_platform_tuning_args,
267
+ tfx.extensions.google_cloud_ai_platform.experimental.REMOTE_TRIALS_WORKING_DIR_KEY: os.path.join(PIPELINE_ROOT, 'trials'),
268
+
269
+ }
270
+ )
271
+ ```
210
272
!!! Note
211
273
Each trial in each parallel search is conducted on a single machine in the
212
274
worker flock, i.e., each trial does not take advantage of multi-worker
0 commit comments