注意这篇文章来源于Pytorch中的仿射变换(affine_grid) ,感觉写的很好,所以转运过来了。
在看 pytorch 的 Spatial Transformer Network 教程  时,在 stn 层中的 affine_grid 与 grid_sample 函数上卡住了,不知道这两个函数该如何使用,经过一些实验终于搞清楚了其作用。
参考:详细解读 Spatial Transformer Networks (STN) ,该文章与李宏毅的课程一样,推荐听李老师的 STN 这一课,讲的比较清楚;
假设我们有这么一张图片:
下面我们将通过分别通过手动编码和 pytorch 方式对该图片进行平移、旋转、转置、缩放等操作,这些操作的数学原理在本文中不会详细讲解。
实现载入图片 (注意,下面的代码都是在 jupyter 中进行):
1 2 3 4 5 6 7 8 9 10 11 from  torchvision import  transformsfrom  PIL import  Imageimport  matplotlib.pyplot as  plt%matplotlib inline img_path = "图片文件路径"  img_torch = transforms.ToTensor()(Image.open(img_path)) plt.imshow(img_torch.numpy().transpose(1 ,2 ,0 )) plt.show() 
平移操作 普通方式 例如我们需要向右平移 50px,向下平移 100px。
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 import  numpy as  npimport  torchtheta = np.array([     [1 ,0 ,50 ],     [0 ,1 ,100 ] ]) t1 = theta[:,[0 ,1 ]] t2 = theta[:,[2 ]] _, h, w = img_torch.size() new_img_torch = torch.zeros_like(img_torch, dtype=torch.float) for  x in  range(w):    for  y in  range(h):         pos = np.array([[x], [y]])         npos = t1@pos+t2         nx, ny = npos[0 ][0 ], npos[1 ][0 ]         if  0 <=nx<w and  0 <=ny<h:             new_img_torch[:,ny,nx] = img_torch[:,y,x] plt.imshow(new_img_torch.numpy().transpose(1 ,2 ,0 )) plt.show() 
图片变为:
pytorch 方式 向右移动 0.2,向下移动 0.4:
1 2 3 4 5 6 7 8 9 10 11 from  torch.nn import  functional as  Ftheta = torch.tensor([     [1 ,0 ,-0.2 ],     [0 ,1 ,-0.4 ] ], dtype=torch.float) grid = F.affine_grid(theta.unsqueeze(0 ), img_torch.unsqueeze(0 ).size()) output = F.grid_sample(img_torch.unsqueeze(0 ), grid) new_img_torch = output[0 ] plt.imshow(new_img_torch.numpy().transpose(1 ,2 ,0 )) plt.show() 
得到的图片为:
总结:
要使用 pytorch 的平移操作,只需要两步:创建 grid:grid = torch.nn.functional.affine_grid(theta, size),其实我们可以通过调节 size 设置所得到的图像的大小 (相当于 resize); grid_sample 进行重采样:outputs = torch.nn.functional.grid_sample(inputs, grid, mode='bilinear')  theta 的第三列为平移比例,向右为负,向下为负; 我们通过设置 size 可以将图像 resize:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 from  torch.nn import  functional as  Ftheta = torch.tensor([     [1 ,0 ,-0.2 ],     [0 ,1 ,-0.4 ] ], dtype=torch.float) N, C, W, H = img_torch.unsqueeze(0 ).size() size = torch.Size((N, C, W//2 , H//3 )) grid = F.affine_grid(theta.unsqueeze(0 ), size) output = F.grid_sample(img_torch.unsqueeze(0 ), grid) new_img_torch = output[0 ] plt.imshow(new_img_torch.numpy().transpose(1 ,2 ,0 )) plt.show() 
缩放操作 普通方式 放大 1 倍:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 import  numpy as  npimport  torchtheta = np.array([     [2 ,0 ,0 ],     [0 ,2 ,0 ] ]) t1 = theta[:,[0 ,1 ]] t2 = theta[:,[2 ]] _, h, w = img_torch.size() new_img_torch = torch.zeros_like(img_torch, dtype=torch.float) for  x in  range(w):    for  y in  range(h):         pos = np.array([[x], [y]])         npos = t1@pos+t2         nx, ny = npos[0 ][0 ], npos[1 ][0 ]         if  0 <=nx<w and  0 <=ny<h:             new_img_torch[:,ny,nx] = img_torch[:,y,x] plt.imshow(new_img_torch.numpy().transpose(1 ,2 ,0 )) plt.show() 
结果为:
由于没有使用插值算法,所以中间有很多部分是黑色的。
pytorch 方式 1 2 3 4 5 6 7 8 9 10 11 from  torch.nn import  functional as  Ftheta = torch.tensor([     [0.5 , 0   , 0 ],     [0   , 0.5 , 0 ] ], dtype=torch.float) grid = F.affine_grid(theta.unsqueeze(0 ), img_torch.unsqueeze(0 ).size()) output = F.grid_sample(img_torch.unsqueeze(0 ), grid) new_img_torch = output[0 ] plt.imshow(new_img_torch.numpy().transpose(1 ,2 ,0 )) plt.show() 
结果为:
结论:可以看到,affine_grid的放大操作是以图片中心为原点的。
旋转操作 普通操作 将图片旋转 30 度:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 import  numpy as  npimport  torchimport  mathangle = 30 *math.pi/180  theta = np.array([     [math.cos(angle),math.sin(-angle),0 ],     [math.sin(angle),math.cos(angle) ,0 ] ]) t1 = theta[:,[0 ,1 ]] t2 = theta[:,[2 ]] _, h, w = img_torch.size() new_img_torch = torch.zeros_like(img_torch, dtype=torch.float) for  x in  range(w):    for  y in  range(h):         pos = np.array([[x], [y]])         npos = t1@pos+t2         nx, ny = int(npos[0 ][0 ]), int(npos[1 ][0 ])         if  0 <=nx<w and  0 <=ny<h:             new_img_torch[:,ny,nx] = img_torch[:,y,x] plt.imshow(new_img_torch.numpy().transpose(1 ,2 ,0 )) plt.show() 
结果为:
pytorch 操作 1 2 3 4 5 6 7 8 9 10 11 12 13 from  torch.nn import  functional as  Fimport  mathangle = -30 *math.pi/180  theta = torch.tensor([     [math.cos(angle),math.sin(-angle),0 ],     [math.sin(angle),math.cos(angle) ,0 ] ], dtype=torch.float) grid = F.affine_grid(theta.unsqueeze(0 ), img_torch.unsqueeze(0 ).size()) output = F.grid_sample(img_torch.unsqueeze(0 ), grid) new_img_torch = output[0 ] plt.imshow(new_img_torch.numpy().transpose(1 ,2 ,0 )) plt.show() 
结果为:
pytorch 以图片中心为原点进行旋转,并且在旋转过程中会发生图片缩放,如果选择角度变为 90°,图片为:
转置操作 普通操作 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 import  numpy as  npimport  torchtheta = np.array([     [0 ,1 ,0 ],     [1 ,0 ,0 ] ]) t1 = theta[:,[0 ,1 ]] t2 = theta[:,[2 ]] _, h, w = img_torch.size() new_img_torch = torch.zeros_like(img_torch, dtype=torch.float) for  x in  range(w):    for  y in  range(h):         pos = np.array([[x], [y]])         npos = t1@pos+t2         nx, ny = npos[0 ][0 ], npos[1 ][0 ]         if  0 <=nx<w and  0 <=ny<h:             new_img_torch[:,ny,nx] = img_torch[:,y,x] plt.imshow(new_img_torch.numpy().transpose(1 ,2 ,0 )) plt.show() 
结果为:
pytorch 操作 我们可以通过 size 大小,保存图片不被压缩:
1 2 3 4 5 6 7 8 9 10 11 12 from  torch.nn import  functional as  Ftheta = torch.tensor([     [0 , 1 , 0 ],     [1 , 0 , 0 ] ], dtype=torch.float) N, C, H, W = img_torch.unsqueeze(0 ).size() grid = F.affine_grid(theta.unsqueeze(0 ), torch.Size((N, C, W, H))) output = F.grid_sample(img_torch.unsqueeze(0 ), grid) new_img_torch = output[0 ] plt.imshow(new_img_torch.numpy().transpose(1 ,2 ,0 )) plt.show() 
结果为:
上面就是 affine_grid + grid_sample 的大致用法,如果你在看 STN 时有相同的用法,希望可以帮助到你。
注意事项 需要注意的是,即使是平移操作,再经过平移之后,也会有损失,例如将全红色图案经过平移变换后,得到的结果如下。
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 img_path = "4.png"  img_ori = Image.open(img_path) img_ori_np = np.array(img_ori) img_torch = transforms.ToTensor()(img_ori) theta = torch.tensor([     [1.0 , 0   , 0.0 ],     [0   , 1.0 , -0.2 ] ], dtype=torch.float) grid = F.affine_grid(theta.unsqueeze(0 ), img_torch.unsqueeze(0 ).size()) output = F.grid_sample(img_torch.unsqueeze(0 ), grid) new_img_torch = output[0 ] new_img = new_img_torch.numpy().transpose(1 ,2 ,0 )*255  unique,count = np.unique(new_img, return_counts=True ) data_count = dict(zip(unique,count)) print(data_count) plt.imshow(new_img) plt.show() 
输出结果为:
1 {0.0: 4352640, 0.0037693689: 1, 0.0037693977: 1919, 254.99805: 972, 255.0: 1865268}