Diffusion Driven Balancing

The University of Melbourne
ICCV 2025
Code arXiv

Abstract

Deep neural networks trained with Empirical Risk Minimization (ERM) perform well when both training and test data come from the same domain, but they often fail to generalize to out-of-distribution samples. In image classification, these models may rely on spurious correlations that often exist between labels and irrelevant features of images, making predictions unreliable when those features do not exist. We propose a Diffusion Driven Balancing (DDB) technique to generate training samples with text-to-image diffusion models for addressing the spurious correlation problem. First, we compute the best describing token for the visual features pertaining to the causal components of samples by a textual inversion mechanism. Then, leveraging a language segmentation method and a diffusion model, we generate new samples by combining the causal component with the elements from other classes. We also meticulously prune the generated samples based on the prediction probabilities and attribution scores of the ERM model to ensure their correct composition for our objective. Finally, we retrain the ERM model on our augmented dataset. This process reduces the model’s reliance on spurious correlations by learning from carefully crafted samples in which this correlation does not exist. Our experiments show that across different benchmarks, our technique achieves better worst-group accuracy than the existing state-of-the-art methods.

Overview

Overview

Figure 1. Examples of generated images by the proposed approach. We perform high-quality automated modification to the majority group samples (top) to generate minority group samples (bottom) such that the alterations precisely occur to the causal features in a manner that resolves the spurious correlations. (a) For Waterbirds dataset, which contains majority samples of landbirds with land backgrounds and waterbirds with water backgrounds, our method generates new images with landbirds on water backgrounds and waterbirds on land backgrounds while retaining majority sample backgrounds. (b) For CelebA, where the majority consists of non-blond males and blond females, our approach generates images of blond males and non-blond females. (c) For the MetaShift dataset, where dogs and cats are correlated with specific objects in the background, our technique generates samples that break this correlation by swapping dogs and cats.

Method

Method

Figure 2. An overview of DDB - Dog/Cat classes used for illustration. (a) First, the token C_dog is optimized by reconstructing samples from the Dog class using prompts that include the trainable C_dog. Then, for each batch in D_cat, masks M are generated using the LangSAM with the prompt Class Name = ‘animal’. These masks, along with the input images and a textual prompt, are fed into the inpainting model to generate new images. The token C_dog is incorporated to generate the causal features of the Dog class. (b) The generated samples are then passed through the ERM model and the integrated gradients (IG) module to compute relevant scores for filtering undesired generated images. Algorithm for this stage is detailed in Alg. 1. (c) The final set of generated samples is used to retrain the ERM model.

Qualitative samples

Qualitative Sample

BibTeX