Text Classification with IBM Watson

Recently I had a request from a client for help classifying short pieces of text. The exact nature of the text is confidential but the passages were similar to paragraphs from reviews and comments users write about products and services.

We wanted a quick and easy way to establish a baseline we could use to compare various approaches. This would help us decide, based on performance and expected necessary investment, if further efforts in research, development and operations were necessary. We looked at a couple of options and decided to investigate the IBM Watson NLP Text Classification API. In the end, though the API has a couple of limitations, we were very happy with the results.

What is text classification

First a little background. Text Classification is the processes of assigning classes to documents based on their content or meaning. You get to decide on the classes, also called categories or labels, based on your needs. And a document can be any piece of text such as book, play, white paper, article, email, comment, review, support request, product description, etc.

The Watson API lets you create a classifier service using labeled training data. The easiest way to do that is to POST a csv file with one sample per line where the first column is the text and the second the label. One thing to note, is that the Watson API seems to be targeted at short texts, perhaps for chatbots and similar applications, so the samples need to be less than 1024 characters long. Also you need to remove unquoted new lines, carriage returns and tabs.

When you create the classifier service it can take a little while to train so you need to be patient check its status periodically. Once it is ready you can and make requests to classify new texts one at a time or in batches of up to 30. The API returns a rich JSON object with the text, the most likely class and the probability for the top 10 classes. In this example we only examine the most likely class but the response lets will let us use the top N and show them to our user.

Our example dataset

Since I can't show you the client data lets do a walk through with another common easily accessible dataset you might be familiar with. The 20 newsgroup dataset from Ken Lang. It is a collection of approximately 20,000 newsgroup posts from 20 different Usenet newsgroups.

The usual task is to label each post with the group that it was posted to based on the content of the post. There are a couple of things to keep in mind. First, the posts can be wildly off topic relative to the newsgroup title and/or can have vague and neutral text. For example a post might just be "I agree" or something similar. So it is important to think about what our error/accuracy metric means and what performance level we expect/need.

Second, meta information from the post (headers, signatures) can give clues that we don't want to use in our example. This brings up the idea of data leakage which means using information that you don't really want to use in the task at hand. Usually because the information is in the training data but not in the test data and is something to keep a close eye on.

Luckily the data set we used for the example, sklearn.datasets.fetch_20newsgroups gives you the option to omit headers, footers and quotes so we can focus more on the content of the text.

The code

Lets take a quick look at the result we get back from the API and the code we used to get that result. Here we see that we asked Watson to classify a piece of text and it believes it is in class 16 with 67% confidence. The second most likely class, 18, only has a 15% confidence level. Based on your use case this can be a pretty handy piece of information to have available.

    "text": "Your text here.",
    "top_class": "16",
    "classes": [
            "class_name": "16",
            "confidence": 0.6727412592135023
            "class_name": "18",
            "confidence": 0.14018533056228114
            "class_name": "8",
            "confidence": 0.0352981572387507

// 7 other classes deleted to save space


This sample code uses Python along with some common data science libraries.

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

import json
import requests
import pickle

import sklearn
from sklearn.datasets import fetch_20newsgroups
from sklearn.model_selection import train_test_split

We need to load the data without the headers, footers and quotes. We'll also remove posts that are less than 50 characters as those are likely to be too short for us to use. We then truncate all long lines to 1024 characters to meet the API requirements and split the data in to train and test sets.

def clean_line(t):
    return (t.replace('\n',' ')
            .replace('\r',' ')
            .replace('\t',' ')
            .replace('  ',' ')
def load_and_clean_data():
    min_len = 50
    all_data = fetch_20newsgroups(subset='all',remove=('headers', 'footers', 'quotes'))
    all_text = [clean_line(t) for t in all_data.data]
    all_data_df = pd.DataFrame({'text' : all_text, 'topics' : all_data.target})
    cleaned_df = all_data_df[all_data_df.text.str.len() > min_len]

    return cleaned_df
newsgroup_df = load_and_clean_data()

X = np.array([t[:1024] for t in newsgroup_df['text'].values])
y = newsgroup_df['topics'].values

X_train, X_test, y_train, y_test = train_test_split(X,y,random_state = 42, test_size=0.20)

The simplest way to write a csv is probably with Pandas so we'll do that.

train_df = pd.DataFrame({'text' : X_train, 'topics' : y_train})
test_df = pd.DataFrame({'text' : X_test, 'topics' : y_test})


You could use Python to POST the CSV, which would be consistent :) but I ended up using curl for other reasons.

curl -i --user your_user_id:your_passord \
     -F training_data=@./train.csv \
     -F training_metadata="{\"language\":\"en\",\"name\":\"TextClassificationTest\"}" \

Then back in Python you can check the status and classify a batch of texts. When you create a service in the Watson console you're given a username and password to use with the API. These credentials are for the API only and are not the ones you use to log in.

user = 'your project specific user id'
pwd = 'your project specific password'
classifier_id = 'your project specific classifier id'

def get_status():
    r = requests.get("https://gateway.watsonplatform.net/natural-language-classifier/api/v1/classifiers/" + classifier_id,
    return r.json()
def ask_watson(texts):
    payload = {'collection' : [{'text' : t} for t in texts]}
    r = requests.post('https://gateway.watsonplatform.net/natural-language-classifier/api/v1/classifiers/' + classifier_id + '/classify_collection', 
    status = r.status_code    
    return (status == 200, r.json() if status == 200 else f'Request status {status}: {r.text}')

For this example, we need predictions for the test set so we request them all in batches of 30.

batch_size = 30
responses = []

for i in range(0,len(X_test),batch_size):
    ok, response = ask_watson(X_test[i:i+batch_size])
    if ok:
        print('Unexpected status',response)

preds = [int(r['top_class']) for r in responses]

print('Accuracy: ',np.mean(preds == y_test))

And that's it! This super simple approach gives us 80.8% accuracy on our baseline example with out any tweaking from us.


Using this simple approach we got an accuracy of 80.8% which is impressive as a baseline. A custom approach should be able to improve on this quite a bit but you'd need to invest in building it, deploying it and keeping it running in production. Whether the ROI is there depends on your particular business use case.

I hope you found this useful and it sparked some ideas for tasks you can intelligently automate with text classification such as routing and organizing support requests and emails. I hope to describe some other approaches and results in future articles.

Let me know if you have any questions or comments.


Text Classification with Scikit-learn

Tuning a Text Classification Algorithm

Want to be automatically notified of more articles like this?