Linear regression with Python

Here are some tools we can use to do a full linear regression:

In [1]:
%matplotlib inline
import numpy as np
from matplotlib import pyplot as plt
from scipy.stats import linregress
import pandas as pd

Real world regression with a data file

Here's one of our commonly used datasets:

In [2]:
df = pd.read_csv('https://www.marksmath.org/data/cdc.csv')
df.head()
Out[2]:
Unnamed: 0 genhlth exerany hlthplan smoke100 height weight wtdesire age gender
0 1 good 0 1 0 70 175 175 77 m
1 2 good 0 1 1 64 125 115 33 f
2 3 good 1 1 1 60 105 105 49 f
3 4 good 1 1 0 66 132 124 42 f
4 5 very good 0 1 0 61 150 130 55 f

Let's grab a sample from there and see how weight is related to height.

In [3]:
sam = df.sample(200, random_state=1)
sam.plot.scatter('height', 'weight')
Out[3]:
<matplotlib.axes._subplots.AxesSubplot at 0x1a1e1c0048>

Now, let's perform a linear regression:

In [4]:
lr = linregress(sam.height, sam.weight)
lr
Out[4]:
LinregressResult(slope=5.664277180406205, intercept=-214.64014336917512, rvalue=0.5756930765891695, pvalue=4.771056232490472e-19, stderr=0.5717376182138059)

Well, there's a number of things that we'll need to interpret here. The first is the regression line, which can be defined in terms of the slope and intercept:

In [5]:
def f(x): return lr.slope*x + lr.intercept
sam.plot.scatter('height', 'weight')
plt.plot([55,80], [f(55), f(80)], 'black')
Out[5]:
[<matplotlib.lines.Line2D at 0x1a1e159da0>]

Simple HW style linear regresion

For homework, you might want to just enter a small data set, like so:

In [6]:
x = [1,2,3,8]
y = [4,3,6,9]
plt.plot(x,y,'bo')
Out[6]:
[<matplotlib.lines.Line2D at 0x1a1f8d4dd8>]

And do a regression:

In [7]:
lr = linregress([1,2,3,8], [4,3,6,9])
lr
Out[7]:
LinregressResult(slope=0.7931034482758621, intercept=2.7241379310344827, rvalue=0.9320070332440739, pvalue=0.06799296675592614, stderr=0.21808811449437113)

And visualize it:

In [8]:
def f(x): return lr.slope*x + lr.intercept
plt.plot(x,y, 'bo')
plt.plot([1,8], [f(1), f(8)], 'black')
Out[8]:
[<matplotlib.lines.Line2D at 0x1a203e8860>]