What is an imbalanced dataset?
An imbalanced dataset refers to a situation where there are a disproportionate number of cases of one class compared to another.
For example here, in a 2k images dataset, there are approximately 20 times more cars than buses inside the training set. ( see below ) That’s what we call an Imbalanced Dataset, when all classes are not equally represented.
Researchers often use imbalanced datasets when designing experiments, since they allow them to more easily control for experimental bias. That said, there are situations where an imbalanced dataset can pose a serious problem for machine learning models. In such situations, the models will be prone to overfitting and low accuracy. In other words, imbalanced datasets are problematic because they can lead to biased and inaccurate results.
Why are imbalanced datasets a problem?
Computer vision models rely on a large amount of data for training, but for certain types of images there aren’t enough examples to provide a robust representation of the data.
For example, if we are training a model to identify images of traffic signs, there are many images of stop signs but relatively few images of other types of traffic signs. Let’s say the split is 80% stop signs, and 20% traffic signs.
Deep Learning models are just very sophisticated idiots.
So during the training process, the model will infer that if it predicts stop signs all the time, it will be right 80% of the time. Quite cool, right? In fact it’s not!
This means that the training process may lead to an inefficient representation of the data, resulting in models that are highly specific and less robust. In our case, the models will be highly accurate for the stop sign, but less accurate for all the other traffic signs.
3 strategies to overcome imbalanced datasets
We can tackle the problem in 3 ways.
The first strategy is data augmentation. This involves creating additional examples from the original dataset, but with slight alterations such as flipping images vertically or creating mirrored images. This will increase the number of images in the underrepresentated classes and yield a more balanced dataset.
However, data augmentation is not enough sometimes. This is when class weighting comes into play. Adding weights to every class can be a useful strategy. I will dive into it shortly using some examples :)
The third option entails performing a hierarchical classification. The main idea is to build multiple binary datasets to ensure our datasets are always balanced. The downside is that you will accumulate errors along the way.
Data augmentation is a technique that can be used to improve the robustness and accuracy of a machine learning model by creating additional examples from the original dataset. Data augmentation is helpful in situations where the original dataset is imbalanced. There are a number of techniques that can be used for data augmentation, including flipping examples, creating mirrored examples, rotating examples, coloring images in different ways, and adding noise. The goal of data augmentation is to increase the number of examples for each class, which can be helpful in situations where the original dataset has a disproportionate number of data points for one class compared to another. This can create issues during the training process and cause the model to be less accurate and more prone to overfitting.
If you want to experiment with data augmentation, here is a cool demo for the Albumentations devs: https://demo.albumentations.ai/
Class weighting involves assigning higher weights to examples of a particular class. It can help balance out the data and reduce the bias that is associated with an imbalanced dataset. Class weighting can be helpful in situations where there are significantly fewer examples from one class compared to another. This can be problematic in machine learning models that rely heavily on examples for training. Class weighting are helpful in situations where a model is being used to identify a particular class, such as the type of flower in an image. There may be significantly fewer images of a specific type of flower - such as orchids - compared to other types of flowers. When training a model for this type of classification, it may be helpful to increase the weight of examples associated with the orchid class. This can help balance out the data and reduce the bias that is associated with an imbalanced dataset.
Hierarchical classification is a machine learning technique that uses a hierarchy of categories to classify data. It is particularly useful for computer vision applications, as well as for classifying imbalanced datasets. This type of classification involves training multiple models, each on a different level of the hierarchy. For example, if you have a dataset of animals, the highest level of the hierarchy could be "animal," followed by "mammal," and then "dog." Each model is trained on the data at that level, with the results then being combined to form the overall prediction. Hierarchical classification can be used to accurately classify data even when there is a large amount of variation or when the dataset is imbalanced. It is a powerful tool that can be used to improve the accuracy of machine learning models.
Computer vision can be challenging to implement, especially when dealing with imbalanced datasets. An imbalanced dataset is one in which the number of data points for one class is significantly lower than the number of data points for the other classes. This can lead to problems with accuracy, overfitting, and bias. Fortunately, there are some strategies that can help overcome these issues. It is crucial to understand how to properly use techniques such as data augmentation, class weighting, and oversampling. A computer vision model trained on balanced datasets will allow you to improve accuracy and reduce bias.