Price Prediction

Build, Train, and Save Models Using Keras and tf.Module

Content Overview

  • Keras layers
  • The build step
  • Keras models
  • Saving Keras models
  • Checkpointing Keras models

Note that up until this point, there is no mention of Keras. You can build your own high-level API on top of tf.Module, and people have.

In this section, you will examine how Keras uses tf.Module. A complete user guide to Keras models can be found in the Keras guide.

Keras layers and models have a lot more extra features including:

  • Optional losses
  • Support for metrics
  • Built-in support for an optional training argument to differentiate between training and inference use
  • Saving and restoring python objects instead of just black-box functions
  • get_config and from_config methods that allow you to accurately store configurations to allow model cloning in Python

These features allow for far more complex models through subclassing, such as a custom GAN or a Variational AutoEncoder (VAE) model. Read about them in the full guide to custom layers and models.

Keras models also come with extra functionality that makes them easy to train, evaluate, load, save, and even train on multiple machines.

Keras layers

tf.keras.layers.Layer is the base class of all Keras layers, and it inherits from tf.Module.

You can convert a module into a Keras layer just by swapping out the parent and then changing __call__ to call:

class MyDense(tf.keras.layers.Layer):
  # Adding **kwargs to support base Keras layer arguments
  def __init__(self, in_features, out_features, **kwargs):
    super().__init__(**kwargs)

    # This will soon move to the build step; see below
    self.w = tf.Variable(
      tf.random.normal([in_features, out_features]), name='w')
    self.b = tf.Variable(tf.zeros([out_features]), name='b')
  def call(self, x):
    y = tf.matmul(x, self.w) + self.b
    return tf.nn.relu(y)

simple_layer = MyDense(name="simple", in_features=3, out_features=3)

Keras layers have their own __call__ that does some bookkeeping described in the next section and then calls call(). You should notice no change in functionality.

simple_layer([[2.0, 2.0, 2.0]])

The build step

As noted, it’s convenient in many cases to wait to create variables until you are sure of the input shape.

Keras layers come with an extra lifecycle step that allows you more flexibility in how you define your layers. This is defined in the build function.

build is called exactly once, and it is called with the shape of the input. It’s usually used to create variables (weights).

You can rewrite MyDense layer above to be flexible to the size of its inputs:

class FlexibleDense(tf.keras.layers.Layer):
  # Note the added `**kwargs`, as Keras supports many arguments
  def __init__(self, out_features, **kwargs):
    super().__init__(**kwargs)
    self.out_features = out_features

  def build(self, input_shape):  # Create the state of the layer (weights)
    self.w = tf.Variable(
      tf.random.normal([input_shape[-1], self.out_features]), name='w')
    self.b = tf.Variable(tf.zeros([self.out_features]), name='b')

  def call(self, inputs):  # Defines the computation from inputs to outputs
    return tf.matmul(inputs, self.w) + self.b

# Create the instance of the layer
flexible_dense = FlexibleDense(out_features=3)

At this point, the model has not been built, so there are no variables:

flexible_dense.variables
[]

Calling the function allocates appropriately-sized variables:

# Call it, with predictably random results
print("Model results:", flexible_dense(tf.constant([[2.0, 2.0, 2.0], [3.0, 3.0, 3.0]])))
Model results: tf.Tensor(
[[-2.531786  -5.5550847 -0.4248762]
 [-3.7976792 -8.332626  -0.6373143]], shape=(2, 3), dtype=float32)
flexible_dense.variables
[,
 ]

Since build is only called once, inputs will be rejected if the input shape is not compatible with the layer’s variables:

try:
  print("Model results:", flexible_dense(tf.constant([[2.0, 2.0, 2.0, 2.0]])))
except tf.errors.InvalidArgumentError as e:
  print("Failed:", e)
Failed: Exception encountered when calling layer 'flexible_dense' (type FlexibleDense).

{ {function_node __wrapped__MatMul_device_/job:localhost/replica:0/task:0/device:CPU:0} } Matrix size-incompatible: In[0]: [1,4], In[1]: [3,3] [Op:MatMul] name: 

Call arguments received by layer 'flexible_dense' (type FlexibleDense):
  • inputs=tf.Tensor(shape=(1, 4), dtype=float32)

Keras models

You can define your model as nested Keras layers.

However, Keras also provides a full-featured model class called tf.keras.Model. It inherits from tf.keras.layers.Layer, so a Keras model can be used and nested in the same way as Keras layers. Keras models come with extra functionality that makes them easy to train, evaluate, load, save, and even train on multiple machines.

You can define the SequentialModule from above with nearly identical code, again converting __call__ to call() and changing the parent:

@keras.saving.register_keras_serializable()
class MySequentialModel(tf.keras.Model):
  def __init__(self, name=None, **kwargs):
    super().__init__(**kwargs)

    self.dense_1 = FlexibleDense(out_features=3)
    self.dense_2 = FlexibleDense(out_features=2)
  def call(self, x):
    x = self.dense_1(x)
    return self.dense_2(x)

# You have made a Keras model!
my_sequential_model = MySequentialModel(name="the_model")

# Call it on a tensor, with random results
print("Model results:", my_sequential_model(tf.constant([[2.0, 2.0, 2.0]])))
Model results: tf.Tensor([[ 0.26034355 16.431221  ]], shape=(1, 2), dtype=float32)

All the same features are available, including tracking variables and submodules.

Note: A raw tf.Module nested inside a Keras layer or model will not get its variables collected for training or saving. Instead, nest Keras layers inside of Keras layers.

my_sequential_model.variables
[,
 ,
 ,
 ]
my_sequential_model.submodules
(<__main__.FlexibleDense at 0x7f790c7e0e80>,
 <__main__.FlexibleDense at 0x7f790c7e6940>)

Overriding tf.keras.Model is a very Pythonic approach to building TensorFlow models. If you are migrating models from other frameworks, this can be very straightforward.

If you are constructing models that are simple assemblages of existing layers and inputs, you can save time and space by using the functional API, which comes with additional features around model reconstruction and architecture.

Here is the same model with the functional API:

inputs = tf.keras.Input(shape=[3,])

x = FlexibleDense(3)(inputs)
x = FlexibleDense(2)(x)

my_functional_model = tf.keras.Model(inputs=inputs, outputs=x)

my_functional_model.summary()
Model: "model"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
=================================================================
 input_1 (InputLayer)        [(None, 3)]               0         
                                                                 
 flexible_dense_3 (Flexible  (None, 3)                 12        
 Dense)                                                          
                                                                 
 flexible_dense_4 (Flexible  (None, 2)                 8         
 Dense)                                                          
                                                                 
=================================================================
Total params: 20 (80.00 Byte)
Trainable params: 20 (80.00 Byte)
Non-trainable params: 0 (0.00 Byte)
_________________________________________________________________
my_functional_model(tf.constant([[2.0, 2.0, 2.0]]))

The major difference here is that the input shape is specified up front as part of the functional construction process. The input_shape argument in this case does not have to be completely specified; you can leave some dimensions as None.

Note: You do not need to specify input_shape or an InputLayer in a subclassed model; these arguments and layers will be ignored.

Saving Keras models

Keras models have their own specialized zip archive saving format, marked by the .keras extension. When calling tf.keras.Model.save, add a .keras extension to the filename. For example:

my_sequential_model.save("exname_of_file.keras")

Just as easily, they can be loaded back in:

reconstructed_model = tf.keras.models.load_model("exname_of_file.keras")

Keras zip archives — .keras files — also save metric, loss, and optimizer states.

This reconstructed model can be used and will produce the same result when called on the same data:

reconstructed_model(tf.constant([[2.0, 2.0, 2.0]]))

Checkpointing Keras models

Keras models can also be checkpointed, and that will look the same as tf.Module.

There is more to know about saving and serialization of Keras models, including providing configuration methods for custom layers for feature support. Check out the guide to saving and serialization.

What’s next

If you want to know more details about Keras, you can follow the existing Keras guides here.

Another example of a high-level API built on tf.module is Sonnet from DeepMind, which is covered on their site.


Originally published on the TensorFlow website, this article appears here under a new headline and is licensed under CC BY 4.0. Code samples shared under the Apache 2.0 License.

Related Articles

Leave a Reply

Your email address will not be published. Required fields are marked *

Back to top button