šŸŽ Checkout my Learn React by Making a Game course and get 1 month Free on Skillshare!

TensorflowJs functional model – using the tf.model() function

We have seen in the previous post how to create a TensorflowJs sequential model.

But what if we want to create a more generic model? Take for example the below model architecture:
TensorflowJs functional model - using the tf.model()  function

We will not be able to implement this one using the tf.sequential() function. A sequential TensorflowJs model does not allow layer branching. It only supports a linear stack of layers.

This is where the TensorflowJs functional models come into play. We can implement a branching model using the tf.model() function.

In order to implement a model with the architecture from the above picture we can do:

// Define the input layer, which has a size of 3 .
const input = tf.input({shape: [3]})

// First dense layer 
const denseLayer1 = tf.layers.dense({units: 4})
// Second dense layer
const denseLayer2 = tf.layers.dense({units: 4})

// Create the branching layer
const branchingLayer = tf.layers.dense({units: 4})

// Obtain the first output - Y1
const y1 = denseLayer2.apply(denseLayer1.apply(input));

// Obtain the second branching output - Y2
const y2 = branchingLayer.apply(denseLayer2.apply(denseLayer1.apply(input)))

// Create the model
const model = tf.model({inputs: input, outputs: [y1, y2]})

model.summary()

One more layer is added after the 2nd hidden layer. The output of the 2nd hidden layer is used to predict the y1 outcome

At the same time, the output of the 2nd hidden layer is fed to one more layer, the branchingLayer to predict a 2nd outputy2.

Using this we can predict multiple outputs at the same time. We would have built 2 different neural networks to predict outputs y1 and y2 using sequential API while the functional API allows us to predict two outputs in just one single network.

Keep in mind that you can always check if the architecture of your TensorflowJs model is correctly implemented using the model.summary() method.

šŸ“– 50 Javascript, React and NextJs Projects

Learn by doing with this FREE ebook! Not sure what to build? Dive in with 50 projects with project briefs and wireframes! Choose from 8 project categories and get started right away.

šŸ“– 50 Javascript, React and NextJs Projects

Learn by doing with this FREE ebook! Not sure what to build? Dive in with 50 projects with project briefs and wireframes! Choose from 8 project categories and get started right away.


Leave a Reply

Your email address will not be published. Required fields are marked *