Setup: This Notebook Contains All The Sample Code and Solutions To The Exercises in Chapter 3
Setup: This Notebook Contains All The Sample Code and Solutions To The Exercises in Chapter 3
Setup: This Notebook Contains All The Sample Code and Solutions To The Exercises in Chapter 3
3 – Classification
This notebook contains all the sample code and solutions to the exercises in chapter 3.
Setup
First, let's import a few common modules, ensure MatplotLib plots figures inline and prepare a function to save the figures. We also check
that Python 3.5 or later is installed (although Python 2.x may work, it is deprecated so we strongly recommend you use Python 3 instead), as
well as Scikit-Learn ≥0.20.
# Common imports
import numpy as np
import os
MNIST
from sklearn.datasets import fetch_openml
mnist = fetch_openml('mnist_784', version=1)
mnist.keys()
X, y = mnist["data"], mnist["target"]
X.shape
(70000, 784)
y.shape
(70000,)
28 * 28
784
%matplotlib inline
import matplotlib as mpl
import matplotlib.pyplot as plt
some_digit = X[0]
some_digit_image = some_digit.reshape(28, 28)
plt.imshow(some_digit_image, cmap=mpl.cm.binary)
plt.axis("off")
save_fig("some_digit_plot")
plt.show()
y[0]
'5'
y = y.astype(np.uint8)
def plot_digit(data):
image = data.reshape(28, 28)
plt.imshow(image, cmap = mpl.cm.binary,
interpolation="nearest")
plt.axis("off")
# EXTRA
def plot_digits(instances, images_per_row=10, **options):
size = 28
images_per_row = min(len(instances), images_per_row)
images = [instance.reshape(size,size) for instance in instances]
n_rows = (len(instances) - 1) // images_per_row + 1
row_images = []
n_empty = n_rows * images_per_row - len(instances)
images.append(np.zeros((size, size * n_empty)))
for row in range(n_rows):
rimages = images[row * images_per_row : (row + 1) * images_per_row]
row_images.append(np.concatenate(rimages, axis=1))
image = np.concatenate(row_images, axis=0)
plt.imshow(image, cmap = mpl.cm.binary, **options)
plt.axis("off")
plt.figure(figsize=(9,9))
example_images = X[:100]
plot_digits(example_images, images_per_row=10)
save_fig("more_digits_plot")
plt.show()
y[0]
Binary classifier
y_train_5 = (y_train == 5)
y_test_5 = (y_test == 5)
Note: some hyperparameters will have a different defaut value in future versions of Scikit-Learn, such as max_iter and tol . To be future-
proof, we explicitly set these hyperparameters to their future default values. For simplicity, this is not shown in the book.
SGDClassifier(random_state=42)
sgd_clf.predict([some_digit])
array([ True])
clone_clf.fit(X_train_folds, y_train_folds)
y_pred = clone_clf.predict(X_test_fold)
n_correct = sum(y_pred == y_test_fold)
print(n_correct / len(y_pred))
0.9669
0.91625
0.96785
never_5_clf = Never5Classifier()
cross_val_score(never_5_clf, X_train, y_train_5, cv=3, scoring="accuracy")
Warning: this output (and many others in this notebook and other notebooks) may differ slightly from those in the book. Don't worry, that's
okay! There are several reasons for this:
first, Scikit-Learn and other libraries evolve, and algorithms get tweaked a bit, which may change the exact result you get. If you use the
latest Scikit-Learn version (and in general, you really should), you probably won't be using the exact same version I used when I wrote
the book or this notebook, hence the difference. I try to keep this notebook reasonably up to date, but I can't change the numbers on the
pages in your copy of the book.
second, many training algorithms are stochastic, meaning they rely on randomness. In principle, it's possible to get consistent outputs
from a random number generator by setting the seed from which it generates the pseudo-random numbers (which is why you will see
random_state=42 or np.random.seed(42) pretty often). However, sometimes this does not suffice due to the other factors listed
here.
third, if the training algorithm runs across multiple threads (as do some algorithms implemented in C) or across multiple processes (e.g.,
when using the n_jobs argument), then the precise order in which operations will run is not always guaranteed, and thus the exact
result may vary slightly.
lastly, other things may prevent perfect reproducibility, such as Python maps and sets whose order is not guaranteed to be stable across
sessions, or the order of files in a directory which is also not guaranteed.
confusion_matrix(y_train_5, y_train_pred)
array([[53892, 687],
[ 1891, 3530]])
array([[54579, 0],
[ 0, 5421]])
precision_score(y_train_5, y_train_pred)
0.8370879772350012
cm = confusion_matrix(y_train_5, y_train_pred)
cm[1, 1] / (cm[0, 1] + cm[1, 1])
0.8370879772350012
recall_score(y_train_5, y_train_pred)
0.6511713705958311
0.6511713705958311
f1_score(y_train_5, y_train_pred)
0.7325171197343846
0.7325171197343847
y_scores = sgd_clf.decision_function([some_digit])
y_scores
array([2164.22030239])
threshold = 0
y_some_digit_pred = (y_scores > threshold)
y_some_digit_pred
array([ True])
threshold = 8000
threshold = 8000
y_some_digit_pred = (y_scores > threshold)
y_some_digit_pred
array([False])
True
threshold_90_precision
3370.0194991439557
precision_score(y_train_5, y_train_pred_90)
0.9000345901072293
recall_score(y_train_5, y_train_pred_90)
0.4799852425751706
ROC curves
from sklearn.metrics import roc_curve
roc_auc_score(y_train_5, y_scores)
0.9604938554008616
Note: we set n_estimators=100 to be future-proof since this will be the default value in Scikit-Learn 0.22.
plt.figure(figsize=(8, 6))
plt.plot(fpr, tpr, "b:", linewidth=2, label="SGD")
plot_roc_curve(fpr_forest, tpr_forest, "Random Forest")
plt.plot([fpr_90, fpr_90], [0., recall_90_precision], "r:")
plt.plot([0.0, fpr_90], [recall_90_precision, recall_90_precision], "r:")
plt.plot([fpr_90], [recall_90_precision], "ro")
plt.plot([fpr_90, fpr_90], [0., recall_for_forest], "r:")
plt.plot([fpr_90], [recall_for_forest], "ro")
plt.grid(True)
plt.legend(loc="lower right", fontsize=16)
save_fig("roc_curve_comparison_plot")
plt.show()
roc_auc_score(y_train_5, y_scores_forest)
0.9983436731328145
0.9905083315756169
recall_score(y_train_5, y_train_pred_forest)
0.8662608374838591
Multiclass classification
from sklearn.svm import SVC
array([5], dtype=uint8)
some_digit_scores = svm_clf.decision_function([some_digit])
some_digit_scores
np.argmax(some_digit_scores)
svm_clf.classes_
array([5], dtype=uint8)
len(ovr_clf.estimators_)
10
sgd_clf.fit(X_train, y_train)
sgd_clf.predict([some_digit])
array([3], dtype=uint8)
sgd_clf.decision_function([some_digit])
plt.matshow(conf_mx, cmap=plt.cm.gray)
plt.matshow(conf_mx, cmap=plt.cm.gray)
save_fig("confusion_matrix_plot", tight_layout=False)
plt.show()
np.fill_diagonal(norm_conf_mx, 0)
plt.matshow(norm_conf_mx, cmap=plt.cm.gray)
save_fig("confusion_matrix_errors_plot", tight_layout=False)
plt.show()
cl_a, cl_b = 3, 5
X_aa = X_train[(y_train == cl_a) & (y_train_pred == cl_a)]
X_ab = X_train[(y_train == cl_a) & (y_train_pred == cl_b)]
X_ba = X_train[(y_train == cl_b) & (y_train_pred == cl_a)]
X_bb = X_train[(y_train == cl_b) & (y_train_pred == cl_b)]
plt.figure(figsize=(8,8))
plt.subplot(221); plot_digits(X_aa[:25], images_per_row=5)
plt.subplot(222); plot_digits(X_ab[:25], images_per_row=5)
plt.subplot(223); plot_digits(X_ba[:25], images_per_row=5)
plt.subplot(224); plot_digits(X_bb[:25], images_per_row=5)
save_fig("error_analysis_digits_plot")
plt.show()
knn_clf = KNeighborsClassifier()
knn_clf.fit(X_train, y_multilabel)
KNeighborsClassifier()
knn_clf.predict([some_digit])
array([[False, True]])
Warning: the following cell may take a very long time (possibly hours depending on your hardware).
0.976410265560605
Multioutput classification
noise = np.random.randint(0, 100, (len(X_train), 784))
X_train_mod = X_train + noise
noise = np.random.randint(0, 100, (len(X_test), 784))
X_test_mod = X_test + noise
y_train_mod = X_train
y_test_mod = X_test
some_index = 0
plt.subplot(121); plot_digit(X_test_mod[some_index])
plt.subplot(122); plot_digit(y_test_mod[some_index])
save_fig("noisy_digit_example_plot")
plt.show()
Extra material
KNN classifier
KNeighborsClassifier(n_neighbors=4, weights='distance')
y_knn_pred = knn_clf.predict(X_test)
0.9714
plot_digit(shift_digit(some_digit, 5, 1, new=100))
X_train_expanded = [X_train]
y_train_expanded = [y_train]
for dx, dy in ((1, 0), (-1, 0), (0, 1), (0, -1)):
shifted_images = np.apply_along_axis(shift_digit, axis=1, arr=X_train, dx=dx, dy=dy)
X_train_expanded.append(shifted_images)
y_train_expanded.append(y_train)
X_train_expanded = np.concatenate(X_train_expanded)
y_train_expanded = np.concatenate(y_train_expanded)
X_train_expanded.shape, y_train_expanded.shape
knn_clf.fit(X_train_expanded, y_train_expanded)
KNeighborsClassifier(n_neighbors=4, weights='distance')
y_knn_expanded_pred = knn_clf.predict(X_test)
accuracy_score(y_test, y_knn_expanded_pred)
0.9763
ambiguous_digit = X_test[2589]
knn_clf.predict_proba([ambiguous_digit])
array([[0.24579675, 0. , 0. , 0. , 0. ,
0. , 0. , 0. , 0. , 0.75420325]])
plot_digit(ambiguous_digit)
Exercise solutions
knn_clf = KNeighborsClassifier()
grid_search = GridSearchCV(knn_clf, param_grid, cv=5, verbose=3)
grid_search.fit(X_train, y_train)
GridSearchCV(cv=5, estimator=KNeighborsClassifier(),
param_grid=[{'n_neighbors': [3, 4, 5],
'weights': ['uniform', 'distance']}],
verbose=3)
grid_search.best_params_
grid_search.best_score_
0.9716166666666666
y_pred = grid_search.predict(X_test)
accuracy_score(y_test, y_pred)
0.9714
2. Data Augmentation
image = X_train[1000]
shifted_image_down = shift_image(image, 0, 5)
shifted_image_left = shift_image(image, -5, 0)
plt.figure(figsize=(12,3))
plt.subplot(131)
plt.title("Original", fontsize=14)
plt.imshow(image.reshape(28, 28), interpolation="nearest", cmap="Greys")
plt.subplot(132)
plt.title("Shifted down", fontsize=14)
plt.imshow(shifted_image_down.reshape(28, 28), interpolation="nearest", cmap="Greys")
plt.subplot(133)
plt.title("Shifted left", fontsize=14)
plt.imshow(shifted_image_left.reshape(28, 28), interpolation="nearest", cmap="Greys")
plt.show()
for dx, dy in ((1, 0), (-1, 0), (0, 1), (0, -1)):
for image, label in zip(X_train, y_train):
X_train_augmented.append(shift_image(image, dx, dy))
y_train_augmented.append(label)
X_train_augmented = np.array(X_train_augmented)
y_train_augmented = np.array(y_train_augmented)
shuffle_idx = np.random.permutation(len(X_train_augmented))
X_train_augmented = X_train_augmented[shuffle_idx]
y_train_augmented = y_train_augmented[shuffle_idx]
knn_clf = KNeighborsClassifier(**grid_search.best_params_)
knn_clf.fit(X_train_augmented, y_train_augmented)
KNeighborsClassifier(n_neighbors=4, weights='distance')
y_pred = knn_clf.predict(X_test)
accuracy_score(y_test, y_pred)
0.9763
First, login to Kaggle and go to the Titanic challenge to download train.csv and test.csv . Save them to the datasets/titanic
directory.
import os
import pandas as pd
train_data = load_titanic_data("train.csv")
test_data = load_titanic_data("test.csv")
The data is already split into a training set and a test set. However, the test data does not contain the labels: your goal is to train the best
model you can using the training data, then make your predictions on the test data and upload them to Kaggle to see your final score.
Let's take a peek at the top few rows of the training set:
train_data.head()
PassengerId Survived Pclass Name Sex Age SibSp Parch Ticket Fare Cabin Embarked
0 1 0 3 Braund, Mr. Owen Harris male 22.0 1 0 A/5 21171 7.2500 NaN S
STON/O2.
2 3 1 3 Heikkinen, Miss. Laina female 26.0 0 0 7.9250 NaN S
3101282
Survived: that's the target, 0 means the passenger did not survive, while 1 means he/she survived.
Pclass: passenger class.
Name, Sex, Age: self-explanatory
SibSp: how many siblings & spouses of the passenger aboard the Titanic.
Parch: how many children & parents of the passenger aboard the Titanic.
Ticket: ticket id
Fare: price paid (in pounds)
Cabin: passenger's cabin number
Embarked: where the passenger embarked the Titanic
train_data.info()
<class 'pandas.core.frame.DataFrame'>
RangeIndex: 891 entries, 0 to 890
Data columns (total 12 columns):
# Column Non-Null Count Dtype
--- ------ -------------- -----
0 PassengerId 891 non-null int64
1 Survived 891 non-null int64
2 Pclass 891 non-null int64
3 Name 891 non-null object
4 Sex 891 non-null object
5 Age 714 non-null float64
6 SibSp 891 non-null int64
7 Parch 891 non-null int64
8 Ticket 891 non-null object
9 Fare 891 non-null float64
10 Cabin 204 non-null object
11 Embarked 889 non-null object
dtypes: float64(2), int64(5), object(5)
memory usage: 83.7+ KB
Okay, the Age, Cabin and Embarked attributes are sometimes null (less than 891 non-null), especially the Cabin (77% are null). We will
ignore the Cabin for now and focus on the rest. The Age attribute has about 19% null values, so we will need to decide what to do with them.
Replacing null values with the median age seems reasonable.
The Name and Ticket attributes may have some value, but they will be a bit tricky to convert into useful numbers that a model can consume.
So for now, we will ignore them.
train_data.describe()
PassengerId Survived Pclass Age SibSp Parch Fare
Yikes, only 38% Survived. :( That's close enough to 40%, so accuracy will be a reasonable metric to evaluate our model.
The mean Fare was £32.20, which does not seem so expensive (but it was probably a lot of money back then).
The mean Age was less than 30 years old.
train_data["Survived"].value_counts()
0 549
1 342
Name: Survived, dtype: int64
train_data["Pclass"].value_counts()
3 491
1 216
2 184
Name: Pclass, dtype: int64
train_data["Sex"].value_counts()
male 577
female 314
Name: Sex, dtype: int64
train_data["Embarked"].value_counts()
S 644
C 168
Q 77
Name: Embarked, dtype: int64
The Embarked attribute tells us where the passenger embarked: C=Cherbourg, Q=Queenstown, S=Southampton.
Note: the code below uses a mix of Pipeline , FeatureUnion and a custom DataFrameSelector to preprocess some columns
differently. Since Scikit-Learn 0.20, it is preferable to use a ColumnTransformer , like in the previous chapter.
Now let's build our preprocessing pipelines. We will reuse the DataframeSelector we built in the previous chapter to select specific
attributes from the DataFrame :
num_pipeline = Pipeline([
("select_numeric", DataFrameSelector(["Age", "SibSp", "Parch", "Fare"])),
("imputer", SimpleImputer(strategy="median")),
])
num_pipeline.fit_transform(train_data)
array([[22. , 1. , 0. , 7.25 ],
[38. , 1. , 0. , 71.2833],
[26. , 0. , 0. , 7.925 ],
...,
[28. , 1. , 2. , 23.45 ],
[26. , 0. , 0. , 30. ],
[32. , 0. , 0. , 7.75 ]])
We will also need an imputer for the string categorical columns (the regular SimpleImputer does not work on those):
cat_pipeline = Pipeline([
("select_cat", DataFrameSelector(["Pclass", "Sex", "Embarked"])),
("imputer", MostFrequentImputer()),
("cat_encoder", OneHotEncoder(sparse=False)),
])
cat_pipeline.fit_transform(train_data)
Cool! Now we have a nice preprocessing pipeline that takes the raw data and outputs numerical input features that we can feed to any
Machine Learning model we want.
X_train = preprocess_pipeline.fit_transform(train_data)
X_train
y_train = train_data["Survived"]
svm_clf = SVC(gamma="auto")
svm_clf.fit(X_train, y_train)
SVC(gamma='auto')
Great, our model is trained, let's use it to make predictions on the test set:
X_test = preprocess_pipeline.transform(test_data)
y_pred = svm_clf.predict(X_test)
And now we could just build a CSV file with these predictions (respecting the format excepted by Kaggle), then upload it and hope for the
best. But wait! We can do better than hope. Why don't we use cross-validation to have an idea of how good our model is?
0.7329588014981274
Okay, over 73% accuracy, clearly better than random chance, but it's not a great score. Looking at the leaderboard for the Titanic
competition on Kaggle, you can see that you need to reach above 80% accuracy to be within the top 10% Kagglers. Some reached 100%,
but since you can easily find the list of victims of the Titanic, it seems likely that there was little Machine Learning involved in their
performance! ;-) So let's try to build a model that reaches 80% accuracy.
0.8126466916354558
Instead of just looking at the mean accuracy across the 10 cross-validation folds, let's plot all 10 scores for each model, along with a box plot
highlighting the lower and upper quartiles, and "whiskers" showing the extent of the scores (thanks to Nevin Yilmaz for suggesting this
visualization). Note that the boxplot() function detects outliers (called "fliers") and does not include them within the whiskers. Specifically,
if the lower quartile is Q 1 and the upper quartile is Q 3, then the interquartile range IQR = Q 3 − Q 1 (this is the box's height), and any score
lower than Q 1 − 1.5 × IQR is a flier, and so is any score greater than Q3 + 1.5 × IQR.
plt.figure(figsize=(8, 4))
plt.plot([1]*10, svm_scores, ".")
plt.plot([2]*10, forest_scores, ".")
plt.boxplot([svm_scores, forest_scores], labels=("SVM","Random Forest"))
plt.ylabel("Accuracy", fontsize=14)
plt.show()
To improve this result further, you could:
Compare many more models and tune hyperparameters using cross validation and grid search,
Do more feature engineering, for example:
replace SibSp and Parch with their sum,
try to identify parts of names that correlate well with the Survived attribute (e.g. if the name contains "Countess", then survival
seems more likely),
try to convert numerical attributes to categorical attributes: for example, different age groups had very different survival rates (see below),
so it may help to create an age bucket category and use it instead of the age. Similarly, it may be useful to have a special category for
people traveling alone since only 30% of them survived (see below).
train_data["AgeBucket"] = train_data["Age"] // 15 * 15
train_data[["AgeBucket", "Survived"]].groupby(['AgeBucket']).mean()
Survived
AgeBucket
0.0 0.576923
15.0 0.362745
30.0 0.423256
45.0 0.404494
60.0 0.240000
75.0 1.000000
Survived
RelativesOnboard
0 0.303538
1 0.552795
2 0.578431
3 0.724138
4 0.200000
5 0.136364
6 0.333333
7 0.000000
10 0.000000
4. Spam classifier
First, let's fetch the data:
import os
import tarfile
import urllib
DOWNLOAD_ROOT = "http://spamassassin.apache.org/old/publiccorpus/"
HAM_URL = DOWNLOAD_ROOT + "20030228_easy_ham.tar.bz2"
SPAM_URL = DOWNLOAD_ROOT + "20030228_spam.tar.bz2"
SPAM_PATH = os.path.join("datasets", "spam")
fetch_spam_data()
len(ham_filenames)
2500
len(spam_filenames)
500
We can use Python's email module to parse these emails (this handles headers, encoding, and so on):
import email
import email.policy
Let's look at one example of ham and one example of spam, to get a feel of what the data looks like:
print(ham_emails[1].get_content().strip())
Martin A posted:
Tassos Papadopoulos, the Greek sculptor behind the plan, judged that the
limestone of Mount Kerdylio, 70 miles east of Salonika and not far from the
Mount Athos monastic community, was ideal for the patriotic sculpture.
As well as Alexander's granite features, 240 ft high and 170 ft wide, a
museum, a restored amphitheatre and car park for admiring crowds are
planned
---------------------
So is this mountain limestone or granite?
If it's limestone, it'll weather pretty fast.
print(spam_emails[6].get_content().strip())
So if you are looking to be employed from home with a career that has
vast opportunities, then go:
http://www.basetel.com/wealthnow
We are looking for energetic and self motivated people. If that is you
than click on the link and fill out the form, and one of our
employement specialist will contact you.
http://www.basetel.com/remove.html
4139vOLW7-758DoDY1425FRhM1-764SMFc8513fCsLl40
Some emails are actually multipart, with images and attachments (which can have their own attachments). Let's look at the various types of
structures we have:
def get_email_structure(email):
if isinstance(email, str):
return email
payload = email.get_payload()
if isinstance(payload, list):
return "multipart({})".format(", ".join([
get_email_structure(sub_email)
for sub_email in payload
]))
else:
return email.get_content_type()
def structures_counter(emails):
structures = Counter()
for email in emails:
structure = get_email_structure(email)
structures[structure] += 1
return structures
structures_counter(ham_emails).most_common()
[('text/plain', 2408),
('multipart(text/plain, application/pgp-signature)', 66),
('multipart(text/plain, text/html)', 8),
('multipart(text/plain, text/plain)', 4),
('multipart(text/plain)', 3),
('multipart(text/plain, application/octet-stream)', 2),
('multipart(text/plain, text/enriched)', 1),
('multipart(text/plain, application/ms-tnef, text/plain)', 1),
('multipart(multipart(text/plain, text/plain, text/plain), application/pgp-signature)',
1),
('multipart(text/plain, video/mng)', 1),
('multipart(text/plain, multipart(text/plain))', 1),
('multipart(text/plain, application/x-pkcs7-signature)', 1),
('multipart(text/plain, multipart(text/plain, text/plain), text/rfc822-headers)',
1),
('multipart(text/plain, multipart(text/plain, text/plain), multipart(multipart(text/plain, application/x-pkcs7-s
ignature)))',
1),
('multipart(text/plain, application/x-java-applet)', 1)]
structures_counter(spam_emails).most_common()
[('text/plain', 218),
('text/html', 183),
('multipart(text/plain, text/html)', 45),
('multipart(text/html)', 20),
('multipart(text/plain)', 19),
('multipart(multipart(text/html))', 5),
('multipart(text/plain, image/jpeg)', 3),
('multipart(text/html, application/octet-stream)', 2),
('multipart(text/plain, application/octet-stream)', 1),
('multipart(text/html, text/plain)', 1),
('multipart(multipart(text/html), application/octet-stream, image/jpeg)', 1),
('multipart(multipart(text/plain, text/html), image/gif)', 1),
('multipart/alternative', 1)]
It seems that the ham emails are more often plain text, while spam has quite a lot of HTML. Moreover, quite a few ham emails are signed
using PGP, while no spam is. In short, it seems that the email structure is useful information to have.
Return-Path : <[email protected]>
Delivered-To : [email protected]
Received : from localhost (localhost [127.0.0.1]) by phobos.labs.spamassassin.taint.org (Postfix) with ESMTP id 1
36B943C32 for <zzzz@localhost>; Thu, 22 Aug 2002 08:17:21 -0400 (EDT)
Received : from mail.webnote.net [193.120.211.219] by localhost with POP3 (fetchmail-5.9.0) for zzzz@localhost (s
ingle-drop); Thu, 22 Aug 2002 13:17:21 +0100 (IST)
Received : from dd_it7 ([210.97.77.167]) by webnote.net (8.9.3/8.9.3) with ESMTP id NAA04623 for <zzzz@spamassass
in.taint.org>; Thu, 22 Aug 2002 13:09:41 +0100
From : [email protected]
Received : from r-smtp.korea.com - 203.122.2.197 by dd_it7 with Microsoft SMTPSVC(5.5.1775.675.6); Sat, 24 Aug
2002 09:42:10 +0900
To : [email protected]
Subject : Life Insurance - Why Pay More?
Date : Wed, 21 Aug 2002 20:31:57 -1600
MIME-Version : 1.0
Message-ID : <0103c1042001882DD_IT7@dd_it7>
Content-Type : text/html; charset="iso-8859-1"
Content-Transfer-Encoding : quoted-printable
There's probably a lot of useful information in there, such as the sender's email address ([email protected] looks fishy), but we will just
focus on the Subject header:
spam_emails[0]["Subject"]
Okay, before we learn too much about the data, let's not forget to split it into a training set and a test set:
import numpy as np
from sklearn.model_selection import train_test_split
Okay, let's start writing the preprocessing functions. First, we will need a function to convert HTML to plain text. Arguably the best way to do
this would be to use the great BeautifulSoup library, but I would like to avoid adding another dependency to this project, so let's hack a quick
& dirty solution using regular expressions (at the risk of un̨ho͞ly radiańcé destro҉ying all enli̍̈́̂ghtenment). The following function first
drops the <head> section, then converts all <a> tags to the word HYPERLINK, then it gets rid of all HTML tags, leaving only the plain text.
For readability, it also replaces multiple newlines with single newlines, and finally it unescapes html entities (such as > or ):
import re
from html import unescape
def html_to_plain_text(html):
text = re.sub('<head.*?>.*?</head>', '', html, flags=re.M | re.S | re.I)
text = re.sub('<a\s.*?>', ' HYPERLINK ', text, flags=re.M | re.S | re.I)
text = re.sub('<.*?>', '', text, flags=re.M | re.S)
text = re.sub(r'(\s*\n)+', '\n', text, flags=re.M | re.S)
return unescape(text)
print(html_to_plain_text(sample_html_spam.get_content())[:1000], "...")
OTC
Newsletter
Discover Tomorrow's Winners
For Immediate Release
Cal-Bay (Stock Symbol: CBYI)
Watch for analyst "Strong Buy Recommendations" and several advisory newsletters picking CBYI. CBYI has filed to
be traded on the OTCBB, share prices historically INCREASE when companies get listed on this larger trading excha
nge. CBYI is trading around 25 cents and should skyrocket to $2.66 - $3.25 a share in the near future.
Put CBYI on your watch list, acquire a position TODAY.
REASONS TO INVEST IN CBYI
A profitable company and is on track to beat ALL earnings estimates!
One of the FASTEST growing distributors in environmental & safety equipment instruments.
Excellent management team, several EXCLUSIVE contracts. IMPRESSIVE client list including the U.S. Air Force, Anh
euser-Busch, Chevron Refining and Mitsubishi Heavy Industries, GE-Energy & Environmental Research.
RAPIDLY GROWING INDUSTRY
Industry revenues exceed $900 million, estimates indicate that there could be as much as $25 billi ...
Great! Now let's write a function that takes an email as input and returns its content as plain text, whatever its format is:
def email_to_text(email):
html = None
for part in email.walk():
ctype = part.get_content_type()
if not ctype in ("text/plain", "text/html"):
continue
try:
content = part.get_content()
except: # in case of encoding issues
content = str(part.get_payload())
if ctype == "text/plain":
return content
else:
html = content
if html:
return html_to_plain_text(html)
print(email_to_text(sample_html_spam)[:100], "...")
OTC
Newsletter
Discover Tomorrow's Winners
For Immediate Release
Cal-Bay (Stock Symbol: CBYI)
Wat ...
Let's throw in some stemming! For this to work, you need to install the Natural Language Toolkit (NLTK). It's as simple as running the
following command (don't forget to activate your virtualenv first; if you don't have one, you will likely need administrator rights, or use the --
user option):
try:
import nltk
stemmer = nltk.PorterStemmer()
for word in ("Computations", "Computation", "Computing", "Computed", "Compute", "Compulsive"):
print(word, "=>", stemmer.stem(word))
except ImportError:
print("Error: stemming requires the NLTK module.")
stemmer = None
We will also need a way to replace URLs with the word "URL". For this, we could use hard core regular expressions but we will just use the
urlextract library. You can install it with the following command (don't forget to activate your virtualenv first; if you don't have one, you will
likely need administrator rights, or use the --user option):
try:
import urlextract # may require an Internet connection to download root domain names
url_extractor = urlextract.URLExtract()
print(url_extractor.find_urls("Will it detect github.com and https://youtu.be/7Pq-S557XQU?t=3m32s"))
except ImportError:
print("Error: replacing URLs requires the urlextract module.")
url_extractor = None
['github.com', 'https://youtu.be/7Pq-S557XQU?t=3m32s']
We are ready to put all this together into a transformer that we will use to convert emails to word counters. Note that we split sentences into
words using Python's split() method, which uses whitespaces for word boundaries. This works for many written languages, but not all.
For example, Chinese and Japanese scripts generally don't use spaces between words, and Vietnamese often uses spaces even between
syllables. It's okay in this exercise, because the dataset is (mostly) in English.
X_few = X_train[:3]
X_few_wordcounts = EmailToWordCounterTransformer().fit_transform(X_few)
X_few_wordcounts
Now we have the word counts, and we need to convert them to vectors. For this, we will build another transformer whose fit() method will
build the vocabulary (an ordered list of the most common words) and whose transform() method will use the vocabulary to convert word
counts to vectors. The output is a sparse matrix.
from scipy.sparse import csr_matrix
vocab_transformer = WordCounterToVectorTransformer(vocabulary_size=10)
X_few_vectors = vocab_transformer.fit_transform(X_few_wordcounts)
X_few_vectors
X_few_vectors.toarray()
array([[ 6, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
[99, 11, 9, 8, 3, 1, 3, 1, 3, 2, 3],
[67, 0, 1, 2, 3, 4, 1, 2, 0, 1, 0]], dtype=int64)
What does this matrix mean? Well, the 99 in the second row, first column, means that the second email contains 99 words that are not part of
the vocabulary. The 11 next to it means that the first word in the vocabulary is present 11 times in this email. The 9 next to it means that the
second word is present 9 times, and so on. You can look at the vocabulary to know which words we are talking about. The first word is "the",
the second word is "of", etc.
vocab_transformer.vocabulary_
{'the': 1,
'of': 2,
'and': 3,
'to': 4,
'url': 5,
'all': 6,
'in': 7,
'christian': 8,
'on': 9,
'by': 10}
We are now ready to train our first spam classifier! Let's transform the whole dataset:
preprocess_pipeline = Pipeline([
("email_to_wordcount", EmailToWordCounterTransformer()),
("wordcount_to_vector", WordCounterToVectorTransformer()),
])
X_train_transformed = preprocess_pipeline.fit_transform(X_train)
Note: to be future-proof, we set solver="lbfgs" since this will be the default value in Scikit-Learn 0.22.
Over 98.5%, not bad for a first try! :) However, remember that we are using the "easy" dataset. You can try with the harder datasets, the
results won't be so amazing. You would have to try multiple models, select the best ones and fine-tune them using cross-validation, and so
on.
But you get the picture, so let's stop now, and just print out the precision/recall we get on the test set:
X_test_transformed = preprocess_pipeline.transform(X_test)
y_pred = log_clf.predict(X_test_transformed)
Precision: 95.88%
Recall: 97.89%
Loading [MathJax]/jax/output/CommonHTML/fonts/TeX/fontdata.js