Tensorflow.js in Jupyter
Contents
Tensorflow.js in Jupyter¶
We’re going to use a pre-trained model based on the MNIST dataset in order to recognize digits written on a <canvas>
element on this notebook.
Free-drawing in Canvas¶
The browser library canvas-free-drawing
is a faily simple way to recollect user drawn input.
First we load it into the browser:
import { resolve } from 'path';
document.body.innerHTML = `
<script src="https://cdn.jsdelivr.net/npm/canvas-free-drawing@2.1.0/umd/canvas-free-drawing.min.js"></script>
Done.
`;
jupyter.renderDom();
Then we render a small canvas to draw on:
document.body.innerHTML = `
<canvas id="freeDrawCanvasTest" style="border: 1px solid #888888;"></canvas>
<div id="canvasTestStatus"></div>
<button onclick="window.cfdTest.clear()">Clear</button>
<script>
if (
typeof CanvasFreeDrawing !== 'undefined'
) {
window.cfdTest = new CanvasFreeDrawing.default({
elementId: 'freeDrawCanvasTest',
width: 140,
height: 140,
});
// set properties
window.cfdTest.setLineWidth(10); // in px
window.cfdTest.setStrokeColor([0, 0, 0]); // in RGB
}
</script>
`;
jupyter.renderDom();
Note
Once exported, this is the only demo that will be interactive. The rest will be rendered static. You will have to download the template from github and run it locally to interact with the other examples on this notebook.
Back-end / Front-end Communication¶
We should be able to setup a WebSocket server that can provide two-way communication between the TypeScript back-end and the HTML front-end:
import {
readFileSync,
writeFileSync,
rmSync,
} from 'fs';
import { WebSocketServer } from 'ws';
(async () => {
// When you run this process it never returns an output to Jupyter
// nor does Jupyter know how to engage with a running script
//
// To make sure we do not spawn too many back-end processes and/or
// hit an EADDRINUSE error, we keep track of the pid of it, then
// make sure to kill the previous running instance every time
try {
process.kill(readFileSync('.wsComm.pid'), 'SIGKILL');
rmSync('.wsComm.pid');
await new Promise((resolve) => setTimeout(resolve, 1000));
} catch {
// Do Nothing
}
writeFileSync('.wsComm.pid', `${process.pid}`);
// Run a WebSocket server
const wss = new WebSocketServer({ port: 9890 });
wss.on('connection', (ws) => {
ws.on('message', (eventData) => {
// Handle incoming data from the front-end
const {data, width, height} = JSON.parse(eventData.toString());
// Do something trivial with the image data for now, just to make
// sure all is working as it should
const imageData = Object.values(data);
const whitePercent = Math.round((imageData.filter((n) => n === 255).length / imageData.length) * 100);
const blackPercent = Math.round((imageData.filter((n) => n === 0).length / imageData.length) * 100);
// Respond with a result to display on the front-end
ws.send(
JSON.stringify(
{
response: `Image ${width}x${height} | ${whitePercent}% white | ${blackPercent}% black`
}
)
);
});
});
})();
and the HTML front-end should be able to communicate with the back-end:
document.body.innerHTML = `
<canvas id="freeDrawCanvasComm" style="border: 1px solid #888888;"></canvas>
<div id="canvasCommStatus"></div>
<button onclick="window.cfdComm.clear()">Clear</button>
<script>
// Every time we run this cell we should reset any previously
// connected WebSockets
if (typeof window.socketComm !== 'undefined') {
window.socketComm.close();
}
// Connect to the back-end
window.socketComm = new WebSocket('ws://localhost:9890');
window.socketComm.binaryType = "arraybuffer";
// Connection opened
window.socketComm.addEventListener('open', function (event) {
canvasCommStatus.innerHTML = 'Connected!';
});
// Listen for messages
window.socketComm.addEventListener('message', function (event) {
canvasCommStatus.innerHTML = 'Message from server: ' + JSON.parse(event.data.toString()).response;
});
if (
typeof CanvasFreeDrawing !== 'undefined'
) {
window.cfdComm = new CanvasFreeDrawing.default({
elementId: 'freeDrawCanvasComm',
width: 150,
height: 150,
});
// set properties
window.cfdComm.setLineWidth(10); // in px
window.cfdComm.setStrokeColor([0, 0, 0]); // in RGB
// Send the image data to the back-end every time we're done drawing
window.cfdComm.on({ event: 'mouseleave' }, () => {
const { data, width, height } = cfdComm.context.getImageData(0, 0, 150, 150);
window.socketComm.send(JSON.stringify({ data, width, height }));
});
}
</script>
`;
jupyter.renderDom();
Note
This is a static example of the output you would get when running the notebook. To see the interactive demo download the template from github and run it locally.
Using digit recognition model with Tensorflow¶
Now that we have validated that we can run a back-end and front-end from Jupyter with Typescript, and that both processes can communicate together, we can continue on forward and translate a model from h5
(Keras) to a set files that can be used with tfjs
using the command:
tensorflowjs_converter --input_format keras mnist.h5 .
Then update the back-end to receive the data from the front-end, create a tensor from it, and make a prediction using the model (which it sends back to the front-end):
import { resolve } from 'path';
import {
readFileSync,
writeFileSync,
rmSync,
} from 'fs';
import { WebSocketServer } from 'ws';
import * as tf from '@tensorflow/tfjs';
import * as tfn from '@tensorflow/tfjs-node';
(async () => {
try {
process.kill(readFileSync('.ws.pid'), 'SIGKILL');
rmSync('.ws.pid');
await new Promise((resolve) => setTimeout(resolve, 1000));
} catch {
// Do Nothing
}
const handler = tfn.io.fileSystem(resolve(__dirname, 'model/model.json'));
const model = await tf.loadLayersModel(handler);
writeFileSync('.ws.pid', `${process.pid}`);
const wss = new WebSocketServer({ port: 9898 });
wss.on('connection', (ws) => {
ws.on('message', async (eventData) => {
const {data: rawData, width, height} = JSON.parse(eventData.toString());
const data = Object.values(rawData);
// We run our predictions inside of tf.tidy to make sure
// we avoid memory leaks by cleaning up intermediate
// memory allocated to the tensors
await tf.tidy(() => {
let img = tf.tensor(data);
img = img.reshape([1, 28, 28, 1]);
img = tf.cast(img, 'float32');
const output = model.predict(img) as any;
const predictions = Array.from(
output.dataSync()
);
// Once we have our predictions we look for the one
// with the highest confidence
const prediction = predictions.findIndex(
(n) => n === predictions.reduce(
(a, p) => ((p > a) ? p : a),
0
)
);
// We get the confidence of the previously found
// prediction and turn it to a percentage
const confidence = predictions[prediction] * 100;
// And send the results to the front-end via the WebSocket
ws.send(
JSON.stringify(
{
response: `the number is <strong>${prediction}</strong> <i>(${confidence}% confidence)</i>`,
// response: JSON.stringify(predictions),
}
)
);
});
});
});
})();
Then simply make sure to send the image data from the front-end’s <canvas>
to the back-end for processing in real-time:
import { readFileSync } from 'fs';
const canvasFreeDrawingSrc = readFileSync('./canvas-free-drawing.min.js');
document.body.innerHTML = `
<canvas id="freeDrawCanvas" style="border: 1px solid #888888;"></canvas>
<div id="canvasStatus"></div>
<button onclick="window.cfd.clear()">Clear</button>
<script>
if (typeof window.socket !== 'undefined') {
window.socket.close();
}
window.socket = new WebSocket('ws://localhost:9898');
window.socket.binaryType = "arraybuffer";
// Connection opened
window.socket.addEventListener('open', function (event) {
canvasStatus.innerHTML = 'Connected!';
});
// Listen for messages
window.socket.addEventListener('message', function (event) {
canvasStatus.innerHTML = 'Message from server: ' + JSON.parse(event.data.toString()).response;
});
if (
typeof CanvasFreeDrawing !== 'undefined'
) {
window.cfd = new CanvasFreeDrawing.default({
elementId: 'freeDrawCanvas',
backgroundColor: [0, 0, 0],
width: 140,
height: 140,
});
// set properties
window.cfd.setLineWidth(5); // in px
window.cfd.setStrokeColor([255, 255, 255]); // in RGB
window.cfd.on({ event: 'mouseleave' }, async () => {
const imageData = cfd.context.getImageData(0, 0, 140, 140);
const resizeWidth = 28 >> 0;
const resizeHeight = 28 >> 0;
const ibm = await window.createImageBitmap(
imageData,
0,
0,
imageData.width,
imageData.height,
{
resizeWidth,
resizeHeight
}
);
// The image is scaled down to 28x28 pixels since that
// is what the model was trained on
const resizeCanvas = document.createElement('canvas');
resizeCanvas.width = 28;
resizeCanvas.height = 28;
const resizeCtx = resizeCanvas.getContext('2d');
resizeCtx.drawImage(ibm, 0, 0);
const { data: rawData, width, height } = resizeCtx.getImageData(0, 0, 28, 28);
// Since the image is black and white, we can simply
// take only every 4th value (due to pixels being in RGBA)
const data = rawData.reduce(
(a, b, i) => {if (i % 4 === 0) {a.push(b)} return a;},
[]
);
window.socket.send(JSON.stringify({ data, width, height }));
});
}
</script>
`;
jupyter.renderDom();
Note
The colors are inverted since the model used was trained on images with black background and white foreground.
Note
This is a static example of the output you would get when running the notebook. To see the interactive demo download the template from github and run it locally.
Mandatory victory GIF¶
document.body.innerHTML = '<img src="https://media.giphy.com/media/HQWLR3lyeGbawyfYIw/giphy.gif" />';
jupyter.renderDom();
