Skip to content

Commit 2d94da5

Browse files
authored
Update TFX to be compatible with Keras3 (#7621)
* Update trainer module to be compatiable with keras3 * Add xfail keras model test which is not compatible with Keras3
1 parent 271801e commit 2d94da5

File tree

2 files changed

+4
-2
lines changed

2 files changed

+4
-2
lines changed

tfx/components/testdata/module_file/trainer_module.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -240,7 +240,7 @@ def _build_keras_model(
240240
output = tf.keras.layers.Dense(1, activation='sigmoid')(
241241
tf.keras.layers.concatenate([deep, wide])
242242
)
243-
output = tf.squeeze(output, -1)
243+
output = tf.keras.layers.Reshape((1,))(output)
244244

245245
model = tf.keras.Model(input_layers, output)
246246
model.compile(
@@ -365,4 +365,4 @@ def run_fn(fn_args: fn_args_utils.FnArgs):
365365
model, tf_transform_output
366366
),
367367
}
368-
model.save(fn_args.serving_model_dir, save_format='tf', signatures=signatures)
368+
tf.saved_model.save(model, fn_args.serving_model_dir, signatures=signatures)

tfx/experimental/templates/taxi/models/keras_model/model_test.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,10 +13,12 @@
1313
# limitations under the License.
1414

1515
import tensorflow as tf
16+
import pytest
1617

1718
from tfx.experimental.templates.taxi.models.keras_model import model
1819

1920

21+
@pytest.mark.xfail(run=False, reason="_build_keras_model is not compatible with Keras3.")
2022
class ModelTest(tf.test.TestCase):
2123

2224
def testBuildKerasModel(self):

0 commit comments

Comments
 (0)