How to classify images from your webcam in the browser without a server using TensorFlow.js and WebRTC.
In March 2018 Google introduced Tensorflow.js - an open-source library that can be used to define, train, and run machine learning models entirely in the browser, using Javascript.
They also created a tool to import your Keras models into Tensorflow.js format so they can be used, for example, for image classification.
Tensorflow.js has an option to capture source images from different html elements including video element. Using this option we can classify images coming from WebRTC video streams.
As an example I’m going to run MobileNet model provided with Keras library in the browser. The model has been already pre-trained using 1000 image classes from ImageNet database.
Let’s start. First you need to download pre-trained model and same it to h5
format.
Create a python file with the following content:
from keras.applications.mobilenet import MobileNet
INPUT_SHAPE = (128, 128, 3)
model = MobileNet(weights='imagenet', include_top=True, input_shape=INPUT_SHAPE)
model.save('model.h5')
Before running this script please install the tensorflowjs
library for Python:
pip install tensorflowjs
Now run the script. It will create model.h5
file in the current directory.
Then use the following command to convert model.h5
into Tensorflow.js format and save
it to jsmodel
directory:
tensorflowjs_converter --input_format keras ./model.h5 ./jsmodel/
Now let’s create an html page which will grab local webcam stream and throw it to our model. Full full version of the file is located here: https://github.com/alexkorep/webrtc-tensorflowjs-example/blob/master/index.html
First let’s grab the video from the local camera and show in in the video element.
HTML:
<video id="gum-local" width="128" height="128" autoplay playsinline></video>
JavaScript:
navigator.mediaDevices.getUserMedia(constraints).then(handleSuccess)
function handleSuccess(stream) {
const video = document.querySelector('video');
const videoTracks = stream.getVideoTracks();
video.srcObject = stream;
}
In the header let’s import the library:
<script src="https://cdn.jsdelivr.net/npm/@tensorflow/tfjs@0.12.0"> </script>
In order to load the model from jsmodel
directory do the following:
tf.loadModel('jsmodel/model.json')
.then(function (model) {
// Use the model
})
Now we need to grab video frames from the video element and put them to our model. Ideally it should be done for every video frame but on practice it greatly depends on your computer/phone performance.
So let’s call window.requestAnimationFrame
and pass the following function:
function onFrame(model, webrtcElement) {
// Build a tensor from the frame image
const tensor = tf.fromPixels(webrtcElement);
// MobileNet model expects RGB values in 0..1 range for each component,
// but the tensor contains 0..255 values for each RGB component.
// So we must divide them by 256:
const eTensor = tensor.expandDims(0).asType('float32').div(256.0);
// Do actual prediction
const pred = model.predict(eTensor);
// Prediction returns an array with probablity which corresponds
// to each class. We need to find the class with highest probability:
max = tf.argMax(pred, 1)
const index = max.get([0])
// Now we should find which text corresponds to the found index
// and that would be our classificaiton for the image.
}
As a result we get index
variable which value corresponds to predicted image class.
The full list of image class labels can be found on my github: https://github.com/alexkorep/webrtc-tensorflowjs-example/blob/master/index.html.
The full project can be found here: https://github.com/alexkorep/webrtc-tensorflowjs-example and the demo is located here: https://alexkorep.github.io/webrtc-tensorflowjs-example/