Build, Train, and Save Models Using Keras and tf.Module
Content Overview
- Keras layers
- The build step
- Keras models
- Saving Keras models
- Checkpointing Keras models
Note that up until this point, there is no mention of Keras. You can build your own high-level API on top of tf.Module
, and people have.
In this section, you will examine how Keras uses tf.Module
. A complete user guide to Keras models can be found in the Keras guide.
Keras layers and models have a lot more extra features including:
- Optional losses
- Support for metrics
- Built-in support for an optional
training
argument to differentiate between training and inference use - Saving and restoring python objects instead of just black-box functions
get_config
andfrom_config
methods that allow you to accurately store configurations to allow model cloning in Python
These features allow for far more complex models through subclassing, such as a custom GAN or a Variational AutoEncoder (VAE) model. Read about them in the full guide to custom layers and models.
Keras models also come with extra functionality that makes them easy to train, evaluate, load, save, and even train on multiple machines.
Keras layers
tf.keras.layers.Layer
is the base class of all Keras layers, and it inherits from tf.Module
.
You can convert a module into a Keras layer just by swapping out the parent and then changing __call__
to call
:
class MyDense(tf.keras.layers.Layer):
# Adding **kwargs to support base Keras layer arguments
def __init__(self, in_features, out_features, **kwargs):
super().__init__(**kwargs)
# This will soon move to the build step; see below
self.w = tf.Variable(
tf.random.normal([in_features, out_features]), name='w')
self.b = tf.Variable(tf.zeros([out_features]), name='b')
def call(self, x):
y = tf.matmul(x, self.w) + self.b
return tf.nn.relu(y)
simple_layer = MyDense(name="simple", in_features=3, out_features=3)
Keras layers have their own __call__
that does some bookkeeping described in the next section and then calls call()
. You should notice no change in functionality.
simple_layer([[2.0, 2.0, 2.0]])
The build
step
As noted, it’s convenient in many cases to wait to create variables until you are sure of the input shape.
Keras layers come with an extra lifecycle step that allows you more flexibility in how you define your layers. This is defined in the build
function.
build
is called exactly once, and it is called with the shape of the input. It’s usually used to create variables (weights).
You can rewrite MyDense
layer above to be flexible to the size of its inputs:
class FlexibleDense(tf.keras.layers.Layer):
# Note the added `**kwargs`, as Keras supports many arguments
def __init__(self, out_features, **kwargs):
super().__init__(**kwargs)
self.out_features = out_features
def build(self, input_shape): # Create the state of the layer (weights)
self.w = tf.Variable(
tf.random.normal([input_shape[-1], self.out_features]), name='w')
self.b = tf.Variable(tf.zeros([self.out_features]), name='b')
def call(self, inputs): # Defines the computation from inputs to outputs
return tf.matmul(inputs, self.w) + self.b
# Create the instance of the layer
flexible_dense = FlexibleDense(out_features=3)
At this point, the model has not been built, so there are no variables:
flexible_dense.variables
[]
Calling the function allocates appropriately-sized variables:
# Call it, with predictably random results
print("Model results:", flexible_dense(tf.constant([[2.0, 2.0, 2.0], [3.0, 3.0, 3.0]])))
Model results: tf.Tensor(
[[-2.531786 -5.5550847 -0.4248762]
[-3.7976792 -8.332626 -0.6373143]], shape=(2, 3), dtype=float32)
flexible_dense.variables
[,
]
Since build
is only called once, inputs will be rejected if the input shape is not compatible with the layer’s variables:
try:
print("Model results:", flexible_dense(tf.constant([[2.0, 2.0, 2.0, 2.0]])))
except tf.errors.InvalidArgumentError as e:
print("Failed:", e)
Failed: Exception encountered when calling layer 'flexible_dense' (type FlexibleDense).
{ {function_node __wrapped__MatMul_device_/job:localhost/replica:0/task:0/device:CPU:0} } Matrix size-incompatible: In[0]: [1,4], In[1]: [3,3] [Op:MatMul] name:
Call arguments received by layer 'flexible_dense' (type FlexibleDense):
• inputs=tf.Tensor(shape=(1, 4), dtype=float32)
Keras models
You can define your model as nested Keras layers.
However, Keras also provides a full-featured model class called tf.keras.Model
. It inherits from tf.keras.layers.Layer
, so a Keras model can be used and nested in the same way as Keras layers. Keras models come with extra functionality that makes them easy to train, evaluate, load, save, and even train on multiple machines.
You can define the SequentialModule
from above with nearly identical code, again converting __call__
to call()
and changing the parent:
@keras.saving.register_keras_serializable()
class MySequentialModel(tf.keras.Model):
def __init__(self, name=None, **kwargs):
super().__init__(**kwargs)
self.dense_1 = FlexibleDense(out_features=3)
self.dense_2 = FlexibleDense(out_features=2)
def call(self, x):
x = self.dense_1(x)
return self.dense_2(x)
# You have made a Keras model!
my_sequential_model = MySequentialModel(name="the_model")
# Call it on a tensor, with random results
print("Model results:", my_sequential_model(tf.constant([[2.0, 2.0, 2.0]])))
Model results: tf.Tensor([[ 0.26034355 16.431221 ]], shape=(1, 2), dtype=float32)
All the same features are available, including tracking variables and submodules.
Note: A raw tf.Module
nested inside a Keras layer or model will not get its variables collected for training or saving. Instead, nest Keras layers inside of Keras layers.
my_sequential_model.variables
[,
,
,
]
my_sequential_model.submodules
(<__main__.FlexibleDense at 0x7f790c7e0e80>,
<__main__.FlexibleDense at 0x7f790c7e6940>)
Overriding tf.keras.Model
is a very Pythonic approach to building TensorFlow models. If you are migrating models from other frameworks, this can be very straightforward.
If you are constructing models that are simple assemblages of existing layers and inputs, you can save time and space by using the functional API, which comes with additional features around model reconstruction and architecture.
Here is the same model with the functional API:
inputs = tf.keras.Input(shape=[3,])
x = FlexibleDense(3)(inputs)
x = FlexibleDense(2)(x)
my_functional_model = tf.keras.Model(inputs=inputs, outputs=x)
my_functional_model.summary()
Model: "model"
_________________________________________________________________
Layer (type) Output Shape Param #
=================================================================
input_1 (InputLayer) [(None, 3)] 0
flexible_dense_3 (Flexible (None, 3) 12
Dense)
flexible_dense_4 (Flexible (None, 2) 8
Dense)
=================================================================
Total params: 20 (80.00 Byte)
Trainable params: 20 (80.00 Byte)
Non-trainable params: 0 (0.00 Byte)
_________________________________________________________________
my_functional_model(tf.constant([[2.0, 2.0, 2.0]]))
The major difference here is that the input shape is specified up front as part of the functional construction process. The input_shape
argument in this case does not have to be completely specified; you can leave some dimensions as None
.
Note: You do not need to specify input_shape
or an InputLayer
in a subclassed model; these arguments and layers will be ignored.
Saving Keras models
Keras models have their own specialized zip archive saving format, marked by the .keras
extension. When calling tf.keras.Model.save
, add a .keras
extension to the filename. For example:
my_sequential_model.save("exname_of_file.keras")
Just as easily, they can be loaded back in:
reconstructed_model = tf.keras.models.load_model("exname_of_file.keras")
Keras zip archives — .keras
files — also save metric, loss, and optimizer states.
This reconstructed model can be used and will produce the same result when called on the same data:
reconstructed_model(tf.constant([[2.0, 2.0, 2.0]]))
Checkpointing Keras models
Keras models can also be checkpointed, and that will look the same as tf.Module
.
There is more to know about saving and serialization of Keras models, including providing configuration methods for custom layers for feature support. Check out the guide to saving and serialization.
What’s next
If you want to know more details about Keras, you can follow the existing Keras guides here.
Another example of a high-level API built on tf.module
is Sonnet from DeepMind, which is covered on their site.
Originally published on the