import {Drawer} from './components/Drawer.js'
viewof drawing_pad = {
  clear;
  return Drawer({
    width: 8 * 28,
    height: 8 * 28
  });
}Digit recognition
    Illustrates CNN and KNN algorithms for digit classification.
  
Draw a digit in the box below and the algorithms on this page will attempt to assess what digits you’ve drawn.
What’s going on?
This demo is meant to illustrate the application of two machine learning algorithms:
- A CNN or Convolutional Neural Network and
- a KNN or K-Nearest Neighbors algorithm.
Classification by CNN is the go to approach for image recognition. While not as strong, KNN can do surprisingly well and is simple enough to explain to an introductory class in machine learning on day one. We should be able to describe how CNN works, by contrast, by the last day of class.
Both these models were built with the Javascript version of Tensorflow.
Why does KNN take so much longer to load than CNN?
A Convolutional Neural Network is a parametric model. This means that the model defines a function that performs the classification. That model depends on a number of parameters and, once those parameters are determined, classification can proceed. Importantly, once the model function is generated, that’s all you need to perform classification. Tensorflow provides a tool to export the model in a standardized form.
The K-Nearest Neighbors algorithm, by contrast, is nonparametric. With no model function, this technique requires access to the actual data. In addition, Tensorflow packs that data into an internal format for efficient access. There’s no standard export tool so loading and packing that data takes a few moments.
The KNN takes about 10 seconds to load on my laptop with a good connection.
What’s machine learning, anyway?
Machine learning consists of a class of algorithms that consumes data to produce a function to perform a desired task. Often, the data is labeled into classes and the objective is to classify similar data into those same classes.
Colloquially, we say that the algorithm “learns” from data.
What does the data look like in this case?
The data for the task on this page is known at the MNIST Digit Dataset and consists of two key components:
- A list of \(70,000\) vectors of length \(784\). Each vector can be partitioned into a \(28\times28\) matrix. The entries are all floating point numbers in \([0,1]\) and can be interpreted as gray scale values that determines an image.
- A list of \(70,000\) labels each of which is an integer from \(0\) to \(9\) indicating what digit the corresponding matrix represents.
The slider below allows you to peruse the \(10,000\) images used to build the KNN classifier, together with their labels. This is the exact dataset used to construct the KNN classifier.
Where can I learn more??
- Like all the demos on Mark’s Math, the full code can be found in my GitHub repo.
- The code defining the KNN algorithm is quite simple and can be found in the index.qmd file for this page.
- The CNN model is much deeper and was trained with this Observable notebook.
- If you want to learn all the mathematics, statistics, and theory behind this stuff, you should take my Math for Machine Learning Course!