Developing Complex Neural Networks in Keras

Most Keras examples show neural networks that use the Sequential class. This is the simplest type of Neural Network where one input gives one output. The constructor of the Sequential class takes in a list of layers, the lowest layer is the first one in the list and the highest layer the last one in the list. It can also be pictured as a stack of layers (see Figure 1). In Figure 1 the arrow shows the flow of data when we are using the model in prediction mode (feed-forward).

Figure 1: Stack of layers

Sequential class does not allow us to build complex models that require joining of two different set of layers or forking out of the current layer.

Why would we need such a structure? We may need a model for video processing which has two types of inputs: audio and video stream. For example if we are attempting to classify a video segment as being fake or not. We might want to use both the video stream as well as the audio stream to help in the classification.

To do this we would want to pass the audio and video through encoders trained for the specific input type and then in a higher layer combine the features to provide a classification (see Figure 2).

Figure 2: Combining two stacked network layers.

Another use-case is to generate lip movements based on audio segments (Google LipSync3D) where a single input (audio segment) generates both a 3D mesh around the mouth (for the lip movements) and a set of textures to map on the 3D mesh. These are combined to generate a video with realistic facial movements.

This common requirement of combining two stacks or forking from a common layer is the reason why we have the Keras Functional API and the Model class.

Keras Functional API and Model Class

The Functional API gives full freedom to create neural networks with non-linear topologies.

The key class here is tf.keras.Model which allows us to build a graph (a Directed Acyclic Graph to be exact) of layers instead of restricting us to a list of layers.

Make sure you use Keras Utils plot_model to keep an eye on the graph you are creating (see below). Figure 3 shows an example of a toy model with two input encoding stacks with a common output stack. This is similar to Figure 2 except the inputs are shown at the top.

keras.utils.plot_model(model_object, "<output image>.png")
Figure 3: Output of plot_model method.

Code for this can be seen below. The main difference is that instead of passing layers in a list we have to assemble a stack of layers (see input stack 1 and 2 below), starting with the tf.keras.layers.Input layer, and connect them through a special merging layer (tf.keras.layers.concatenate in this example) to the common part of the network. The Model constructor takes a list of these Input layers as well as the final output layer.

The Input layers mark the starting point of the graph and the output layer (in this example) marks the end of the graph. The activation will flow from Input layers to the output layer.

input1 = layers.Input(WIDTH) #input stack 1
    l1 = layers.Dense(20)(input1)
    l2 = layers.Dense(10)(l1)

    input2 = layers.Input(WIDTH) #input stack 2
    l21 = layers.Dense(20)(input2)
    l22 = layers.Dense(10)(l21)

#Common output stack
    common = layers.concatenate([l2,l22])
    interm = layers.Dense(10)(common)
    output = layers.Dense(1)(interm)
    model = models.Model(inputs=[input1,input2],outputs=output)

Azahar Machwe (2022)

Leave a Comment

Fill in your details below or click an icon to log in: Logo

You are commenting using your account. Log Out /  Change )

Twitter picture

You are commenting using your Twitter account. Log Out /  Change )

Facebook photo

You are commenting using your Facebook account. Log Out /  Change )

Connecting to %s