Skip to content

Commit 1ccbc83

Browse files
committed
Fix Ruff B008 errors
1 parent 50a8599 commit 1ccbc83

File tree

3 files changed

+12
-7
lines changed

3 files changed

+12
-7
lines changed

tfx/dsl/component/experimental/decorators_test.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -140,8 +140,10 @@ def verify_beam_pipeline_arg(a: int) -> OutputDict(b=float): # pytype: disable=
140140

141141
def verify_beam_pipeline_arg_non_none_default_value(
142142
a: int,
143-
beam_pipeline: BeamComponentParameter[beam.Pipeline] = beam.Pipeline(),
143+
beam_pipeline: BeamComponentParameter[beam.Pipeline] = 0,
144144
) -> OutputDict(b=float): # pytype: disable=invalid-annotation,wrong-arg-types
145+
if beam_pipeline == 0:
146+
beam_pipeline = beam.Pipeline()
145147
del beam_pipeline
146148
return {'b': float(a)}
147149

tfx/dsl/component/experimental/decorators_typeddict_test.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -140,8 +140,10 @@ def verify_beam_pipeline_arg(a: int) -> TypedDict('Output6', dict(b=float)): #
140140

141141
def verify_beam_pipeline_arg_non_none_default_value(
142142
a: int,
143-
beam_pipeline: BeamComponentParameter[beam.Pipeline] = beam.Pipeline(),
143+
beam_pipeline: BeamComponentParameter[beam.Pipeline] = 0,
144144
) -> TypedDict('Output7', dict(b=float)): # pytype: disable=wrong-arg-types
145+
if beam_pipeline == 0:
146+
beam_pipeline = beam.Pipeline()
145147
del beam_pipeline
146148
return {'b': float(a)}
147149

@@ -807,4 +809,4 @@ def testListOfArtifacts(self):
807809
],
808810
)
809811

810-
beam_dag_runner.BeamDagRunner().run(test_pipeline)
812+
beam_dag_runner.BeamDagRunner().run(test_pipeline)

tfx/examples/bert/utils/bert_models.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -59,16 +59,15 @@ def build_bert_classifier(bert_layer: tf.keras.layers.Layer,
5959

6060
def compile_bert_classifier(
6161
model: tf.keras.Model,
62-
loss: tf.keras.losses.Loss = tf.keras.losses.SparseCategoricalCrossentropy(
63-
from_logits=True),
62+
loss: tf.keras.losses.Loss | None = None,
6463
learning_rate: float = 2e-5,
6564
metrics: Optional[List[Union[str, tf.keras.metrics.Metric]]] = None):
6665
"""Compile the BERT classifier using suggested parameters.
6766
6867
Args:
6968
model: A keras model. Most likely the output of build_bert_classifier.
70-
loss: tf.keras.losses. The suggested loss function expects integer labels
71-
(e.g. 0, 1, 2). If the labels are one-hot encoded, consider using
69+
loss: Default None will use tf.keras.losses. The suggested loss function expects
70+
integer labels (e.g. 0, 1, 2). If the labels are one-hot encoded, consider using
7271
tf.keras.lossesCategoricalCrossEntropy with from_logits set to true.
7372
learning_rate: Suggested learning rate to be used in
7473
tf.keras.optimizer.Adam. The three suggested learning_rates for
@@ -79,6 +78,8 @@ def compile_bert_classifier(
7978
Returns:
8079
None.
8180
"""
81+
if loss is None:
82+
loss = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)
8283
if metrics is None:
8384
metrics = ["sparse_categorical_accuracy"]
8485

0 commit comments

Comments
 (0)