Pytorch踩坑记录

ToTensor

Pytorch中有ToTensor()函数,经常用在加载数据的时候。注意该函数的API文档是这样说的:

1
2
3
4
5
6
7
8
9
 """Convert a ``PIL Image`` or ``numpy.ndarray`` to tensor.

Converts a PIL Image or numpy.ndarray (H x W x C) in the range
[0, 255] to a torch.FloatTensor of shape (C x H x W) in the range [0.0, 1.0]
if the PIL Image belongs to one of the modes (L, LA, P, I, F, RGB, YCbCr, RGBA, CMYK, 1)
or if the numpy.ndarray has dtype = np.uint8

In the other cases, tensors are returned without scaling.
"""

也就是说,当该函数的输入为numpy.ndarray (H x W x C) np.uint8或者PIL Image(L, LA, P, I, F, RGB, YCbCr, RGBA, CMYK, 1)的时候,会对数据进行归一化处理。对于图像分割任务,一二分类为例,当类标为{0,1}的时候,满足以上条件,照样会对数据进行预处理。此时,类标就会变成很小的数,导致网络预测出的结果全为0。

值得注意的是,torchvision0.2.0并没有上面输入的前提,全部会归一化,也就是说下面的代码是错误的,而在torchvision0.2.0没有问题。

1
2
3
4
5
6
7
8
9
10
11
12
mask = Image.open(mask_path)
mask = mask.resize((224, 224))
# 将255转换为1, 0转换为0
mask = np.around(np.array(mask.convert('L'))/256.)
# mask = mask[:, :, np.newaxis] # Wrong, will convert range
mask = np.reshape(mask, (np.shape(mask)[0],np.shape(mask)[1],1)).astype("float32")
to_tensor = transforms.ToTensor()

transform_compose = transforms.Compose([to_tensor])
mask = transform_compose(mask)
mask = torch.squeeze(mask)
return mask.float()

对于这种情况,可以使用torch.from_numpy用法。避免不同torchvision版本的不同造成错误的影响。

1
2
3
4
5
6
mask = Image.open(mask_path)
mask = mask.resize((224, 224))
# 将255转换为1, 0转换为0
mask = np.around(np.array(mask.convert('L'))/256.)
mask = torch.from_numpy(mask)
return mask.float()

同样的,对于输入数据而言,使用ToTensor()的时候,也会归一化。例如下面代码:

1
2
3
4
5
6
7
8
image = self.image_transform(img)

resize = transforms.Resize(224)
to_tensor = transforms.ToTensor()
normalize = transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
transform_compose = transforms.Compose([resize, to_tensor, normalize])

image = transform_compose(image)

这里的数据首先经过resize函数进行resize,接着使用ToTensor()函数归一化。最终使用transforms.Normalize函数标准化。可以看到这里使用transforms.Normalize函数的失活,均值和方差均小于1,这也是因为ToTensor()函数会对数据归一化的体现。而我之前使用Tensorflow的时候,因为没有经过归一化,所以标准化的时候,数据各个通道的均值为[103.939, 116.779, 123.68]

------ 本文结束------
坚持原创技术分享,您的支持将鼓励我继续创作!

欢迎关注我的其它发布渠道