最近参加kaggle比赛,才发现对于图像分割损失函数有各种形式。同时,关于如何实现这些损失函数,尤其是加权的损失函数,之前并没有研究过。但是在实际应用中,应该还是挺常见的,毕竟样本不均衡问题时有发生。好了,废话不多说了, 进入正题。下面的内容均以二分类问题为例。
cross entropy
图像分割任务的本质为对于像素点的分类,通常称为密集预测(dense prediction)。分类问题自然可以使用cross entropy(交叉熵损失函数)。
设真实情况下$\mathbf{P}(Y = 0) = p$,$\mathbf{P}(Y = 1) = 1 - p$。通过 logistic/sigmoid 函数得到的预测$\mathbf{P}(\hat{Y} = 0) = \frac{1}{1 + e^{-x}} = \hat{p}$,$\mathbf{P}(\hat{Y} = 1) = 1 - \frac{1}{1 + e^{-x}} = 1 - \hat{p}$,则交叉熵损失函数CE为
在keras中,对应函数为binary_crossentropy(y_true, y_pred)
,在TensorFlow中,对应函数为softmax_cross_entropy_with_logits_v2
,在Pytorch中,对应的损失函数为torch.nn.BCEWithLogitsLoss()
。
Weighted cross entropy
Weighted cross entropy是cross entropy的一种变体,具体体现在所有的正例损失前均有一个系数。主要用于类别不平衡的问题,例如当图像中只有10%的正样本,而有90%的负样本的时候,常规的cross entropy不能正常的work。
如果想减少false negatives(漏报),即增加recall,则设置$\beta>1$;若想减少false positives(误报),则增加precision,则设置$\beta<1$。这个可以这么理解:
- 当$\beta>1$的时候,$p_{i,j} \log\left(\hat{p}_{i,j}\right)$的系数较大,所谓false negatives(漏报),就是指预测错了,预测为了负样本,实际类别为正样本,此时$p_{i,j}=1$,为了使得损失尽可能的小,会导致$\hat{p}_{i,j}$尽可能大,模型更加倾向于尽可能的减少漏报;
- 当$\beta<1$的时候,$(1-p_{i,j}) \log\left(1 - \hat{p}_{i,j}\right)$的系数较大,所谓false positives(误报),就是指预测错了,预测为了正样本,实际类别为负样本,此时$1-p_{i,j}=1$,为了使得损失尽可能的小,会导致$\hat{p}_{i,j}$尽可能小,模型更加倾向于尽可能的减少误报。
例如,当数据集中含有100个正例,300个负例的时候,Pytorch中的torch.nn.BCEWithLogitsLoss()
函数中的pos_weight
参数需要为$\frac{300}{100}=3$。此时的loss相当于有关$100\times3=300$个样本。
Balanced cross entropy
该损失函数和WCE基本一致,不同点在于该损失函数对负样本也进行了加权。
上面的公式均是针对每个样本均有一个权重。对于图像分割任务,相当于对所有样本的所有像素点均有一个权重。且该公式中,不管是正样本还是负样本的损失,均要除以$batch_size \times image_size$来得到均值。
除此之外,在遇到类别不均衡的时候,当计算正负样本损失的时候,分别所以各自的总数,然后加权。这样做的好处是,防止正样本数目过少导致求和后除以$batch_size \times image_size$值很小。当正样本的权值为0.25,负样本的权值为0.75的时候,具体公式可以描述如下:
其中,$p_{i,j}$为一个batch内所有样本所有像素点是否为正样本,为正样本为1,不为正样本为0;$n_{i,j}$为一个batch内所有样本所有像素点是否为负样本,为正样本为0,不为正样本为1;$loss_{i,j}$为一个batch内所有样本所有像素点的损失值。
正样本和负样本权重分别为0.25和0.75是针对SIIM-ACR Pneumothorax Segmentation比赛的。在该比赛中,有掩模的样本总数和无掩模的样本总数大概为1:3,也就相当于0.25:0.75。若不进行加权,则正样本和负样本的损失值基本相同,这不符合实际的数据分布,会导致最终可能出现没有掩模的也预测出了掩模的情况。PS:实际使用的时候,效果特别差。具体原因未知。
具体代码如下:
1 | # reference: https://www.kaggle.com/c/siim-acr-pneumothorax-segmentation/discussion/101429 |
DiceLoss
DICE与IOU很相似,具体两者的区别如下:
从中可以看出,$\text{DC} \geq \text{IoU}$(两者相减得到的式子中分子为$|X| + |Y| - 2|X \cap Y|$,显然分子大于0)。
DICE也可以作为loss使用,具体代码如下:
1 | # reference: https://github.com/asanakoy/kaggle_carvana_segmentation |
这里解释下dice_loss
函数内部的加权。preds.size(0)
得到batch_size
大小,w
的大小为batch_size*image_size
,则可以得到下式:
其中,$\hat p_{i,j}$为一个batch第$i$个样本第$j$个像素的预测值,而$p_{i,j}$为一个batch第$i$个样本第$j$个像素的真实值。
所以,这里的加权就相当于对一个batch内的所有样本的loss进行加权,和Pytorch中的BCEWithLogitsLoss中的weight
参数含义一致。对于图像分割任务,相当于对一个batch内的所有样本的所有像素点进行加权。
一方面,这样的加权方式不经常使用,因为我们经常会遇到正样本和负样本比例失衡问题,对于所有样本的所有像素点均要设置一个权值,在实现上不如直接设置正样本和负样本的权值方便,类似于Pytorch中的BCEWithLogitsLoss中的pos_weight
参数含义。PS:暂时没有实现,所以还是老老实实没一个样本设置一个权值吧。
另一方面,值得注意的是,在图像分割任务中,会碰到样本mask中没有正样本的情况。例如在SIIM-ACR Pneumothorax Segmentation比赛中,就会出现大部分图像中并没有目标,mask也就全部为负样本。对于mask全部为负样本的数据,若预测出mask也没正样本,上面的dice_loss
函数分子接近0,导致最终的loss很大,然而真实情况应该为此时loss应该很小。因此可以考虑下面的dice函数:
1 | # dice for threshold selection |
那么如何实现对所有样本的所有像素点分配权重呢?这需要引入下面的SoftDICELoss
。
SoftDICELoss
这个loss是一个kaggle的大神提出来的。该损失函数克服了上面DiceLoss
损失函数没有考虑
DICE还有另外一个形式:
其中,$\mathbf{p} \in \{0,1\}^n$,$\mathbf{\hat p} \in [0,1]^n$。
1 | # reference https://www.kaggle.com/c/siim-acr-pneumothorax-segmentation/discussion/101429#latest-588288 |
解释下上面代码,因为考虑到数据集中可能存在某些样本的掩模均为负样本,没有正样本的情况,所以需要将类标从[0,1]变为[-1,1]。此时,若真实掩模没有mask,预测出来也全部没有mask,不会因为全部值为0,导致dice的分子为0,loss为1。相反此时全部值为-1,dice的值为1,loss为0,更符合我们的实际需求。
另外,所谓的使用加权的损失函数解决样本不均衡问题,是指对于每一个正样本和负样本均有对应的加权系数。上面代码可以总结为公式:
其中,$t_{i,j} \in {-1,1}$,而$p_{i,j} \in [-1,1]$。若正样本的系数$w_{i,j}$为0.8,而负样本的系数$w_{i,j}$为0.2,则正样本对dice的影响更大,负样本对dice的影响更小。从而让网络更加关注正样本。
FocalLoss
该损失函数降低easy examples
的权重,使得模型更加关注hard examples
。
其中$\gamma$为超参数,当$\gamma = 0$的时候,我们得到标准BCE。我们这里关注的为当$\gamma \not= 0$的时候,对于这个公式的理解如下。
当$\gamma>1$时:
- 当样本为正样本时,此时上式右边只有第一项不为0。若$\hat{p}$较大的时候,意味着网络对该数据的分类效果较好,$(1 - \hat{p})^{\gamma}$值较小,意味着该数据的loss更小,网络接下来对于该数据的关注会更小;反之,当$\hat{p}$较小的时候,意味着网络对该数据的分类效果较差,$(1 - \hat{p})^{\gamma}$值较大,意味着该数据的loss更大,网络接下来对于该数据的关注会更大。
- 当样本为负样本时,此时上式右边只有第二项不为0。若$\hat{p}$较大的时候,意味着网络对该数据的分类效果较差,$\hat{p}^{\gamma}$值较大,意味着该数据的loss更大,网络接下来对于该数据的关注会更大;反之,当$\hat{p}$较小的时候,意味着网络对该数据的分类效果较好,$\hat{p}^{\gamma}$值较小,意味着该数据的loss更小,网络接下来对于该数据的关注会更小。
当$\gamma<1$的时候,此时损失函数会越加关注容易分的样本,而越加不关注难分的样本,与该损失函数的设计初衷背道而驰。所以实际使用的时候,$\gamma\geq1$。
因为我们这里使用的是logistic/sigmoid 函数预测的,所以继续进行推导,可以得到
参考
Losses for Image Segmentation
BCEWithLogitsLoss
losses.py
some workable loss function
How to apply weighted loss to a binary segmentation problem?