Skip to content

Commit 58fa4a8

Browse files
authored
Add a workaround for the removal of top-level TFMA attributes (#7637)
* Add a workaround for the removal of top-level TFMA attributes in TFMA 0.47.0
1 parent a3aa157 commit 58fa4a8

File tree

4 files changed

+46
-10
lines changed

4 files changed

+46
-10
lines changed

tfx/components/model_validator/executor.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,12 @@
2828
from tfx.utils import io_utils
2929
from tfx.utils import path_utils
3030

31+
try:
32+
# Try to access EvalResult from tfma directly
33+
_EvalResult = tfma.EvalResult
34+
except AttributeError:
35+
# If tfma doesn't have EvalResult, use the one from view_types
36+
from tensorflow_model_analysis.view.view_types import EvalResult as _EvalResult
3137

3238
class Executor(base_beam_executor.BaseBeamExecutor):
3339
"""DEPRECATED: Please use `Evaluator` instead.
@@ -51,13 +57,13 @@ class Executor(base_beam_executor.BaseBeamExecutor):
5157
"""
5258

5359
# TODO(jyzhao): customized threshold support.
54-
def _pass_threshold(self, eval_result: tfma.EvalResult) -> bool:
60+
def _pass_threshold(self, eval_result: _EvalResult) -> bool:
5561
"""Check threshold."""
5662
return True
5763

5864
# TODO(jyzhao): customized validation support.
59-
def _compare_eval_result(self, current_model_eval_result: tfma.EvalResult,
60-
blessed_model_eval_result: tfma.EvalResult) -> bool:
65+
def _compare_eval_result(self, current_model_eval_result: _EvalResult,
66+
blessed_model_eval_result: _EvalResult) -> bool:
6167
"""Compare accuracy of all metrics and return true if current is better or equal."""
6268
for current_metric, blessed_metric in zip(
6369
current_model_eval_result.slicing_metrics,

tfx/components/testdata/module_file/evaluator_module.py

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,9 +19,24 @@
1919
from tfx_bsl.tfxio import tensor_adapter
2020

2121

22+
try:
23+
# Try to access EvalSharedModel from tfma directly
24+
_EvalSharedModel = tfma.EvalSharedModel
25+
except AttributeError:
26+
# If tfma doesn't have EvalSharedModel, use the one from api.types
27+
from tensorflow_model_analysis.api.types import EvalSharedModel as _EvalSharedModel
28+
29+
try:
30+
# Try to access MaybeMultipleEvalSharedModels from tfma directly
31+
_MaybeMultipleEvalSharedModels = tfma.MaybeMultipleEvalSharedModels
32+
except AttributeError:
33+
# If tfma doesn't have MaybeMultipleEvalSharedModels, use the one from api.types
34+
from tensorflow_model_analysis.api.types import MaybeMultipleEvalSharedModels as _MaybeMultipleEvalSharedModels
35+
36+
2237
def custom_eval_shared_model(eval_saved_model_path: str, model_name: str,
2338
eval_config: tfma.EvalConfig,
24-
**kwargs: Dict[str, Any]) -> tfma.EvalSharedModel:
39+
**kwargs: Dict[str, Any]) -> _EvalSharedModel:
2540
return tfma.default_eval_shared_model(
2641
eval_saved_model_path=eval_saved_model_path,
2742
model_name=model_name,
@@ -30,7 +45,7 @@ def custom_eval_shared_model(eval_saved_model_path: str, model_name: str,
3045

3146

3247
def custom_extractors(
33-
eval_shared_model: tfma.MaybeMultipleEvalSharedModels,
48+
eval_shared_model: _MaybeMultipleEvalSharedModels,
3449
eval_config: tfma.EvalConfig,
3550
tensor_adapter_config: tensor_adapter.TensorAdapterConfig,
3651
) -> List[tfma.extractors.Extractor]:

tfx/examples/penguin/experimental/sklearn_predict_extractor.py

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -25,9 +25,16 @@
2525

2626
_PREDICT_EXTRACTOR_STAGE_NAME = 'SklearnPredict'
2727

28+
try:
29+
# Try to access EvalSharedModel from tfma directly
30+
_EvalSharedModel = tfma.EvalSharedModel
31+
except AttributeError:
32+
# If tfma doesn't have EvalSharedModel, use the one from api.types
33+
from tensorflow_model_analysis.api.types import EvalSharedModel as _EvalSharedModel
34+
2835

2936
def _make_sklearn_predict_extractor(
30-
eval_shared_model: tfma.EvalSharedModel,) -> tfma.extractors.Extractor:
37+
eval_shared_model: _EvalSharedModel,) -> tfma.extractors.Extractor:
3138
"""Creates an extractor for performing predictions using a scikit-learn model.
3239
3340
The extractor's PTransform loads and runs the serving pickle against
@@ -54,7 +61,7 @@ def _make_sklearn_predict_extractor(
5461
class _TFMAPredictionDoFn(tfma.utils.DoFnWithModels):
5562
"""A DoFn that loads the models and predicts."""
5663

57-
def __init__(self, eval_shared_models: Dict[str, tfma.EvalSharedModel]):
64+
def __init__(self, eval_shared_models: Dict[str, _EvalSharedModel]):
5865
super().__init__({k: v.model_loader for k, v in eval_shared_models.items()})
5966

6067
def setup(self):
@@ -116,7 +123,7 @@ def process(self, elem: tfma.Extracts) -> Iterable[tfma.Extracts]:
116123
@beam.typehints.with_output_types(tfma.Extracts)
117124
def _ExtractPredictions( # pylint: disable=invalid-name
118125
extracts: beam.pvalue.PCollection,
119-
eval_shared_models: Dict[str, tfma.EvalSharedModel],
126+
eval_shared_models: Dict[str, _EvalSharedModel],
120127
) -> beam.pvalue.PCollection:
121128
"""A PTransform that adds predictions and possibly other tensors to extracts.
122129
@@ -139,7 +146,7 @@ def _custom_model_loader_fn(model_path: str):
139146
# TFX Evaluator will call the following functions.
140147
def custom_eval_shared_model(
141148
eval_saved_model_path, model_name, eval_config,
142-
**kwargs) -> tfma.EvalSharedModel:
149+
**kwargs) -> _EvalSharedModel:
143150
"""Returns a single custom EvalSharedModel."""
144151
model_path = os.path.join(eval_saved_model_path, 'model.pkl')
145152
return tfma.default_eval_shared_model(

tfx/experimental/pipeline_testing/executor_verifier_utils.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,14 @@
3333
from tensorflow_metadata.proto.v0 import anomalies_pb2
3434

3535

36+
try:
37+
# Try to access EvalResult from tfma directly
38+
_EvalResult = tfma.EvalResult
39+
except AttributeError:
40+
# If tfma doesn't have EvalResult, use the one from view_types
41+
from tensorflow_model_analysis.view.view_types import EvalResult as _EvalResult
42+
43+
3644
def compare_dirs(dir1: str, dir2: str):
3745
"""Recursively compares contents of the two directories.
3846
@@ -159,7 +167,7 @@ def verify_file_dir(output_uri: str,
159167

160168

161169
def _group_metric_by_slice(
162-
eval_result: tfma.EvalResult) -> Dict[str, Dict[str, float]]:
170+
eval_result: _EvalResult) -> Dict[str, Dict[str, float]]:
163171
"""Returns a dictionary holding metric values for every slice.
164172
165173
Args:

0 commit comments

Comments
 (0)