K Nearest Neighbors

Sat, Apr 26, 2025

K Nearest Neighbors (KNN)

Today, we’re going to take a quick look at K Nearest Neighbor (or KNN) algorithms, which are non-parametric, supervised techniques for both classification and regression.

It’s also worth distinguishing K Nearest Neighbors from K-Means, which is a non-parametric, unsupervised clustering algorithm. The two techniques are really quite different but sometimes confused due to their similar names.

Imports

Here are a few imports that we’ll be using today, in addition to our usuals.

# KNN for (well) KNN and 
# OneHotEncoder for dealing with categorical data
from sklearn.neighbors import KNeighborsClassifier, KNeighborsRegressor
from sklearn.preprocessing import OneHotEncoder

# Some other tools for building and evaluating models
from sklearn.preprocessing import StandardScaler
from sklearn.decomposition import PCA
from sklearn.pipeline import make_pipeline
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score

# And access to OpenML's data library
from sklearn.datasets import fetch_openml

# A custom distance function 
import gower

Basics

K Nearest Neighbors defines a class of supervised algorithms. Thus, we have labeled data. The idea is to develop an algorithm to predict the label, when given new data.

KNN is non-parametric. Thus, it makes no assumption about a model function for the data so there are no parameters to optimize.

Basic idea

The fundamental idea behind KNN is to examine other data near the point under consideration and to guess the label based on those other values. We might examine 5 nearest neighbors, 8 nearest neighbors, or 15 nearest neighbors. In general, we examine \(k\) nearest neighbors, which is where the name comes from.

How we determine our guess for the label depends on the nature of the data.

  • If the label is numeric, we often average the neighboring values.
  • If the label is categorical, we might use majority vote.

Example

Into which species would you classify the first entry ???

species island sex culmen_length_mm culmen_depth_mm flipper_length_mm body_mass_g
??? Dream MALE 48 15 222 4850
... ... ... ... ... ... ...
Gentoo Biscoe FEMALE 46.1 13.2 211.0 4500.0
Adelie Biscoe MALE 37.7 18.7 180.0 3600.0
Gentoo Biscoe MALE 50.0 16.3 230.0 5700.0
Adelie Biscoe FEMALE 37.8 18.3 174.0 3400.0
Chinstrap Dream FEMALE 46.5 17.9 192.0 3500.0
Adelie Dream FEMALE 39.5 16.7 178.0 3250.0
Adelie Dream MALE 37.2 18.1 178.0 3900.0
Gentoo Biscoe FEMALE 48.7 14.1 210.0 4450.0
Chinstrap Dream MALE 51.3 19.2 193.0 3650.0
Chinstrap Dream MALE 50.0 19.5 196.0 3900.0

A plot of that data

Here’s a plot of that same data in the body_mass/culmen_depth plane. The classification is even more obvious, now.

Principal component plot

To illustrate the idea even further, let’s plot the first two principal components.

Classification plot

Here’s the resulting nearest neighbor classification plot based on just the first two principal components.

Comments

  • Given a point in the PC1/PC2 plane, the classification is based on the 8 data points that are closest to that point. Whichever of the three species is most common among those 8 points is assigned to the given point.
  • The accuracy score is about 92%.
  • The classification is based purely on the data and there is no model assumed for that data. Thus, the technique is non-parametric.
  • Purely data based classification like this can lead more jagged classification boundaries that we see in the plot.
  • You can compare this to prior classifications we’ve seen before like

Synthetic example

Here’s a synthetic example to drive the point home.

A couple of real world examples

Let’s check out a couple of real world examples.

  • The MNIST Digit challenge and
  • Mushrooms!

MNIST

We are, of course, already familiar with MNIST so let’s not rehash it. It makes a nice benchmark, though, as we’ve worked with it before.

Recall the following results from our Colab CNN notebook:

  • Logistic Regression: 91.06%
  • Simple NN: 95.76%
  • Convolutional NN: 98.68%

MNIST with KNN

Here’s how to perform KNN on MNIST using 5 nearest neighbors. The general pattern should be pretty familiar by now - though, fit seems like an odd term in this context??

X, y = fetch_openml('mnist_784', 
  version=1, return_X_y=True, as_frame=False)
scaler = StandardScaler()
X_scaled = scaler.fit_transform(X)
X_train, X_test, y_train, y_test = train_test_split(
    X_scaled, y, test_size=0.2, random_state=1)
knn = KNeighborsClassifier(n_neighbors=5)
knn.fit(X_train, y_train)
y_pred = knn.predict(X_test)
print("Accuracy:", accuracy_score(y_test, y_pred))
Accuracy: 0.9447857142857143

Mushrooms

Here’s another data set that KNN works very well for: a mushroom data set available at OpenML.

This data set lists a describes 22 mushroom characteristics and labels each as poisonous or not. The objective is to develop an algorithm that will determine if a mushroom is poisonous or not based on the other data.

Seems worthwhile!

A look at the mushroom data

The data can be loaded via SciKit-Learns fetch_openml. Notably, the data is purely categorical.

X, y = fetch_openml(data_id=24, as_frame=True, return_X_y=True)
X
cap-shape cap-surface cap-color bruises%3F odor gill-attachment gill-spacing gill-size gill-color stalk-shape ... stalk-surface-below-ring stalk-color-above-ring stalk-color-below-ring veil-type veil-color ring-number ring-type spore-print-color population habitat
0 x s n t p f c n k e ... s w w p w o p k s u
1 x s y t a f c b k e ... s w w p w o p n n g
2 b s w t l f c b n e ... s w w p w o p n n m
3 x y w t p f c n n e ... s w w p w o p k s u
4 x s g f n f w b k t ... s w w p w o e n a g
... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ...
8119 k s n f n a c b y e ... s o o p o o p b c l
8120 x s n f n a c b y e ... s o o p n o p b v l
8121 f s n f n a c b n e ... s o o p o o p b c l
8122 k y n f y f c n b t ... k w w p w o e w v l
8123 x s n f n a c b y e ... s o o p o o p o c l

8124 rows × 22 columns

KNN for mushrooms

X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=1)
pipeline = make_pipeline(
    OneHotEncoder(handle_unknown='ignore',sparse_output=False),
    KNeighborsClassifier(metric='hamming', n_neighbors=15))
pipeline.fit(X_train, y_train)
y_pred = pipeline.predict(X_test)
print("Accuracy:", accuracy_score(y_test, y_pred))
Accuracy: 0.9975381585425899

Metrics

Recall our discussion on Norms and Inner Products back on February 11. We learned that the norm \(\|\mathbf{x}\|\) of a vector \(\mathbf{x}\) generalizes the notion of length of vector to an abstract function satisfying a few algebraic properties. Furthermore, given two vectors \(\mathbf{x}\) and \(\mathbf{y}\), \[\|\mathbf{x} - \mathbf{y}\|\] generalizes the notion of distance between the points determined by those vectors. We should discuss the choice of norm (and, therefore, distance) that we should use for K Nearest Neighbors, particularly when we have categorical or mixed data.

Hamming distance

When dealing with purely categorical data, like the mushroom data, the Hamming distance is a very natural tool to use. Given two rows of categorical data, the Hamming distance simply counts the number of variables on which the data disagree.

For example, the distance between row 0 and row 1 below is 2, while the distance between row 0 and row 2 is 3.

cap-shape cap-surface cap-color bruises%3F odor
0 x s n t p
1 x s y t a
2 b s w t l

One-hot encoding

The Hamming distance can be expressed in terms of the L1 norm using one-hot encoding.

Recall that, if a variable has \(v\) possible values, then that single variable contributes \(v\) columns to the encoded data frame. If we label those columns \(c_1,c_2,\ldots,c_v\), then \(c_i=1\) precisely when that variable attains the \(i^{\text{th}}\) value; otherwise \(c_i=0\).

Example

Consulting the mushroom data reference, I notice that cap-surface can have one of four values:

f b y s
fibrous grooves scaly smooth

Thus, the portion of the rows in the previous example corresponding to this variable would be encoded as

\[[0,0,0,1].\]

This is done for each variable. In the mushroom data, there are 22 variables but they can they can each take on a number of values so that the resulting one-hot encoded vector has length 117.

Example with computation

Now suppose that we have two mushrooms, one with fibrous cap surface and one with smooth cap surface. The portions of the encoded vector corresponding to the variable will look like

\[\begin{aligned} &[\cdots,1,0,0,0,\cdots] \\ &[\cdots,0,0,0,1,\cdots] \end{aligned}\]

The contribution to the one-norm from this portion will be 1.

Mixed data types

Often, data types are mixed; you might have some numeric data and some categorical data. You might also have different metrics applied to data of the same type. Perhaps, L1 is applied to one numeric variable and L2 to another.

There’s a metric called the Gower distance designed to deal with these differences.

The Gower distance

The Gower distance works first on per-feature basis. Let us suppose that we have data with \(p\) features and \(x=(x_j)_{j=1}^p\) and \(y = (y_j)_{j=1}^p\) are two data points. We suppose that we have a metric \(d_j\) associated with each feature.

The Gower distance simply computes the average of these distances. Thus, \[d(x,y) = \frac{1}{p} \sum_{j=1}^p d_j(x_j,y_j).\]

Individual metrics

Popular choices for the individual metrics when defining a Gower distance are:

  • For numeric data, the distance between values is the absolute value of the difference divided by the range of the variable.
  • For categorical data, the distance between two values is 1 if they are different and 0 if they are the same.

Note that, with these choices, the max distance between two variables is always one. Thus, the Gower distance applies a similar scale to all variables.

In our last column of slides, we’ll apply this to a fun data set.

Kaggle’s Titanic challenge

Here’s yet another classic data set. The data contains much of the passenger manifest of the doomed voyage of 1912. The data contains information on passenger characteristics like sex, age, class, fare, and others. There’s also a column indicating whether the passenger survived or not.

The challenge is to predict survival based on the other characteristics. Kaggle has a learning competition for exactly this purpose.

The data

Let’s take a look at the data:

titanic_data = pd.read_csv(
  'https://raw.githubusercontent.com/datasciencedojo/datasets/master/titanic.csv')
titanic_data.sample(3, random_state=2)
PassengerId Survived Pclass Name Sex Age SibSp Parch Ticket Fare Cabin Embarked
707 708 1 1 Calderhead, Mr. Edward Pennington male 42.0 0 0 PC 17476 26.2875 E24 S
37 38 0 3 Cann, Mr. Ernest Charles male 21.0 0 0 A./5. 2152 8.0500 NaN S
615 616 1 2 Herman, Miss. Alice female 24.0 1 2 220845 65.0000 NaN S

KNN prediction setup

Let’s set things up for a manual 1-Nearest Neighbor classification using the Gower metric:

cols = ['Pclass', 'Sex', 'Age', 'Embarked', 'Survived']
titanic_data = titanic_data[cols].dropna()
X = titanic_data.drop(columns='Survived')
y = titanic_data['Survived']
X_train, X_test, y_train, y_test = train_test_split(
  X, y, random_state=1)

The Gower matrix

The folowing computation computes a matrix \(D\) such that \(D_{ij}\) represents the distance from the \(i^{\text{th}}\) test point to the \(j^{\text{th}}\) train point.

D = gower.gower_matrix(X_test, X_train)
D[:4,:4]
array([[0.60681075, 0.43782985, 0.36623523, 0.5376979 ],
       [0.7939809 , 0.375     , 0.5534054 , 0.27513194],
       [0.5659714 , 0.14699046, 0.32539585, 0.00314149],
       [0.48495224, 0.5659714 , 0.7443767 , 0.66583943]], dtype=float32)

Testing accuracy

Now, we’re going to iterate through the rows (i.e. the test) of the matrix. For each one, we’ll find the closest neighbor (in the train). We’ll base our survival prediction on train data. The closer they are, the more likely they have the same survival outcome.

y_pred = []
for row in D:
    i = row.argmin()
    y_pred.append(y_train.iloc[i])
print("Accuracy:", accuracy_score(y_test, y_pred))
Accuracy: 0.7528089887640449

KNN for regression

It’s worth mentioning that K-Nearest neighbors can also be used for regression, rather than classification. It’s actually easier; we’re trying to estimate numerical data and we simply average the values for the nearby data points.

Suppose for example, we’ve got the following noisy sine wave:

# True model:
x = np.linspace(0,5,500)
y = np.sin(x)

# Noisy approximation:
X = np.sort(5 * np.random.rand(30, 1), axis=0)
Y = np.sin(X).ravel() + 0.2 * np.random.randn(30)

Applying KNN

Then, using that data, we can use KNN to make predictions on given testing data

knn_model = KNeighborsRegressor(n_neighbors=8)
knn_model.fit(X, Y)
X_test = np.linspace(0, 5, 100).reshape(-1, 1)
Y_pred = knn_model.predict(X_test)
print([X_test.shape, X.shape, Y.shape, Y_pred.shape])
[(100, 1), (30, 1), (30,), (100,)]

Visualization

Here’s the result:

Code
plt.figure(figsize=(8, 4))
plt.plot(x,y, linewidth=1, label="Base model")
plt.scatter(X, Y, c="black", label="Training data")
plt.plot(X_test, Y_pred, c="blue", label=f"KNN prediction (k={8})", linewidth=2.5)
plt.legend()
plt.show()

Imputation

It’s also worth mentioning that we’ve seen KNN used before, but called it “Imputation” at that point. The idea was to supplement data with just a few missing values with best guesses for those values. We did so using a KNNImputer as part of a pipeline of several steps. It’s instructive, though, to take a close look at how that works.

Here’s how to import and construct an imputer based on K nearest neighbors:

from sklearn.impute import KNNImputer
imputer = KNNImputer(n_neighbors=3)

Some data

Here’s a bit of data (with some NaNs) to apply the imputer:

r g b
0 5.0 8.0 249
1 248.0 3.0 4
2 NaN 0.0 1
3 5.0 251.0 7
4 254.0 NaN 0
5 5.0 2.0 247
6 0.0 248.0 0
7 251.0 7.0 9
8 248.0 9.0 2
9 247.0 4.0 0
10 1.0 248.0 2
11 247.0 3.0 0

Application

Finally, we apply the imputer to create complete data that can be passed down the pipeline:

imputer.fit_transform(colors)
array([[  5.        ,   8.        , 249.        ],
       [248.        ,   3.        ,   4.        ],
       [249.33333333,   0.        ,   1.        ],
       [  5.        , 251.        ,   7.        ],
       [254.        ,   4.        ,   0.        ],
       [  5.        ,   2.        , 247.        ],
       [  0.        , 248.        ,   0.        ],
       [251.        ,   7.        ,   9.        ],
       [248.        ,   9.        ,   2.        ],
       [247.        ,   4.        ,   0.        ],
       [  1.        , 248.        ,   2.        ],
       [247.        ,   3.        ,   0.        ]])