Skip to content

Commit 7615e5a

Browse files
authored
Merge pull request #6952 from janasangeetha/sangeethajana-patch2
Fix Ruff B008 errors
2 parents 8393d14 + 2f6a67d commit 7615e5a

File tree

3 files changed

+9
-6
lines changed

3 files changed

+9
-6
lines changed

tfx/dsl/component/experimental/decorators_test.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@
4242
from tfx.types.system_executions import SystemExecution
4343

4444
_TestBeamPipelineArgs = ['--my_testing_beam_pipeline_args=foo']
45+
_TestEmptyBeamPipeline = beam.Pipeline()
4546

4647

4748
class _InputArtifact(types.Artifact):
@@ -140,7 +141,7 @@ def verify_beam_pipeline_arg(a: int) -> OutputDict(b=float): # pytype: disable=
140141

141142
def verify_beam_pipeline_arg_non_none_default_value(
142143
a: int,
143-
beam_pipeline: BeamComponentParameter[beam.Pipeline] = beam.Pipeline(),
144+
beam_pipeline: BeamComponentParameter[beam.Pipeline] = _TestEmptyBeamPipeline,
144145
) -> OutputDict(b=float): # pytype: disable=invalid-annotation,wrong-arg-types
145146
del beam_pipeline
146147
return {'b': float(a)}

tfx/dsl/component/experimental/decorators_typeddict_test.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@
4040
from tfx.types.system_executions import SystemExecution
4141

4242
_TestBeamPipelineArgs = ['--my_testing_beam_pipeline_args=foo']
43+
_TestEmptyBeamPipeline = beam.Pipeline()
4344

4445

4546
class _InputArtifact(types.Artifact):
@@ -140,7 +141,7 @@ def verify_beam_pipeline_arg(a: int) -> TypedDict('Output6', dict(b=float)): #
140141

141142
def verify_beam_pipeline_arg_non_none_default_value(
142143
a: int,
143-
beam_pipeline: BeamComponentParameter[beam.Pipeline] = beam.Pipeline(),
144+
beam_pipeline: BeamComponentParameter[beam.Pipeline] = _TestEmptyBeamPipeline,
144145
) -> TypedDict('Output7', dict(b=float)): # pytype: disable=wrong-arg-types
145146
del beam_pipeline
146147
return {'b': float(a)}

tfx/examples/bert/utils/bert_models.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -59,16 +59,15 @@ def build_bert_classifier(bert_layer: tf.keras.layers.Layer,
5959

6060
def compile_bert_classifier(
6161
model: tf.keras.Model,
62-
loss: tf.keras.losses.Loss = tf.keras.losses.SparseCategoricalCrossentropy(
63-
from_logits=True),
62+
loss: tf.keras.losses.Loss | None = None,
6463
learning_rate: float = 2e-5,
6564
metrics: Optional[List[Union[str, tf.keras.metrics.Metric]]] = None):
6665
"""Compile the BERT classifier using suggested parameters.
6766
6867
Args:
6968
model: A keras model. Most likely the output of build_bert_classifier.
70-
loss: tf.keras.losses. The suggested loss function expects integer labels
71-
(e.g. 0, 1, 2). If the labels are one-hot encoded, consider using
69+
loss: Default None will use tf.keras.losses. The suggested loss function expects
70+
integer labels (e.g. 0, 1, 2). If the labels are one-hot encoded, consider using
7271
tf.keras.lossesCategoricalCrossEntropy with from_logits set to true.
7372
learning_rate: Suggested learning rate to be used in
7473
tf.keras.optimizer.Adam. The three suggested learning_rates for
@@ -79,6 +78,8 @@ def compile_bert_classifier(
7978
Returns:
8079
None.
8180
"""
81+
if loss is None:
82+
loss = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)
8283
if metrics is None:
8384
metrics = ["sparse_categorical_accuracy"]
8485

0 commit comments

Comments
 (0)