Skip to content

Commit be1477f

Browse files
authored
Merge pull request #6953 from pritamdodeja/vertex_tuner_example
Vertex AI Tuner component example
2 parents 7615e5a + 38672b2 commit be1477f

File tree

1 file changed

+60
-0
lines changed

1 file changed

+60
-0
lines changed

docs/guide/tuner.md

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -207,6 +207,66 @@ algorithm uses information from results of prior trials, such as Google Vizier
207207
algorithm implemented in the AI Platform Vizier does, an excessively parallel
208208
search would negatively affect the efficacy of the search.
209209

210+
It is also possible to use the new Vertex AI api as in the example shown below.
211+
```
212+
from tfx.v1.extensions.google_cloud_ai_platform import Tuner
213+
ai_platform_tuning_args = {
214+
'project': GOOGLE_CLOUD_PROJECT,
215+
'job_spec': {
216+
# 'service_account': ACCOUNT,
217+
'worker_pool_specs': [{'container_spec': {'image_uri': default_kfp_image},
218+
'machine_spec': {'machine_type': MACHINE_TYPE,
219+
'accelerator_type': accelerator_type,
220+
'accelerator_count': 1
221+
},
222+
'replica_count': 1}],
223+
224+
# "enable_web_access": True, #In case you need to debug from within the container
225+
}
226+
}
227+
vertex_job_spec = {
228+
'project': GOOGLE_CLOUD_PROJECT,
229+
'job_spec': {
230+
'worker_pool_specs': [{
231+
'machine_spec': {
232+
'machine_type': MACHINE_TYPE,
233+
'accelerator_type': accelerator_type,
234+
'accelerator_count': 1
235+
},
236+
'replica_count': 1,
237+
'container_spec': {
238+
'image_uri': default_kfp_image,
239+
},
240+
}],
241+
"enable_web_access": True,
242+
}
243+
}
244+
tuner = Tuner(
245+
module_file=_tuner_module_file,
246+
examples=transform.outputs['transformed_examples'],
247+
transform_graph=transform.outputs['transform_graph'],
248+
train_args=proto.TrainArgs(
249+
splits=['train'], num_steps=int(
250+
TRAINING_STEPS // 4)),
251+
eval_args=proto.EvalArgs(
252+
splits=['eval'], num_steps=int(
253+
VAL_STEPS // 4)),
254+
tune_args=proto.TuneArgs(num_parallel_trials=num_parallel_trials),
255+
custom_config={
256+
tfx.extensions.google_cloud_ai_platform.ENABLE_VERTEX_KEY:
257+
True,
258+
tfx.extensions.google_cloud_ai_platform.VERTEX_REGION_KEY:
259+
GOOGLE_CLOUD_REGION,
260+
tfx.extensions.google_cloud_ai_platform.experimental.TUNING_ARGS_KEY:
261+
vertex_job_spec,
262+
'use_gpu':
263+
USE_GPU,
264+
'ai_platform_tuning_args': ai_platform_tuning_args,
265+
tfx.extensions.google_cloud_ai_platform.experimental.REMOTE_TRIALS_WORKING_DIR_KEY: os.path.join(PIPELINE_ROOT, 'trials'),
266+
267+
}
268+
)
269+
```
210270
!!! Note
211271
Each trial in each parallel search is conducted on a single machine in the
212272
worker flock, i.e., each trial does not take advantage of multi-worker

0 commit comments

Comments
 (0)