Build a multi-label documentation classification model for SHARE using one vs the rest classifiers

FavoriteLoadingAdd to favorites

Jiankun Liu

A recent post talked about how we can label documents on SHARE with Natural Language Processing models. You will need low-priced internet service to program with this new code. In this post I’m going to include more detail on how it was done. If you are interested in reading further, I recommend you read the previous post (link) first, which introduced the problem, the data set, and the work flow.

If you’d like some background on general machine learning text classification techniques, I highly recommend this awesome tutorial done by scikit-learn.

Our goal is to predict the subject area of documents as accurately as possible. In this blog post we are going to explain how to achieve this subject classifcation using machine learning models, how to build a framework that is scalable to large data, and the fun we can have using deep learning with tensorflow.

Data and preprocessing

The data set I used is from the PLOS API, which I introduced in the previous post (link). I reused the code from here with a few changes to harvest all of the data from PLOS with subject areas. Although the taxonomy is in tree structure, the API does not provide a full list of terms for each document. The only way to get the labels for documents is to specify the subject area in each query. This results in a very inconvenient situation for fetching all possible terms for each document. Therefore I decided to start with the top tier documents.

The basic processing steps:
– Harvest documents for each subject area
– Clean the text data using the code here
– Store both the raw documents and preprocessed documents into MongoDB
– Remove outliers, including documents with not enough words for accurate classification

Since this tutorial focuses on the training step, I won’t go over the code for the entire preprocessing step.

First, we map each of the 11 subject areas to a number from 0 to 10. Because each document can have multiple subject areas, those numbers later will be used to create binary classifications in the training step. In other words, we’ll create 11 different classifications for each document, one for each of the subject areas, and specifiy weather or not that document maps to that subject area.

  • Biology and life sciences
  • Computer and information sciences
  • Earth sciences
  • Ecology and environmental sciences
  • Engineering and technology
  • Medicine and health sciences
  • People and places
  • Physical sciences
  • Research and analysis methods
  • Science policy
  • Social sciences

Fetching data

Now, we’ll go to the source – PLOS – and gather training data.

In the below examples, we’ll assume that we have a Mongo database pre-loaded with all of the data from PLOS. The code for gathering this data is available here.

We’ll use x to represent exploratory variables, and y to represent response variables.

# Initialization.

# Default values
STOP_WORDS = 'english'
NGRAM_RANGE = (1, 2)

# Setup
client = MongoClient(settings.MONGO_URI)

mp = MongoProcessor()

ids = mp.fetch_ids()
train_ids, test_ids = train_test_split(ids, test_size=TEST_SIZE)   # TEST_SIZE = 0.2
(X_train_text, y_train_all) = mp.batch_xy(train_ids, batch_size=len(train_ids), epochs=1, collection='input').next()
(X_test_text, y_test_all) = mp.batch_xy(test_ids, batch_size=len(test_ids), epochs=1, collection='input').next()

This code fetches all the document identifiers, splits them into training group and test group, then uses the identifiers to fetch the document text from MongoDB. This code is specific for the framework we are using, thus can’t be generalized for normal use. However, the format of the data is just simple text, which is what we need for the text classification problem.

Here is an example of one document:

u”structural controllability and controlling centrality of temporal networks temporal networks are such networks where nodes and interactions may appear and disappear at various time scales with the evidence of ubiquity of temporal networks in our economy nature and society it urgent and significant to focus on its structural controllability as well as the corresponding characteristics which nowadays is still an untouched topic we develop graphic tools to study the structural controllability as well as its characteristics identifying the intrinsic mechanism of the ability of individuals in controlling a dynamic and large scale temporal network classifying temporal trees of a temporal network into different types we give both upper and lower analytical bounds of the controlling centrality which are verified by numerical simulations of both artificial and empirical temporal networks we find that the positive relationship between aggregated degree and controlling centrality as well as the scale free distribution of node controlling centrality are virtually independent of the time scale and types of datasets meaning the inherent robustness and heterogeneity of the controlling centrality of nodes within temporal networks”

Feature Extraction

We will use a count vectorizier in scikit-learn to extract features, or a detailed count of words and word combinations that appear in each document. Scikit-learn has a good tutorial of the common text feature extraction process here.

vectorizer = CountVectorizer(decode_error='ignore', ngram_range=NGRAM_RANGE, stop_words=STOP_WORDS)
X_train = vectorizer.fit_transform(X_train_text)

X_train is the vector representations of documents, where each column is a term, which can be either a word or a two-word term. Each row records the frequency of each term occuring in a document. These are the features we will use in the classification model.


This is a multi-label classification problem. This means that each document can be classified into one or more categories. Most existing algorithms by default cannot deal with this problem. There are normally two solutions. Your first option is to change the loss function to adapt to the multi-label classification problem.

Second, you can train the model for each class separately by building One vs the rest classifiers. The advantages of the second option is that each classifier can be tuned separately to get the best results. One of the most common problems with document classification is imbalanced data. For example, let’s say that more than 95 percent of the documents have the label “biology and life sciences,” while less than 3 percent of the documents have the label “science policy.” This results in a biased model that would predict almost all documents as “biology and life sciences” and none as “science policy,” which seems accurate at first glance, but in reality ignores important documents. By building separate binary classifiers each one will be fine-tuned respectively to tackle the imbalanced data problem.

The classifier we use from scikit-learn is the SGDClassifier, which supports several loss functions, including hinge loss, which is the loss function of SVM, and huber loss, which is for the regression model. We use hinge loss for demonstration.

cls_base_name = 'SGD'
cls_base = SGDClassifier
cls_list = [cls_base()] * len(all_classes)

The above code created 11 classifiers, each for classifying one subject area.


Now, we’ll start using the data to train the model.

In the below example, we’ll transform each of the multi class variables to a binary response variable. In other words, we’ll loop through each of the subject areas, and decide if the document can be classified using that subject.

def OVR_transformer(classes, pos):
    Tranform multi-labels to binary classes for OneVsRest classifiers.
    return 1 if pos in classes else 0

for j in range(len(all_classes)):
    # Create namespace for storing metrics.
    cls_name = cls_base_name + str(j)

    # transform y
    y_train = np.asarray(map(lambda x: OVR_transformer(x, j), y_train_all))
    cls_list[j].fit(X_train, y_train)

To build the one vs rest classifiers we need to transform the labels into 0 and 1 for each of the subject areas. If the label is 0, the document can’t be classified with the subject, and a label of 1 means that it can.

The OVR_transformer cunction transforms the labels into 0 and 1 based on which subject area we are considering as the true label.


Now that we’ve trained the model, it’s time to test.

X_test = vectorizer.transform(X_test_text)
for j in range(len(all_classes)):
    cls_name = cls_base_name + str(j)
    # Test
    print("Test results")
    y_test = np.asarray(map(lambda x: OVR_transformer(x, j), y_test_all))
    y_pred = cls_list[j].predict(X_test)
    test_stats = precision_recall_fscore_support(y_test, y_pred, pos_label=1)
    accuracy = cls_list[j].score(X_test, y_test)
    print("Test metrics for {}: {}".format(cls_name, test_stats))
    print(confusion_matrix(y_test, y_pred))

The function precision_recall_fscore_support gives you the metrics needed to evaluate the models. In our case, precision is the most important one. It is the ratio of “true positive,” pre-defined labels and all positive predictions, including potential misclasifications. We want to avoid incluidng wrong information in our metadata, while still tolerating some documents not being labelled at all. Another important metric is recall, or the ratio of “true positive” to all positive predictions. If this ratio is too low, the classifier is basically doing nothing. Precision and recall tradeoff is always a challenge. In our case, we would like to maximize precision while having a resonable recall.

In our test, the classifiers achieved 60 to 90 percent precision. However, the recall of some classifiers are extremely low due to the imbalanced data. This is acceptable for the simplified model, but it is clear that there could be improvement.

How to improve the model

The above code is the most simplified version of a multi-label classification model. The biggest problem with the model is imbalanced data, which we explained above. To tackle the problem, we can consider training the classifiers separately using the undersampling strategy, or other methods.

In terms of feature engineering, adding Term Frequency – Inverse Document Frequency (TF-IDF)( is the first obvious choice. TF-IDF can tune down the influence of high frequency words that don’t help with the classification, like “to,” “and,” “the,” etc. It is a must have model improvement to include unless you are dealing with very large data. Another method worth trying is the Word2Vec method, which transforms each word into a highdimensional vector. Then you can get the mean of the all word vectors in a document and use that as a feature together with TF-IDF. In fact, many papers nowadays will use the combination of TF-IDF and Word2Vec with SVM model as the baseline.

The framework

Here are some lessons I learned while working on this problem.

  • Be as generalized as possible to use in other cases
  • Make it scalable
  • Make sure your code is lightweight and simple

The training data we are dealing with is no where near “big data” on the TB, PB level, but storing them as flat files would definitely not be helpful for us to move on. Therefore, we used MongoDB instead of flat files to store the data.


Bonus: Text classification with Convolutional Neural Networks

Following the trend of deep learning in recent years, many papers have researched text classification with Convolutional Neural Networks, which focus on learning models that are similar in strucutre to neurons in the brain. I decided to give it a shot, but the performance was simply lacking. In case you are interested in how to do it, I suggest read this awesome tutorial first, then read the code of the prototype here.

In the tutorial, the author trained the word vectors starting from scratch. In our case, we used pre-trained word vectors from GoogleNews corpus, which you can find online.Therefore, we need to change the code in the embedding layer to use it properly. You can find the code foe this part here.

The conclusion, to put it simply, is that it costs you a significant amount of time and resources in exchange for little or no gain in precision. As fancy as Convolutional Neural Networks sounds, you might want to try traditional machine learning techniques before jumping into more complicated and hyped methods.

All my code is available online at

Thanks to Erin Braswell for the editing work!

Leave a Reply

Your email address will not be published. Required fields are marked *