Tensorflow.js in Jupyter

We’re going to use a pre-trained model based on the MNIST dataset in order to recognize digits written on a <canvas> element on this notebook.

Free-drawing in Canvas

The browser library canvas-free-drawing is a faily simple way to recollect user drawn input.

First we load it into the browser:

import { resolve } from 'path';

document.body.innerHTML = `
<script src="https://cdn.jsdelivr.net/npm/canvas-free-drawing@2.1.0/umd/canvas-free-drawing.min.js"></script>

Done.
`;

jupyter.renderDom();
Done.

Then we render a small canvas to draw on:

document.body.innerHTML = `
<canvas id="freeDrawCanvasTest" style="border: 1px solid #888888;"></canvas>
<div id="canvasTestStatus"></div>
<button onclick="window.cfdTest.clear()">Clear</button>
<script>
  if (
    typeof CanvasFreeDrawing !== 'undefined'
  ) {
    window.cfdTest = new CanvasFreeDrawing.default({
      elementId: 'freeDrawCanvasTest',
      width: 140,
      height: 140,
    });

    // set properties
    window.cfdTest.setLineWidth(10); // in px
    window.cfdTest.setStrokeColor([0, 0, 0]); // in RGB
  }
</script>
`;

jupyter.renderDom();

Note

Once exported, this is the only demo that will be interactive. The rest will be rendered static. You will have to download the template from github and run it locally to interact with the other examples on this notebook.

Back-end / Front-end Communication

We should be able to setup a WebSocket server that can provide two-way communication between the TypeScript back-end and the HTML front-end:

import {
  readFileSync,
  writeFileSync,
  rmSync,
} from 'fs';
import { WebSocketServer } from 'ws';

(async () => {
  // When you run this process it never returns an output to Jupyter
  // nor does Jupyter know how to engage with a running script
  //
  // To make sure we do not spawn too many back-end processes and/or
  // hit an EADDRINUSE error, we keep track of the pid of it, then
  // make sure to kill the previous running instance every time
  try {
    process.kill(readFileSync('.wsComm.pid'), 'SIGKILL');
    rmSync('.wsComm.pid');
    await new Promise((resolve) => setTimeout(resolve, 1000));
  } catch {
      // Do Nothing
  }

  writeFileSync('.wsComm.pid', `${process.pid}`);

  // Run a WebSocket server
  const wss = new WebSocketServer({ port: 9890 });

  wss.on('connection', (ws) => {
    ws.on('message', (eventData) => {
      // Handle incoming data from the front-end
      const {data, width, height} = JSON.parse(eventData.toString());

      // Do something trivial with the image data for now, just to make
      // sure all is working as it should
      const imageData = Object.values(data);
      const whitePercent = Math.round((imageData.filter((n) => n === 255).length / imageData.length) * 100);
      const blackPercent = Math.round((imageData.filter((n) => n === 0).length / imageData.length) * 100);

      // Respond with a result to display on the front-end
      ws.send(
        JSON.stringify(
          {
            response: `Image ${width}x${height} | ${whitePercent}% white | ${blackPercent}% black`
          }
        )
      );
    });
  });
})();

and the HTML front-end should be able to communicate with the back-end:

document.body.innerHTML = `
<canvas id="freeDrawCanvasComm" style="border: 1px solid #888888;"></canvas>
<div id="canvasCommStatus"></div>
<button onclick="window.cfdComm.clear()">Clear</button>
<script>
  // Every time we run this cell we should reset any previously
  // connected WebSockets
  if (typeof window.socketComm !== 'undefined') {
    window.socketComm.close();
  }

  // Connect to the back-end
  window.socketComm = new WebSocket('ws://localhost:9890');
  window.socketComm.binaryType = "arraybuffer";

  // Connection opened
  window.socketComm.addEventListener('open', function (event) {
    canvasCommStatus.innerHTML = 'Connected!';
  });

  // Listen for messages
  window.socketComm.addEventListener('message', function (event) {
    canvasCommStatus.innerHTML = 'Message from server: ' + JSON.parse(event.data.toString()).response;
  });

  if (
    typeof CanvasFreeDrawing !== 'undefined'
  ) {
    window.cfdComm = new CanvasFreeDrawing.default({
      elementId: 'freeDrawCanvasComm',
      width: 150,
      height: 150,
    });

    // set properties
    window.cfdComm.setLineWidth(10); // in px
    window.cfdComm.setStrokeColor([0, 0, 0]); // in RGB

    // Send the image data to the back-end every time we're done drawing
    window.cfdComm.on({ event: 'mouseleave' }, () => {
      const { data, width, height } = cfdComm.context.getImageData(0, 0, 150, 150);

      window.socketComm.send(JSON.stringify({ data, width, height }));
    });
  }
</script>
`;

jupyter.renderDom();

Message from server: Image 150x150 | 86% white | 14% black

Note

This is a static example of the output you would get when running the notebook. To see the interactive demo download the template from github and run it locally.

Using digit recognition model with Tensorflow

Now that we have validated that we can run a back-end and front-end from Jupyter with Typescript, and that both processes can communicate together, we can continue on forward and translate a model from h5 (Keras) to a set files that can be used with tfjs using the command:

tensorflowjs_converter --input_format keras mnist.h5 .

Then update the back-end to receive the data from the front-end, create a tensor from it, and make a prediction using the model (which it sends back to the front-end):

import { resolve } from 'path';
import {
    readFileSync,
    writeFileSync,
    rmSync,
} from 'fs';
import { WebSocketServer } from 'ws';
import * as tf from '@tensorflow/tfjs';
import * as tfn from '@tensorflow/tfjs-node';

(async () => {
  try {
    process.kill(readFileSync('.ws.pid'), 'SIGKILL');
    rmSync('.ws.pid');
    await new Promise((resolve) => setTimeout(resolve, 1000));
  } catch {
    // Do Nothing
  }

  const handler = tfn.io.fileSystem(resolve(__dirname, 'model/model.json'));
  const model = await tf.loadLayersModel(handler);

  writeFileSync('.ws.pid', `${process.pid}`);

  const wss = new WebSocketServer({ port: 9898 });

  wss.on('connection', (ws) => {
    ws.on('message', async (eventData) => {
      const {data: rawData, width, height} = JSON.parse(eventData.toString());
      const data = Object.values(rawData);

      // We run our predictions inside of tf.tidy to make sure
      // we avoid memory leaks by cleaning up intermediate
      // memory allocated to the tensors
      await tf.tidy(() => {
        let img = tf.tensor(data);
        img = img.reshape([1, 28, 28, 1]);
        img = tf.cast(img, 'float32');

        const output = model.predict(img) as any;

        const predictions = Array.from(
            output.dataSync()
        );

        // Once we have our predictions we look for the one
        // with the highest confidence
        const prediction = predictions.findIndex(
          (n) => n === predictions.reduce(
            (a, p) => ((p > a) ? p : a),
            0
          )
        );

        // We get the confidence of the previously found
        // prediction and turn it to a percentage
        const confidence = predictions[prediction] * 100;

        // And send the results to the front-end via the WebSocket
        ws.send(
          JSON.stringify(
            {
              response: `the number is <strong>${prediction}</strong> <i>(${confidence}% confidence)</i>`,
              // response: JSON.stringify(predictions),
            }
          )
        );
      });
    });
  });
})();

Then simply make sure to send the image data from the front-end’s <canvas> to the back-end for processing in real-time:

import { readFileSync } from 'fs';

const canvasFreeDrawingSrc = readFileSync('./canvas-free-drawing.min.js');

document.body.innerHTML = `
<canvas id="freeDrawCanvas" style="border: 1px solid #888888;"></canvas>
<div id="canvasStatus"></div>
<button onclick="window.cfd.clear()">Clear</button>
<script>
    if (typeof window.socket !== 'undefined') {
        window.socket.close();
    }

    window.socket = new WebSocket('ws://localhost:9898');
    window.socket.binaryType = "arraybuffer";

    // Connection opened
    window.socket.addEventListener('open', function (event) {
        canvasStatus.innerHTML = 'Connected!';
    });

    // Listen for messages
    window.socket.addEventListener('message', function (event) {
        canvasStatus.innerHTML = 'Message from server: ' + JSON.parse(event.data.toString()).response;
    });

    if (
      typeof CanvasFreeDrawing !== 'undefined'
    ) {
      window.cfd = new CanvasFreeDrawing.default({
        elementId: 'freeDrawCanvas',
        backgroundColor: [0, 0, 0],
        width: 140,
        height: 140,
      });

      // set properties
      window.cfd.setLineWidth(5); // in px
      window.cfd.setStrokeColor([255, 255, 255]); // in RGB

      window.cfd.on({ event: 'mouseleave' }, async () => {
        const imageData = cfd.context.getImageData(0, 0, 140, 140);

        const resizeWidth = 28 >> 0;
        const resizeHeight = 28 >> 0;
        const ibm = await window.createImageBitmap(
          imageData,
          0,
          0,
          imageData.width,
          imageData.height,
          {
            resizeWidth,
            resizeHeight
          }
        );

        // The image is scaled down to 28x28 pixels since that
        // is what the model was trained on
        const resizeCanvas = document.createElement('canvas');
        resizeCanvas.width = 28;
        resizeCanvas.height = 28;
        const resizeCtx = resizeCanvas.getContext('2d');
        resizeCtx.drawImage(ibm, 0, 0);
        const { data: rawData, width, height } = resizeCtx.getImageData(0, 0, 28, 28);

        // Since the image is black and white, we can simply
        // take only every 4th value (due to pixels being in RGBA)
        const data = rawData.reduce(
          (a, b, i) => {if (i % 4 === 0) {a.push(b)} return a;},
          []
        );

        window.socket.send(JSON.stringify({ data, width, height }));
      });
    }
</script>
`;

jupyter.renderDom();

Note

The colors are inverted since the model used was trained on images with black background and white foreground.

Message from server: the number is 7 (100% confidence)

Note

This is a static example of the output you would get when running the notebook. To see the interactive demo download the template from github and run it locally.

Mandatory victory GIF

document.body.innerHTML = '<img src="https://media.giphy.com/media/HQWLR3lyeGbawyfYIw/giphy.gif" />';
jupyter.renderDom();