Unbox named a top 100 AI company of 2022 by CB Insights!

Read more



Use cases




Get started

Machine learning
Model quality

Dealing with class imbalance - Part 1

Learning with unbalanced datasets

Gustavo Cid's picture
Gustavo Cid

6 minute read · March 22, 2022, 12:01 AM

Imagine you were reading recent studies about a rare disease and, suddenly, you spot a pattern! You notice that the disease seems to be associated with markers detected on common blood tests and you figure that if people were automatically alerted early enough, they could seek treatment and avoid most of the severe symptoms.

Equipped with your machine learning (ML) knowledge, you partner with a local hospital to collect data and train a model. You evaluate your model’s performance on a validation set and you see an accuracy equal to 99.9% blinking on the terminal window. You did it!

Not so fast...

When working with ML, practitioners are often at a curious spot. On the one hand, the learning algorithms and metrics are natively prepared to deal with balanced datasets, i.e., datasets where the number of examples for every class is more or less equal. On the other, the world is not always balanced and, more often than not, we are more interested in the rarest classes.

In the example we presented above, the binary classifier should manage to classify samples from a rare disease. If the validation set had 1000 samples, 999 of which were from patients that do not have the disease, a model that always predicts an individual is disease-free would also have a 99.9% accuracy. How useful would such a model be?

Such a scenario is quite common in ML. A model that detects diseases must be prepared to deal with mostly healthy samples. A credit card fraud classifier needs to handle much more normal transactions than fraudulent ones. An e-mail spam classifier must correctly distinguish between spam and not spam even though 85% of all e-mails are spam.

This blog post is the first of a two-part series on class imbalance. In the series we will expose why dealing with imbalanced data is challenging and what are some ways to overcome it. Today, we start almost at the beginning of the ML pipeline: with training models on an unbalanced dataset.

Join thousands of practitioners and enthusiasts learning the secrets of building performant and explainable ML!

Hey — if you already know how you want to mitigate the issues associated with unbalanced datasets, feel free to head straight to Unbox!

Learning algorithms and unbalanced datasets

In the context of supervised learning, learning often means finding the model’s parameters that minimize a loss function on a training set, which mathematically translates to:

w=argminwi=1nl(xi,yi;w),w^* = \text{argmin}_{w} \sum_{i = 1}^n l(x_i, y_i; w),

where the training dataset is given by input-output pairs {(x_1, y_1), ..., (x_n, y_n)} and l is the loss function, which measures how well the model with parameters w is doing for the sample (x_i, y_i).

This approach generally works well and is based on empirical risk minimization (ERM), which coins the idea that by minimizing a loss function evaluated on the training set (called empirical risk) we will be led closer to what we truly want, which is a good model that generalizes well for new data.

Now, let’s think about what happens when we have to work with an unbalanced dataset.

When a model is learning over an unbalanced dataset, the samples from the minority class will contribute little to the quantity that we are striving to minimize. This is simply because in our training set we will have few (x_i, y_i) examples from the minority class, and, looking at the summation above, it is possible to note that their contribution to the empirical risk is much smaller than the contribution from samples from the majority class.

The problem is that the less these samples contribute to the empirical risk, the higher the chances that we end up with a model that is biased towards the majority class and more prone to errors in the minority class, simply because they affect little our optimization objective.

To make matters worse, in practice, it is common to use learning algorithms such as stochastic gradient descent, which approximate the empirical risk by sampling individual points or batches of data from our training set at random instead of performing the summation across the whole dataset. If we have a heavily unbalanced dataset, how likely are these samples to come from the minority class? Not very likely, which results in models that are not prepared to handle the minority classes.

But are models always doomed to fail in such cases?

Fortunately, there are a few solutions that can help greatly the models while they learn over unbalanced datasets.

Solution #1: Undersampling

If we have a lot of samples from one class and not so many from the others, why don’t we throw away some of the samples from the majority class to balance out our dataset? This is the first, and most obvious, solution to handle unbalanced datasets, called undersampling.

Even though a lot of practitioners may feel wary after reading this, undersampling is indeed a possible solution. Sometimes, throwing away samples might help boost the model’s performance, particularly if the data is somewhat noisy. Remember, good data is much better than big data.

Undersampling also comes with other pros. It is quite simple to implement and models are potentially trained faster since they have to go through a smaller dataset.

There is a caveat that justifies every bit of the ML practitioners’ wariness face to undersampling. Obtaining data is costly and throwing samples away means wasting these precious resources. This is exactly the biggest con of undersampling: the risk of wasting important resources.

Therefore, before you start randomly deleting samples from your dataset, think critically about the problem that you have and if undersampling is a good option.

Solution #2: Oversampling

In undersampling, we throw away samples from the majority class. Now, in oversampling, we create more samples from the minority classes to balance out our training set. This is where it is possible to leverage the power of synthetic data.

Generating synthetic data to augment underrepresented portions of your training data is a great way to increase your model’s robustness, ensure important invariances, and further explore specific model failure modes.

There are various ways of generating synthetic data, depending on the datatype of interest. It can be as simple as adding a little bit of noise to the data that you have or as complex as using generative adversarial networks (GANs).

Original data samples can also be perturbed to augment the dataset. In natural language processing (NLP), this can be done, for instance, by introducing small typos (which encourages model robustness to typos) or replacing word tokens with synonyms. In computer vision applications, this can be done by adding noise to sample images, changing image orientation, among a plethora of other ways. Each area has its idiosyncratic methods of data perturbation for data augmentation.

With Unbox, there are multiple ways to generate synthetic data to augment your training set with just a few clicks. If this is something you are interested in learning more about, feel free to check out our white paper and one of our recent blog posts, where we talk about synthetic data in the context of ML model testing.

Solution #3: Cost-sensitive loss

The two approaches presented above mitigate the problem of class imbalance by modifying the data itself. However, this is not the only category of possible approaches.

Another possibility to solve this issue is by modifying the loss function in a way that mistakes on the minority class are more severely punished than mistakes on the majority class. By doing so, we ensure that the model focuses more closely on learning to predict these samples.

A first implementation of this approach would be attributing weights inversely proportional to the number of samples from a class. Let n be the total number of training samples on our dataset, in this case, we might use weights such as

ci=nnumber of samples from class i in the training setc_i = \frac{n}{\text{number of samples from class i in the training set}}

Another approach would be using focal loss. Focal loss is a technique created by Facebook AI researchers that helps deal with class imbalance cleverly. The idea is that models generally learn to classify some instances much faster/easier than others. In the case of unbalanced datasets, these are likely examples from the majority class. What can be done, then, is incentivizing the model to progressively focus on the examples it still has trouble classifying. To do so, the loss function evaluated on a sample receives a weight that is inversely proportional to the probability of it being right.

The world is often unbalanced and, even though a lot of the ML models and algorithms are not natively prepared to work with unbalanced datasets, practitioners need to be aware of this fact and work around it accordingly. In this post, we explored how the problem of class imbalance affects the learning algorithms a lot of the models use and a few different solutions that might mitigate the problem. On part 2 we will focus on model evaluation in the presence of class imbalance.