Isaac's Blog

Dealing with Class Imbalance

2019/08/10

Class imbalance is a common problem in machine learning, where the total number of one class of data is far less than that of another class of data. Class imbalance affects the quality and reliability of results in machine learning tasks as most of the evaluation metrics assume a balanced class distribution. In this post, I am going to share a few simple yet effective methods that help handle imbalanced datasets, using R language.

Weighting

The most regular method is to assign weights to each class. Considering a dataset named data_training with the target variable target as a binary variable and with Y being positive and N being negative, the class weight assignment can be written as follows.

1
2
3
weights <- ifelse(data_training$target == "Y",
(1/table(data_training$target)[1]),
(1/table(data_training$target)[2]))

Resampling

The simplest data resampling methods are downsampling and upsampling. Both of them are covered by the caret package.

1
library(caret)

We only need to specify the resampling method in the control object for training. For downsampling:

1
2
3
4
5
ctrl <- trainControl(method = "repeatedcv",
repeats = 5,
classProbs = TRUE,
summaryFunction = twoClassSummary,
sampling = "down")

And for upsampling:

1
2
3
4
5
ctrl <- trainControl(method = "repeatedcv",
repeats = 5,
classProbs = TRUE,
summaryFunction = twoClassSummary,
sampling = "up")

There are also a few hybrid methods, such as random over-sampling examples (ROSE) and synthetic minority over-sampling technique (SMOTE), which downsample the majority class and synthesize new data points in the minority class. To use ROSE, we need to load the ROSE package.

1
library(ROSE)

And we create a wrapper around the ROSE function.

1
2
3
4
5
6
7
8
9
rosest <- list(name = "ROSE",
func = function(x, y) {
dat <- if (is.data.frame(x)) x else as.data.frame(x)
dat$.y <- y
dat <- ROSE(.y ~ ., data = dat, hmult.majo = 1, hmult.mino = 1)$data
list(x = dat[, !grepl(".y", colnames(dat), fixed = TRUE)],
y = dat$.y)
},
first = TRUE)

We specify the resampling method in the control object.

1
2
3
4
5
ctrl <- trainControl(method = "repeatedcv",
repeats = 5,
classProbs = TRUE,
summaryFunction = twoClassSummary,
sampling = rosest)

Similarly, to use SMOTE, we first load the DMwR package.

1
library(DMwR)

Then we create a wrapper around the SMOTE function.

1
2
3
4
5
6
7
8
9
smotest <- list(name = "SMOTE",
func = function(x, y) {
dat <- if (is.data.frame(x)) x else as.data.frame(x)
dat$.y <- y
dat <- SMOTE(.y ~ ., data = dat, perc.over = 100, k = 5)
list(x = dat[, !grepl(".y", colnames(dat), fixed = TRUE)],
y = dat$.y)
},
first = TRUE)

Finally, we specify the resampling method in the control object.

1
2
3
4
5
ctrl <- trainControl(method = "repeatedcv",
repeats = 5,
classProbs = TRUE,
summaryFunction = twoClassSummary,
sampling = smotest)
CATALOG
  1. 1. Weighting
  2. 2. Resampling