argMax() function will return the index of the maximum value from a tensor.
Why do we need it?
Well, we have seen in the post about one hot encoding that when it comes to putting stuff into different categories neuronal networks are preferring to return an answer with the form of an array instead of just one single value output.
So, let's say that we want to differentiate between cats, dogs, and birds. Instead of printing something like the animal from this image is a cat, a neuronal network will prefer to return an array like
[0.98, 0.12, 0.02] meaning that
98% this is a cat,
12% is a dog and
2% that this is a bird.
Given that we will receive a tensor with the probabilities for each category this means that we need to go through all the returned values and find the category with the highest probability.
And this is where the TensorflowJs argMax() function comes into play. It will return the index that contains the maximum value from a tensor.
const results = tf.tensor1d([0.12, 0.15, 0.98, 0.02]); results.argMax().print() // Tensor // 2
const types = ['cat', 'dog', 'bird', 'mouse'] const predictions = tf.tensor1d([0.12, 0.15, 0.98, 0.02]) const index = predictions.argMax().arraySync() console.log('In the picture we have a ' + types[index]) // In the picture we have a bird
Note that it will return a tensor with the index where the highest value is found, NOT the actual highest value.
We can also call the
argMax() function as follows:
const results = tf.tensor1d([0.12, 0.15, 0.98, 0.02]); tf.argMax(results).print()
And it will work also for multi-dimensional tensors:
const multiDTensor = tf.tensor2d([[0.12, 0.15], [0.98, 0.02]]); multiDTensor.argMax().print() // Tensor // [1, 0]
Here is the link for the codepen if you want to play with it.