Last time, we began to talk about how to build models worthy of our users' trust. As a refresher, we said that trustworthy models require at least three things:
- Prediction -- An estimate for some unknown value
- Confidence -- A description of how uncertain the model is about the prediction
- Explanation -- The reasoning for which a model made its prediction
Today, we'll take a pass at actually implementing such a model.
For pedagogical reasons, we're using a dataset on fish that were sold at a fish market. Here's a few rows from the dataset:
| Species | Weight | Length1 | Length2 | Length3 | Height | Width |
|---------|--------|---------|---------|---------|---------|--------|
| Perch | 250.0 | 25.9 | 28.0 | 29.4 | 7.8204 | 4.2042 |
| Bream | 714.0 | 32.7 | 36.0 | 41.5 | 16.517 | 5.8515 |
| Perch | 145.0 | 22.0 | 24.0 | 25.5 | 6.375 | 3.825 |
| Perch | 145.0 | 20.7 | 22.7 | 24.2 | 5.9532 | 3.63 |
| Bream | 975.0 | 37.4 | 41.0 | 45.9 | 18.6354 | 6.7473 |
The first step, of course, is to load it up!
import os
import pandas as pd
fish = pd.read_csv(os.path.expanduser("~/Downloads/Fish.csv"))
For our exercise today, let's see if we can predict Weight
given the values of
the other columns. We're going to use statsmodels
to build a simple linear
model.
import statsmodels.formula.api as smf
model = smf.ols(
formula="Weight ~ C(Species) + Length2 + Length2 + Length3 + Height + Width",
data=fish,
).fit()
If you've never used statsmodels
before, think of this as fitting a linear
model, with Species
being one-hot encoded. statsmodels
has a nice way of
getting basic information about the model:
model.summary()
OLS Regression Results
==============================================================================
Dep. Variable: Weight R-squared: 0.936
Model: OLS Adj. R-squared: 0.931
Method: Least Squares F-statistic: 195.7
Date: Sun, 14 Jun 2020 Prob (F-statistic): 6.85e-82
Time: 15:00:23 Log-Likelihood: -941.46
No. Observations: 159 AIC: 1907.
Df Residuals: 147 BIC: 1944.
Df Model: 11
Covariance Type: nonrobust
===========================================================================================
coef std err t P>|t| [0.025 0.975]
-------------------------------------------------------------------------------------------
Intercept -918.3321 127.083 -7.226 0.000 -1169.478 -667.186
C(Species)[T.Parkki] 164.7227 75.699 2.176 0.031 15.123 314.322
C(Species)[T.Perch] 137.9489 120.314 1.147 0.253 -99.819 375.717
C(Species)[T.Pike] -208.4294 135.306 -1.540 0.126 -475.826 58.968
C(Species)[T.Roach] 103.0400 91.308 1.128 0.261 -77.407 283.487
C(Species)[T.Smelt] 446.0733 119.430 3.735 0.000 210.051 682.095
C(Species)[T.Whitefish] 93.8742 96.658 0.971 0.333 -97.145 284.893
Length1 -80.3030 36.279 -2.214 0.028 -151.998 -8.608
Length2 79.8886 45.718 1.747 0.083 -10.461 170.238
Length3 32.5354 29.300 1.110 0.269 -25.369 90.439
Height 5.2510 13.056 0.402 0.688 -20.551 31.053
Width -0.5154 23.913 -0.022 0.983 -47.773 46.742
==============================================================================
Omnibus: 43.558 Durbin-Watson: 0.973
Prob(Omnibus): 0.000 Jarque-Bera (JB): 97.422
Skew: 1.184 Prob(JB): 7.00e-22
Kurtosis: 6.016 Cond. No. 2.03e+03
==============================================================================
Warnings:
[1] Standard Errors assume that the covariance matrix of the errors is correctly specified.
[2] The condition number is large, 2.03e+03. This might indicate that there are
strong multicollinearity or other numerical problems.
At this point, we can achieve our first objective: to provide a prediction!
new_fish = pd.DataFrame(
[
{
"Species": "Bream",
"Weight": -1,
"Length1": 31.3,
"Length2": 34,
"Length3": 39.5,
"Height": 15.1285,
"Width": 5.5695,
}
]
)
model.predict(new_fish)
This model predicts this fish weighs about 646 grams.
The main reason I've chosen to use statsmodels (rather thank scikit-learn) is that it provides built-in support for prediction intervals. Take a look:
frame = model.get_prediction(new_fish).summary_frame(alpha=0.95)
frame.round(2)
| mean | mean_se | mean_ci_lower | mean_ci_upper | obs_ci_lower | obs_ci_upper |
|--------|---------|---------------|---------------|--------------|--------------|
| 646.12 | 18.32 | 644.96 | 647.27 | 640.11 | 652.12 |
mean
here is the prediction, and a 95% prediction interval is provided by
obs_ci_lower
and obs_ci_upper
. In other words, our model thinks the weight
of this fish is between 640 and 652 grams with 95% probability.
We're two thirds of the way there!
We can use the structure of the model to provide an explanation. The prediction is equal to:
-918 (the intercept)
- 80.3 * 31.3 (Length1)
+ 79.9 * 34 (Length2)
+ 32.5 * 39.5 (Length3)
+ 5.3 * 15.1 (Height)
- 0.5 * 5.6 (Width)
------------
646.12
A way we might display how the various features contribute to the overall prediction is this:
def fish_to_feats(a_fish, model):
feats = a_fish.copy()
feats["Intercept"] = 1.0
for species_feat in model.params.index:
if not species_feat.startswith("C(Species)"):
continue
species = species_feat.split(".")[1].replace("]", "") # This is ugly
feats[species_feat] = (feats["Species"] == species).astype(int)
del feats["Species"]
return feats[model.params.index]
contributions = fish_to_feats(new_fish, model) * model.params
for name, amount in sorted(
contributions.round(1).iteritems(), key=lambda t: -t[1].abs()[0]
):
if -1e-3 < amount[0] < 1e3:
continue
print(f"{name}: {amount[0]}")
Which provides the following output:
Length2: 2716.2
Length1: -2513.5
Length3: 1285.1
Intercept: -918.3
Width: -2.9
This could certainly be made more user friendly, but it does give some kind of explanation for why the model believes this fish to weigh 646 grams.
We've built a model that can provide trustworthy predictions. For example:
Length2
(pushes the
prediction higher), Length1
(pushes it lower), and Length3
(pushes it
higher).I highly recommend attacking machine learning problems by starting with an incredibly simple model first. Implementing that end-to-end enables focus on the truly difficult parts of machine learning (i.e. not the ML bits). For some use cases, this post provides yet another reason to love linear models: they are trustworthy by default!
Comments? Questions? Concerns? Please tweet me @SamuelDataT or email me. Thanks!