Skip to content

Commit f42957d

Browse files
authored
Update template models not to use deprecated Keras apis (#7723)
1 parent 34c7147 commit f42957d

File tree

2 files changed

+59
-84
lines changed

2 files changed

+59
-84
lines changed

tfx/experimental/templates/taxi/models/keras_model/model.py

Lines changed: 57 additions & 82 deletions
Original file line numberDiff line numberDiff line change
@@ -106,98 +106,73 @@ def _build_keras_model(hidden_units, learning_rate):
106106
Returns:
107107
A keras Model.
108108
"""
109-
real_valued_columns = [
110-
tf.feature_column.numeric_column(key, shape=())
111-
for key in features.transformed_names(features.DENSE_FLOAT_FEATURE_KEYS)
112-
]
113-
categorical_columns = [
114-
tf.feature_column.categorical_column_with_identity( # pylint: disable=g-complex-comprehension
115-
key,
116-
num_buckets=features.VOCAB_SIZE + features.OOV_SIZE,
117-
default_value=0)
118-
for key in features.transformed_names(features.VOCAB_FEATURE_KEYS)
119-
]
120-
categorical_columns += [
121-
tf.feature_column.categorical_column_with_identity( # pylint: disable=g-complex-comprehension
122-
key,
123-
num_buckets=num_buckets,
124-
default_value=0) for key, num_buckets in zip(
125-
features.transformed_names(features.BUCKET_FEATURE_KEYS),
126-
features.BUCKET_FEATURE_BUCKET_COUNT)
127-
]
128-
categorical_columns += [
129-
tf.feature_column.categorical_column_with_identity( # pylint: disable=g-complex-comprehension
130-
key,
131-
num_buckets=num_buckets,
132-
default_value=0) for key, num_buckets in zip(
133-
features.transformed_names(features.CATEGORICAL_FEATURE_KEYS),
134-
features.CATEGORICAL_FEATURE_MAX_VALUES)
135-
]
136-
indicator_column = [
137-
tf.feature_column.indicator_column(categorical_column)
138-
for categorical_column in categorical_columns
139-
]
140-
141-
model = _wide_and_deep_classifier(
142-
# TODO(b/140320729) Replace with premade wide_and_deep keras model
143-
wide_columns=indicator_column,
144-
deep_columns=real_valued_columns,
145-
dnn_hidden_units=hidden_units,
146-
learning_rate=learning_rate)
147-
return model
148-
149-
150-
def _wide_and_deep_classifier(wide_columns, deep_columns, dnn_hidden_units,
151-
learning_rate):
152-
"""Build a simple keras wide and deep model.
153-
154-
Args:
155-
wide_columns: Feature columns wrapped in indicator_column for wide (linear)
156-
part of the model.
157-
deep_columns: Feature columns for deep part of the model.
158-
dnn_hidden_units: [int], the layer sizes of the hidden DNN.
159-
learning_rate: [float], learning rate of the Adam optimizer.
160-
161-
Returns:
162-
A Wide and Deep Keras model
163-
"""
164-
# Keras needs the feature definitions at compile time.
165-
# TODO(b/139081439): Automate generation of input layers from FeatureColumn.
166-
input_layers = {
167-
colname: tf.keras.layers.Input(name=colname, shape=(), dtype=tf.float32)
168-
for colname in features.transformed_names(
169-
features.DENSE_FLOAT_FEATURE_KEYS)
109+
deep_input = {
110+
colname: tf.keras.layers.Input(name=colname, shape=(1,), dtype=tf.float32)
111+
for colname in features.transformed_names(features.DENSE_FLOAT_FEATURE_KEYS)
170112
}
171-
input_layers.update({
172-
colname: tf.keras.layers.Input(name=colname, shape=(), dtype='int32')
113+
wide_vocab_input = {
114+
colname: tf.keras.layers.Input(name=colname, shape=(1,), dtype='int32')
173115
for colname in features.transformed_names(features.VOCAB_FEATURE_KEYS)
174-
})
175-
input_layers.update({
176-
colname: tf.keras.layers.Input(name=colname, shape=(), dtype='int32')
116+
}
117+
wide_bucket_input = {
118+
colname: tf.keras.layers.Input(name=colname, shape=(1,), dtype='int32')
177119
for colname in features.transformed_names(features.BUCKET_FEATURE_KEYS)
178-
})
179-
input_layers.update({
180-
colname: tf.keras.layers.Input(name=colname, shape=(), dtype='int32') for
181-
colname in features.transformed_names(features.CATEGORICAL_FEATURE_KEYS)
182-
})
183-
184-
# TODO(b/161952382): Replace with Keras premade models and
185-
# Keras preprocessing layers.
186-
deep = tf.keras.layers.DenseFeatures(deep_columns)(input_layers)
187-
for numnodes in dnn_hidden_units:
120+
}
121+
wide_categorical_input = {
122+
colname: tf.keras.layers.Input(name=colname, shape=(1,), dtype='int32')
123+
for colname in features.transformed_names(features.CATEGORICAL_FEATURE_KEYS)
124+
}
125+
input_layers = {
126+
**deep_input,
127+
**wide_vocab_input,
128+
**wide_bucket_input,
129+
**wide_categorical_input,
130+
}
131+
132+
deep = tf.keras.layers.concatenate(
133+
[tf.keras.layers.Normalization()(layer) for layer in deep_input.values()]
134+
)
135+
for numnodes in (hidden_units or [100, 70, 50, 25]):
188136
deep = tf.keras.layers.Dense(numnodes)(deep)
189-
wide = tf.keras.layers.DenseFeatures(wide_columns)(input_layers)
190137

191-
output = tf.keras.layers.Dense(
192-
1, activation='sigmoid')(
193-
tf.keras.layers.concatenate([deep, wide]))
194-
output = tf.squeeze(output, -1)
138+
wide_layers = []
139+
for key in features.transformed_names(features.VOCAB_FEATURE_KEYS):
140+
wide_layers.append(
141+
tf.keras.layers.CategoryEncoding(num_tokens=features.VOCAB_SIZE + features.OOV_SIZE)(
142+
input_layers[key]
143+
)
144+
)
145+
for key, num_tokens in zip(
146+
features.transformed_names(features.BUCKET_FEATURE_KEYS),
147+
features.BUCKET_FEATURE_BUCKET_COUNT,
148+
):
149+
wide_layers.append(
150+
tf.keras.layers.CategoryEncoding(num_tokens=num_tokens)(
151+
input_layers[key]
152+
)
153+
)
154+
for key, num_tokens in zip(
155+
features.transformed_names(features.CATEGORICAL_FEATURE_KEYS),
156+
features.CATEGORICAL_FEATURE_MAX_VALUES,
157+
):
158+
wide_layers.append(
159+
tf.keras.layers.CategoryEncoding(num_tokens=num_tokens)(
160+
input_layers[key]
161+
)
162+
)
163+
wide = tf.keras.layers.concatenate(wide_layers)
164+
165+
output = tf.keras.layers.Dense(1, activation='sigmoid')(
166+
tf.keras.layers.concatenate([deep, wide])
167+
)
168+
output = tf.keras.layers.Reshape((1,))(output)
195169

196170
model = tf.keras.Model(input_layers, output)
197171
model.compile(
198172
loss='binary_crossentropy',
199173
optimizer=tf.keras.optimizers.Adam(learning_rate=learning_rate),
200-
metrics=[tf.keras.metrics.BinaryAccuracy()])
174+
metrics=[tf.keras.metrics.BinaryAccuracy()],
175+
)
201176
model.summary(print_fn=logging.info)
202177
return model
203178

tfx/experimental/templates/taxi/models/keras_model/model_test.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ class ModelTest(tf.test.TestCase):
2222
def testBuildKerasModel(self):
2323
built_model = model._build_keras_model(
2424
hidden_units=[1, 1], learning_rate=0.1) # pylint: disable=protected-access
25-
self.assertEqual(len(built_model.layers), 10)
25+
self.assertEqual(len(built_model.layers), 13)
2626

2727
built_model = model._build_keras_model(hidden_units=[1], learning_rate=0.1) # pylint: disable=protected-access
28-
self.assertEqual(len(built_model.layers), 9)
28+
self.assertEqual(len(built_model.layers), 12)

0 commit comments

Comments
 (0)