šŸŽ Checkout my Learn React by Making a Game course and get 1 month Free on Skillshare!

Reading the weights of a TensorflowJs model

When it comes to machine learning, all evolve around the weights between the neurons. One of the main things a TensorflowJs model does is to find the values of the appropriate weights that are needed to emulate the behavior of a function.

The Tensorflow API provides the getWeights() function that is made for accessing the values of the weights in a TensorflowJs model.

Reading all the weights of a TensorflowJs model

Let's say we have the following sequential model:

const model = tf.sequential();
model.add(tf.layers.dense({inputShape: [2], units: 1}));
model.add(tf.layers.dense({units: 1}));
model.compile({ optimizer: 'sgd', loss: 'meanSquaredError' });

It's a 3-layer model, with 2 inputs neurons, 1 neuron in the hidden layer, and 1 output neuron.

If we want to print all of the weights of this model we can do:

const printAllWeights = (model) => {
  for (let i = 0; i < model.getWeights().length; i++) {
    model.getWeights()[i].print()
  }
}

Remember that we also have the bias neurons so the output will contain 5 total values. The initial bias weights are zero:

"Tensor
    [[-0.1103895],
     [0.8178059 ]]"
"Tensor
    [0]"
"Tensor
     [[1.3078513],]"
"Tensor
    [0]"

Also, we start from small random weights values in a neuronal network, so each time you should see different values for the output.

If you need to test what would happen with a given set of weights you can use the setWeights() function to load specific values for the weights of your model.

After the training is done, the values for the weights will be different as the network is now tailored to mimic a specific behavior.

Reading the weights of a single layer in a TensorflowJs model

If we want to read the weights of just one layer from a model we can do as follows:

const printWeightsFromLayer = (model, layerNo) => {
  const layer = model.layers[layerNo]
  // print layer weights
  layer.getWeights()[0].print()
  // print bias weights
  layer.getWeights()[1].print()
}

You can check out the full codepen of this example.

šŸ“– 50 Javascript, React and NextJs Projects

Learn by doing with this FREE ebook! Not sure what to build? Dive in with 50 projects with project briefs and wireframes! Choose from 8 project categories and get started right away.

šŸ“– 50 Javascript, React and NextJs Projects

Learn by doing with this FREE ebook! Not sure what to build? Dive in with 50 projects with project briefs and wireframes! Choose from 8 project categories and get started right away.


Leave a Reply

Your email address will not be published. Required fields are marked *