PyTorch中的仿射变换(affine_grid)

注意这篇文章来源于Pytorch中的仿射变换(affine_grid),感觉写的很好,所以转运过来了。

在看 pytorch 的 Spatial Transformer Network 教程 时,在 stn 层中的 affine_gridgrid_sample 函数上卡住了,不知道这两个函数该如何使用,经过一些实验终于搞清楚了其作用。

参考:详细解读 Spatial Transformer Networks (STN),该文章与李宏毅的课程一样,推荐听李老师的 STN 这一课,讲的比较清楚;

假设我们有这么一张图片:

下面我们将通过分别通过手动编码和 pytorch 方式对该图片进行平移、旋转、转置、缩放等操作,这些操作的数学原理在本文中不会详细讲解。

实现载入图片 (注意,下面的代码都是在 jupyter 中进行):

1
2
3
4
5
6
7
8
9
10
11
from torchvision import transforms
from PIL import Image
import matplotlib.pyplot as plt

%matplotlib inline

img_path = "图片文件路径"
img_torch = transforms.ToTensor()(Image.open(img_path))

plt.imshow(img_torch.numpy().transpose(1,2,0))
plt.show()

平移操作

普通方式

例如我们需要向右平移 50px,向下平移 100px。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
import numpy as np
import torch

theta = np.array([
[1,0,50],
[0,1,100]
])
# 变换1:可以实现缩放/旋转,这里为 [[1,0],[0,1]] 保存图片不变
t1 = theta[:,[0,1]]
# 变换2:可以实现平移
t2 = theta[:,[2]]

_, h, w = img_torch.size()
new_img_torch = torch.zeros_like(img_torch, dtype=torch.float)
for x in range(w):
for y in range(h):
pos = np.array([[x], [y]])
npos = t1@pos+t2
nx, ny = npos[0][0], npos[1][0]
if 0<=nx<w and 0<=ny<h:
new_img_torch[:,ny,nx] = img_torch[:,y,x]
plt.imshow(new_img_torch.numpy().transpose(1,2,0))
plt.show()

图片变为:

pytorch 方式

向右移动 0.2,向下移动 0.4:

1
2
3
4
5
6
7
8
9
10
11
from torch.nn import functional as F

theta = torch.tensor([
[1,0,-0.2],
[0,1,-0.4]
], dtype=torch.float)
grid = F.affine_grid(theta.unsqueeze(0), img_torch.unsqueeze(0).size())
output = F.grid_sample(img_torch.unsqueeze(0), grid)
new_img_torch = output[0]
plt.imshow(new_img_torch.numpy().transpose(1,2,0))
plt.show()

得到的图片为:

总结:

  • 要使用 pytorch 的平移操作,只需要两步:
    • 创建 grid:grid = torch.nn.functional.affine_grid(theta, size),其实我们可以通过调节 size 设置所得到的图像的大小 (相当于 resize);
    • grid_sample 进行重采样:outputs = torch.nn.functional.grid_sample(inputs, grid, mode='bilinear')
  • theta 的第三列为平移比例,向右为负,向下为负;

我们通过设置 size 可以将图像 resize:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
from torch.nn import functional as F

theta = torch.tensor([
[1,0,-0.2],
[0,1,-0.4]
], dtype=torch.float)
# 修改size
N, C, W, H = img_torch.unsqueeze(0).size()
size = torch.Size((N, C, W//2, H//3))
grid = F.affine_grid(theta.unsqueeze(0), size)
output = F.grid_sample(img_torch.unsqueeze(0), grid)
new_img_torch = output[0]
plt.imshow(new_img_torch.numpy().transpose(1,2,0))
plt.show()

缩放操作

普通方式

放大 1 倍:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
import numpy as np
import torch

theta = np.array([
[2,0,0],
[0,2,0]
])
t1 = theta[:,[0,1]]
t2 = theta[:,[2]]

_, h, w = img_torch.size()
new_img_torch = torch.zeros_like(img_torch, dtype=torch.float)
for x in range(w):
for y in range(h):
pos = np.array([[x], [y]])
npos = t1@pos+t2
nx, ny = npos[0][0], npos[1][0]
if 0<=nx<w and 0<=ny<h:
new_img_torch[:,ny,nx] = img_torch[:,y,x]
plt.imshow(new_img_torch.numpy().transpose(1,2,0))
plt.show()

结果为:

由于没有使用插值算法,所以中间有很多部分是黑色的。

pytorch 方式

1
2
3
4
5
6
7
8
9
10
11
from torch.nn import functional as F

theta = torch.tensor([
[0.5, 0 , 0],
[0 , 0.5, 0]
], dtype=torch.float)
grid = F.affine_grid(theta.unsqueeze(0), img_torch.unsqueeze(0).size())
output = F.grid_sample(img_torch.unsqueeze(0), grid)
new_img_torch = output[0]
plt.imshow(new_img_torch.numpy().transpose(1,2,0))
plt.show()

结果为:

结论:可以看到,affine_grid的放大操作是以图片中心为原点的。

旋转操作

普通操作

将图片旋转 30 度:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
import numpy as np
import torch
import math

angle = 30*math.pi/180
theta = np.array([
[math.cos(angle),math.sin(-angle),0],
[math.sin(angle),math.cos(angle) ,0]
])
t1 = theta[:,[0,1]]
t2 = theta[:,[2]]

_, h, w = img_torch.size()
new_img_torch = torch.zeros_like(img_torch, dtype=torch.float)
for x in range(w):
for y in range(h):
pos = np.array([[x], [y]])
npos = t1@pos+t2
nx, ny = int(npos[0][0]), int(npos[1][0])
if 0<=nx<w and 0<=ny<h:
new_img_torch[:,ny,nx] = img_torch[:,y,x]
plt.imshow(new_img_torch.numpy().transpose(1,2,0))
plt.show()

结果为:

pytorch 操作

1
2
3
4
5
6
7
8
9
10
11
12
13
from torch.nn import functional as F
import math

angle = -30*math.pi/180
theta = torch.tensor([
[math.cos(angle),math.sin(-angle),0],
[math.sin(angle),math.cos(angle) ,0]
], dtype=torch.float)
grid = F.affine_grid(theta.unsqueeze(0), img_torch.unsqueeze(0).size())
output = F.grid_sample(img_torch.unsqueeze(0), grid)
new_img_torch = output[0]
plt.imshow(new_img_torch.numpy().transpose(1,2,0))
plt.show()

结果为:

pytorch 以图片中心为原点进行旋转,并且在旋转过程中会发生图片缩放,如果选择角度变为 90°,图片为:

转置操作

普通操作

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
import numpy as np
import torch

theta = np.array([
[0,1,0],
[1,0,0]
])
t1 = theta[:,[0,1]]
t2 = theta[:,[2]]

_, h, w = img_torch.size()
new_img_torch = torch.zeros_like(img_torch, dtype=torch.float)
for x in range(w):
for y in range(h):
pos = np.array([[x], [y]])
npos = t1@pos+t2
nx, ny = npos[0][0], npos[1][0]
if 0<=nx<w and 0<=ny<h:
new_img_torch[:,ny,nx] = img_torch[:,y,x]
plt.imshow(new_img_torch.numpy().transpose(1,2,0))
plt.show()

结果为:

pytorch 操作

我们可以通过 size 大小,保存图片不被压缩:

1
2
3
4
5
6
7
8
9
10
11
12
from torch.nn import functional as F

theta = torch.tensor([
[0, 1, 0],
[1, 0, 0]
], dtype=torch.float)
N, C, H, W = img_torch.unsqueeze(0).size()
grid = F.affine_grid(theta.unsqueeze(0), torch.Size((N, C, W, H)))
output = F.grid_sample(img_torch.unsqueeze(0), grid)
new_img_torch = output[0]
plt.imshow(new_img_torch.numpy().transpose(1,2,0))
plt.show()

结果为:

上面就是 affine_grid + grid_sample 的大致用法,如果你在看 STN 时有相同的用法,希望可以帮助到你。

注意事项

需要注意的是,即使是平移操作,再经过平移之后,也会有损失,例如将全红色图案经过平移变换后,得到的结果如下。

image-20211221171505853

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
img_path = "4.png"
img_ori = Image.open(img_path)
img_ori_np = np.array(img_ori)
img_torch = transforms.ToTensor()(img_ori)

theta = torch.tensor([
[1.0, 0 , 0.0],
[0 , 1.0, -0.2]
], dtype=torch.float)
grid = F.affine_grid(theta.unsqueeze(0), img_torch.unsqueeze(0).size())
output = F.grid_sample(img_torch.unsqueeze(0), grid)
new_img_torch = output[0]
new_img = new_img_torch.numpy().transpose(1,2,0)*255
unique,count = np.unique(new_img, return_counts=True)
data_count = dict(zip(unique,count))
print(data_count)

plt.imshow(new_img)
plt.show()

输出结果为:

1
{0.0: 4352640, 0.0037693689: 1, 0.0037693977: 1919, 254.99805: 972, 255.0: 1865268}
------ 本文结束------
坚持原创技术分享,您的支持将鼓励我继续创作!

欢迎关注我的其它发布渠道