Running Your ML model on the client side with TF.js
In this article, we will be loading in a model created in this google Colabs notebook. The model classifies 28X28 images to be in one of the 10 classes in the Fashion MNIST data set, and it will be run COMPLETELY ON THE CLIENT SIDE!!
Imagine you built a great ML model and you want to put it on your website, but you don’t want to pay the extra money to host a server to do the ML on the backend. Well, with TFJS, you can load and use a model completely on the client side by leveraging the browser’s webGL for tensor operations.
The only pre-requisite is that you execute the notebook and download the TF js model (note: you could also use your own model but make sure you adapt the preprocessing step to fit your model’s expected input).
Step 1) Uploading the model
Make sure you put the .bin and .json file in the public directory like I have done in the image below.
This lets us access the files using relative paths once we start our website.
Step 2) Imports
Besides react, we need tf, we also create a list of labels which will help us map the prediction from a number 0–9 to its name.
import React, {useEffect, useState} from 'react';
import * as tf from '@tensorflow/tfjs';
import {LayersModel, Tensor} from "@tensorflow/tfjs";
const labels = ["T-shirt/top", "Trouser", "Pullover", "Dress", "Coat", "Sandal", "Shirt", "Sneaker", "Bag", "Ankle Boot"];
Step 2) Loading the model
function App() {
const [model, setModel] = useState<LayersModel>()
// when component mounts, load the model
useEffect(() => {
loadModel();
}, []);
const loadModel = async () => {
setModel(await tf.loadLayersModel("./model/model.json"));
}
return (
<div>
Hi
</div>
);
}
We use the useEffect hook with an empty array of dependencies to fire one time when the component loads. We set our model to be the result of using the tf module to load up the model.
Step 3) Add preprocessing
const srgbToBlackAndWhite = (imageData: ImageData) => {
const data = imageData.data;
const bw = []
for (let i = 0; i < data.length; i += 4) {
const r = data[i];
const g = data[i + 1];
const b = data[i + 2];
// Convert the pixel to grayscale using the luminosity method and flip
const gray = Math.abs(255 - (0.2126 * r + 0.7152 * g + 0.0722 * b));
// add the new gray pixel
bw.push(gray/255.0)
}
return bw;
}
When our image is uploaded, it will be in rgba format and the pixel range will be 0–255. this function converts the image to black and white scaled from 0–1 (which is the format our model expects).
Step 4) Add HTML
return (
<div>
<label htmlFor={"image-input"}>Image</label>
<input type={"file"} id={"image-input"} accept={"image/png"}/>
<button onClick={predictModel}>Predict</button>
<p id={"result"}/>
<canvas id={"canvas"} style={{display: "none"}}></canvas>
</div>
);
Our HTML is quite simple for this component, we have an input field to upload the image, a button to run the prediction, a p tag to show the result and a hidden canvas which lets us extract data from the image.
Step 4) predictModel()
const predictModel = () => {
// get document element references
const result = document.getElementById("result") as HTMLElement;
const input = document.getElementById('image-input') as HTMLInputElement;
const canvas = document.getElementById("canvas") as HTMLCanvasElement;
const context = canvas.getContext('2d');
const reader = new FileReader();
// when the reader loads something...
reader.onload = (event) => {
const img = new Image();
// when image loads...
img.onload = () => {
// draw it on the canvas
canvas.width = img.width;
canvas.height = img.height;
context!.drawImage(img, 0, 0);
// get the data
const sRGBImageData = context!.getImageData(0, 0, canvas.width, canvas.height);
//turn image into bw image w/ 0-1 range
const bwScaledImage = srgbToBlackAndWhite(sRGBImageData);
// feed the image into the model and get the argmax of the output
// our model expects a tensor of [1,28,28,1] so we shape it to that
const res = (model?.predict(tf.tensor4d(bwScaledImage, [1, 28, 28, 1])) as Tensor );
const arr = res.bufferSync().values;
const max = arr.indexOf(Math.max(...arr));
// set the html element to the corresponding label to the argmax
result.innerText = labels[max];
};
// load the image as the result of reading
img.src = event!.target!.result as string;
};
// load input image as url
reader.readAsDataURL(input.files![0]);
}
Now we can create our predictModel function which is invoked when the button is clicked. It will…
- get references to the html elements we need.
- create a reader and give it an onLoad function which creates a new image and sets its url to the file given to the reader.
- create an img onLoad function which will: draw the image on the canvas, extract the rgba data, convert it to bw image data, use the model to predict the outcome (which is an array of length 10), get the index of the label with the highest probability and finally set the result to be that label.
- reads the file in the input as a dataURL into the reader which triggers the onLoad of the reader and in turn triggers the onLoad of the img.
Done!
Now you can test out your model in the browser by running npm start, I have attached two sample pngs, my model gets the pullover correct but the shoe wrong. See if you can do better!