Skip to content

Commit 7479317

Browse files
pritamdodejaPritam Dodeja
andauthored
Add feature to optionally provide stats to trainer (#7734)
Enables the user to use number of examples information computed by StatisticsGen in their training code. Passing statistics to trainer enables the use of fn_args.num_examples['train'] etc., in run_fn More details at: #7700 Co-authored-by: Pritam Dodeja <pdodeja@distml.home.int>
1 parent 723bdac commit 7479317

File tree

4 files changed

+38
-0
lines changed

4 files changed

+38
-0
lines changed

tfx/components/trainer/component.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,7 @@ class Trainer(base_component.BaseComponent):
7272

7373
def __init__(
7474
self,
75+
statistics: Optional[types.BaseChannel] = None,
7576
examples: Optional[types.BaseChannel] = None,
7677
transformed_examples: Optional[types.BaseChannel] = None,
7778
transform_graph: Optional[types.BaseChannel] = None,
@@ -170,6 +171,7 @@ def run_fn(trainer.fn_args_utils.FnArgs)
170171
model = types.Channel(type=standard_artifacts.Model)
171172
model_run = types.Channel(type=standard_artifacts.ModelRun)
172173
spec = standard_component_specs.TrainerSpec(
174+
statistics=statistics,
173175
examples=examples,
174176
transform_graph=transform_graph,
175177
schema=schema,

tfx/components/trainer/component_test.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
from tfx.components.trainer import executor
1919
from tfx.dsl.components.base import executor_spec
2020
from tfx.orchestration import data_types
21+
from tfx.types import artifact_utils
2122
from tfx.proto import trainer_pb2
2223
from tfx.types import channel_utils
2324
from tfx.types import standard_artifacts
@@ -30,6 +31,10 @@ def setUp(self):
3031
super().setUp()
3132

3233
self.examples = channel_utils.as_channel([standard_artifacts.Examples()])
34+
statistics_artifact = standard_artifacts.ExampleStatistics()
35+
statistics_artifact.split_names = artifact_utils.encode_split_names(
36+
['train', 'eval'])
37+
self.statistics = channel_utils.as_channel([statistics_artifact])
3338
self.transform_graph = channel_utils.as_channel(
3439
[standard_artifacts.TransformGraph()])
3540
self.schema = channel_utils.as_channel([standard_artifacts.Schema()])
@@ -62,6 +67,23 @@ def testConstructFromModuleFile(self):
6267
'{"test": 10}', trainer.spec.exec_properties[
6368
standard_component_specs.CUSTOM_CONFIG_KEY])
6469

70+
def testConstructFromModuleFileWithStatistics(self):
71+
module_file = '/path/to/module/file'
72+
trainer = component.Trainer(
73+
module_file=module_file,
74+
examples=self.examples,
75+
statistics=self.statistics,
76+
transform_graph=self.transform_graph,
77+
schema=self.schema,
78+
custom_config={'test': 10})
79+
self._verify_outputs(trainer)
80+
self.assertEqual(
81+
module_file,
82+
trainer.spec.exec_properties[standard_component_specs.MODULE_FILE_KEY])
83+
self.assertEqual(
84+
'{"test": 10}', trainer.spec.exec_properties[
85+
standard_component_specs.CUSTOM_CONFIG_KEY])
86+
6587
def testConstructWithParameter(self):
6688
module_file = data_types.RuntimeParameter(name='module-file', ptype=str)
6789
n_steps = data_types.RuntimeParameter(name='n-steps', ptype=int)

tfx/components/trainer/executor.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
from tfx.components.trainer import constants
2323
from tfx.components.trainer import fn_args_utils
2424
from tfx.components.util import udf_utils
25+
from tfx.components.statistics_gen import stats_artifact_utils
2526
from tfx.dsl.components.base import base_executor
2627
from tfx.dsl.io import fileio
2728
from tfx.types import artifact_utils
@@ -87,6 +88,15 @@ class GenericExecutor(base_executor.BaseExecutor):
8788
def _GetFnArgs(self, input_dict: Dict[str, List[types.Artifact]],
8889
output_dict: Dict[str, List[types.Artifact]],
8990
exec_properties: Dict[str, Any]) -> fn_args_utils.FnArgs:
91+
if standard_component_specs.STATISTICS_KEY in input_dict.keys():
92+
stats_artifact = artifact_utils.get_single_instance(
93+
input_dict[standard_component_specs.STATISTICS_KEY])
94+
split_names = artifact_utils.decode_split_names(stats_artifact.split_names)
95+
num_examples = {}
96+
for split in split_names:
97+
stats = stats_artifact_utils.load_statistics(stats_artifact,
98+
split).proto()
99+
num_examples[split] = stats.datasets[0].num_examples
90100
if input_dict.get(standard_component_specs.HYPERPARAMETERS_KEY):
91101
hyperparameters_file = io_utils.get_only_uri_in_dir(
92102
artifact_utils.get_single_uri(
@@ -115,6 +125,8 @@ def _GetFnArgs(self, input_dict: Dict[str, List[types.Artifact]],
115125
result.model_run_dir = model_run_dir
116126
result.schema_file = result.schema_path
117127
result.hyperparameters = hyperparameters_config
128+
if standard_component_specs.STATISTICS_KEY in input_dict.keys():
129+
result.num_examples = num_examples
118130
return result
119131

120132
def Do(self, input_dict: Dict[str, List[types.Artifact]],

tfx/types/standard_component_specs.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -411,6 +411,8 @@ class TrainerSpec(ComponentSpec):
411411
HYPERPARAMETERS_KEY:
412412
ChannelParameter(
413413
type=standard_artifacts.HyperParameters, optional=True),
414+
STATISTICS_KEY:
415+
ChannelParameter(type=standard_artifacts.ExampleStatistics, optional=True),
414416
}
415417
OUTPUTS = {
416418
MODEL_KEY: ChannelParameter(type=standard_artifacts.Model),

0 commit comments

Comments
 (0)