Text Classification with Scikit-Learn

In a previous article I wrote about a recent request from a client to classify short pieces of text. We started out with the simplest thing possible, which in that case was to use a 3rd party API.

In this article we talk about using the next simplest approach which TF-IDF with basic classifiers from Scikit-Learn (sklearn). We show that with minimal processing and no parameter tuning at all we get the following accuracies:

  • 68% with Naive Bayes
  • 78% with Support Vector Machine (w/ SGD)
  • 49% with a Random Forest
  • 75% with Logistic Regression
  • 53% with K-Nearest Neighbors

The main take away here is that the 3rd party API did better (80% on this particular task) than our best simple classifiers (78% SVM and 75% Logistic Regression) out of the box. However, each one of these classifiers can be improved significantly with additional parameter tuning.

All of these algorithms will perform differently with your data and the decision on if tuning and hosting your own models is worth the improvement is up to your specific needs. Tuning and hosting will be the subject of a future articles.

The code

Lets take a quick look at how we can use the various classifiers from sklearn. For background on the data set see this article.

We need to load the data without the headers, footers and quotes. We'll do basic clean up and remove posts that are less than 50 characters as those are likely to be too short for us to use. We don't truncate long texts since these algorithms do not have that requirement.

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

import json
import requests
import pickle

import sklearn
from sklearn.datasets import fetch_20newsgroups
from sklearn.model_selection import train_test_split

from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.naive_bayes import MultinomialNB
from sklearn.linear_model import SGDClassifier
from sklearn.ensemble import RandomForestClassifier

def clean_line(t):
    return (t.replace('\n',' ')
            .replace('\r',' ')
            .replace('\t',' ')
            .replace('  ',' ')
def load_and_process_data():
    min_len = 50
    all_data = fetch_20newsgroups(subset='all',remove=('headers', 'footers', 'quotes'))
    all_text = [clean_line(t) for t in all_data.data]
    all_data_df = pd.DataFrame({'text' : all_text, 'topics' : all_data.target})
    cleaned_df = all_data_df[all_data_df.text.str.len() > min_len]

    X_raw = cleaned_df['text'].values
    y_raw = cleaned_df['topics'].values
    X_train_raw, X_test_raw, y_train, y_test = train_test_split(X_raw,y_raw,random_state = 42, test_size=0.20)

    tfidf = TfidfVectorizer()
    X_train_tfidf = tfidf.fit_transform(X_train_raw)
    X_test_tfidf = tfidf.transform(X_test_raw)

    return X_train_tfidf, X_test_tfidf, y_train, y_test
X_train, X_test, y_train, y_test = load_and_process_data()

Now lets try a Naive Bayes classifier which gets an accuracy of 0.68232662192393734.

nb_clf = MultinomialNB()
nb_clf.fit(X_train, y_train)
predicted = nb_clf.predict(X_test)
np.mean(predicted == y_test)

SGDClassifier which is an SVM that uses SGD gets 0.78607382550335569

sgd_clf = SGDClassifier()
sgd_clf.fit(X_train, y_train)
predicted = sgd_clf.predict(X_test)
np.mean(predicted == y_test)

The Random Forest classifier with the default parameters (only 10 trees) gets 0.49189038031319909.

rf_clf = RandomForestClassifier()
rf_clf.fit(X_train, y_train)
predicted = rf_clf.predict(X_test)
np.mean(predicted == y_test)

The crowd favorite Logistic Regression gets 0.75531319910514538.

lr_clf = LogisticRegression()
lr_clf.fit(X_train, y_train)
predicted = lr_clf.predict(X_test)
np.mean(predicted == y_test)

And the simplest of all K Nearest Neighbors classifier with the default of 5 neighbors gets 0.53691275167785235.

kn_clf = KNeighborsClassifier()
kn_clf.fit(X_train, y_train)
predicted = kn_clf.predict(X_test)
np.mean(predicted == y_test)


We looked at performance of five common classifiers from sklearn using the least amount of programming and tuning possible. The performance of two of them come close to the 3rd party API but all can be improved with further tuning.

Each classifier will work differently on your particular data and with different hyper-parameters so testing with your own use case is critical. In a future article we'll look at how to go about tuning these classifiers to get even better results.

Let me know if you have any questions or comments.


Text Classification with IBM Watson Tuning a Text Classification Algorithm

Want to get notified of new articles and insights?