一、背景
BEV方案中,将图像视角转换到BEV视角的方法对模型性能影响较大,FastBEV的速度较快,但投影效果上限不高,LSS投影上限较高,但速度较慢 (耗时相对较高)。是否有折中的方案,在耗时增加相对较少的情况下,提升模型的上限(中高算力平台下,提升模型能力)?
二、视角转换关键算子-----gridsample
这是pytorch官网对gridsample算子使用方法说明,其支持4-D(FastBEV/IMP)和5-D(LSS)采样,将图像特征提取到对应的BEV特征中,完成相机视角转换:https://pytorch.org/docs/stable/generated/torch.nn.functional.grid_sample.html
5-D gridsample相比4-D gridsample耗时剧增,假如在某智驾芯片上,4-D gridsample耗时是2ms,相同条件下5-D gridsample的耗时可能是200ms(具体耗时受特征图通道数影响),这种耗时急剧上升的方案,很难在智驾中落地应用。
三、LSS投影优化
1.先来对比4-D gridsample和5-D gridsample的输入输出关系:
4-D gridsample
input: (N, C, H_in, W_in);
bev_grid: (N, H_out, W_out, 2), 这里的2表示bev_grid坐标通过相机内外参投影到图像上的坐标(x,y);
output: (N, C, H_out, W_out)
5-D gridsample
input: (N, C, H_in, W_in);
for循环提取每个C通道的输入特征进行softmax处理input_i:(N, D, H_in, W_in),按照dim=1堆叠起来,得到深度输入input_2:(N, C, D, H_in, W_in), 这里的D表示深度估计的通道数;
bev_grid: (N, Z_out, H_out, W_out, 3), 这里的3表示bev_grid坐标通过相机内外参投影到图像上的坐标(x,y,d), d为深度估计;
output: (N, C, Z_out, H_out, W_out);
由于获取深度信息需要用到5-D gridsample,想要降低耗时,考虑减少特征图通道对耗时的影响,即做5-D gridsample时,将通道C设为1;
2.具体方法-----拆解5-D gridsample
将5-D gridsample拆解为一个4-D gridsample和一个单通道(C=1)的5-D gridsample,4-D gridsample负责提取多通道特征信息,单通道5-D gridsample负责提取深度特征信息,最后将两个特征信息相乘,得到多通道下的深度信息,等效变换过程如下:
step1:
4-D gridsample
input: (N, C, H_in, W_in);
bev_grid: (N, Z_out, H_out, W_out, 2), 这里的2表示bev_grid坐标通过相机内外参投影到图像上的坐标(x,y);
for循环提取每个Z_out下的bev_grid_i: (N, Z_out, H_out, W_out, 2),通过4-D gridsample分别得到输出特征图output_i: (N, C, H_out, W_out),按照dim=2堆叠起来,得到最终的BEV特征图output_1(没有深度概率信息):
output_1: (N, C, Z_out, H_out, W_out)
step2:
单通道5-D gridsample
input: (N, C, H_in, W_in);
input经过softmax处理后的特征图input_2: (N, D, H_in, W_in),这里的D表示深度估计的通道数;将input_2在dim=1上扩展一个维度,得到input_3:(N, 1, D, H_in, W_in)
bev_grid: (N, Z_out, H_out, W_out, 3), 这里的3表示bev_grid坐标通过相机内外参投影到图像上的坐标(x,y,d), d为深度估计;
output_2: (N, 1, Z_out, H_out, W_out);
step3:
将output_1和output_2相乘得到有深度概率信息的BEV特征图
output = outptu_1 * output_2 = (N, C, Z_out, H_out, W_out) * (N, 1, Z_out, H_out, W_out) = (N, 1, Z_out, H_out, W_out)
四、部分代码
1.IPM的BEV网格坐标索引
class UpdateIndicesIPM:def __init__(self, height, range, voxel_size, feature_size, downsample):self.height = heightself.range = rangeself.voxel_size = voxel_sizeself.feature_size = feature_sizeself.ds_matrix = np.eye(4)self.ds_matrix[:2] /= downsampledef __call__(self, data):num = len(data["cam2egoes"])ego2feats = torch.zeros((num, 4, 4), dtype=torch.float32)for i in range(num):ego2cam = np.linalg.inv(data["cam2egoes"][i])tmp = np.eye(4)tmp[:3, :3] = data["cam_intrinsics"][i]ego2feats[i] = torch.tensor(self.ds_matrix @ tmp @ ego2cam)grid = torch.stack(torch.meshgrid([torch.arange(self.range[0], self.range[3], self.voxel_size[0]),torch.arange(self.range[1], self.range[4], self.voxel_size[1]),torch.tensor(self.height), torch.tensor(1.0)], indexing="ij")) # [4, 188, 64, 4, 1]grid_h, grid_w = grid.shape[1:3]grid = grid.view(1, 4, -1).expand(num, 4, -1) # [7, 4, 192512] points_2d = torch.bmm(ego2feats[:, :3, :], grid)x = (points_2d[:, 0] / points_2d[:, 2]).round().long() y = (points_2d[:, 1] / points_2d[:, 2]).round().long() z = points_2d[:, 2]valid = ~((x >= 0) & (y >= 0) & (x < self.feature_size[1]) & (y < self.feature_size[0]) & (z > 0))x[valid] = 0y[valid] = 0x = (x.float() / self.feature_size[1] * 2.) - 1.0y = (y.float() / self.feature_size[0] * 2.) - 1.0indices = torch.cat([x.unsqueeze(2), y.unsqueeze(2)], dim=2)indices = indices.reshape(-1, grid_h, grid_w, len(self.height), 2) # batch, num_img, bev_w, bev_h, num_height, 2data["indices"] = indicesreturn data
2.FastBEV
class FastBevTransform(nn.Module):def __init__(self, feats_channels, num_height):super().__init__()self._num_height = num_heightself._conv = nn.Conv2d(feats_channels * num_height, feats_channels, kernel_size=1)self._grid_sample = GridSample(mode="nearest",padding_mode="zeros",align_corners=True)self._cat = Concat(dim=1)def forward(self, feats, indices):# feats: (7B, C, H, W), indices: (7B, Hg, Wg, Z, 2)bev_feats = []for i in range(self._num_height):output = self._grid_sample(feats, indices[:,:,:,i])bev_feats.append(output)bev_feats = self._cat(bev_feats) # (7B, Z*C, Hg, Wg)bev_feats = self._conv(bev_feats) # (7B, C, Hg, Wg)return bev_feats
3.LSS的BEV网格坐标索引
class UpdateIndicesLSS:def __init__(self, height, range, voxel_size, feature_size,resolution, max_num_depth, downsample):self.height = heightself.range = rangeself.voxel_size = voxel_sizeself.feature_size = feature_sizeself.resolution = resolutionself.max_num_depth = max_num_depthself.ds = np.eye(3)self.ds[:2] /= downsampledef __call__(self, data):num = len(data["cam2egoes"])ego2cams = torch.zeros((num, 4, 4), dtype=torch.float32)cam2feats = torch.zeros((num, 3, 3), dtype=torch.float32)for i in range(num):ego2cams[i] = torch.tensor(np.linalg.inv(data["cam2egoes"][i]))cam2feats[i] = torch.tensor(self.ds @ data["cam_intrinsics"][i])grid = torch.stack(torch.meshgrid([torch.arange(self.range[0], self.range[3], self.voxel_size[0]),torch.arange(self.range[1], self.range[4], self.voxel_size[1]),torch.tensor(self.height), torch.tensor(1.0)], indexing="ij")) # [4, 188, 64, 4, 1]grid_h, grid_w = grid.shape[1:3]grid4 = grid.view(1, 4, -1).expand(num, 4, -1) # [7, 4, 192512] points_2d = torch.bmm(ego2cams[:, :3, :], grid4)x = (points_2d[:, 0] / points_2d[:, 2]) # [7, 48128]y = (points_2d[:, 1] / points_2d[:, 2]) # [7, 48128]z = points_2d[:, 2] # [7, 48128]r = points_2d.norm(dim=1) # [B*N, Hg*Wg]d = torch.floor(r / self.resolution)distortions = torch.tensor(np.array(data["cam_distortions"]).T)k1,k2,k3,p1,p2,k4,k5,k6 = distortions[:,:,None]fovs = torch.tensor(data['crop_fovs']).unsqueeze(-1) / 2.0in_fov = np.abs(np.arctan2(points_2d[:, 0], z)) < fovsr2 = x**2 + y**2ratio = (1 + k1 * r2 + k2 * r2**2 + k3 * r2**3) / (1 + k4 * r2 + k5 * r2**2 + k6 * r2**3)x_undist = x * ratio + 2 * p1 * x * y + p2 * (r2 + 2 * x**2)y_undist = y * ratio + p1 * (r2 + 2 * y**2) + 2 * p2 * x * yx = cam2feats[:, 0, [0]] * x_undist + cam2feats[:, 0, [2]]y = cam2feats[:, 1, [1]] * y_undist + cam2feats[:, 1, [2]]valid = ~((x >= 0) & (y >= 0) & (x < self.feature_size[1]) & \(y < self.feature_size[0]) & (z > 0) & in_fov & \(d >= 0) & (d < self.max_num_depth)) # [7, 48128]x[valid], y[valid], d[valid] = -1, -1, -1x = (x.float() / self.feature_size[1] * 2.) - 1.0y = (y.float() / self.feature_size[0] * 2.) - 1.0d = (d.float() / self.max_num_depth * 2.) - 1.0indices = torch.cat([x[:,:,None], y[:,:,None], d[:,:,None]], dim=2) # [7, 48128, 3]indices = indices.reshape(-1, grid_h, grid_w, len(self.height), 3) # batch*num_img, bev_w, bev_h, num_height, 3(x, y, d)data["indices"] = indices.permute(0, 3, 1, 2, 4) # batch*num_img, num_height, bev_w, bev_h, 3(x, y, d)return data
4.LSS的BEV投影
class LssBevTransform(nn.Module):def __init__(self, feats_channels, num_height, max_num_depth):super().__init__()self._num_height = num_heightself._max_num_depth = max_num_depthself.ms_cam = MS_CAM(feats_channels * num_height)self._depth_proj = nn.Sequential(nn.Conv2d(feats_channels, max_num_depth, kernel_size=3, padding=1),nn.Softmax(dim=1))self._grid_sample = GridSample(mode="nearest",padding_mode="zeros",align_corners=True)self._cat = Concat(dim=1)self._blocks = nn.Sequential(nn.Conv2d(feats_channels * num_height, feats_channels, kernel_size=1),nn.BatchNorm2d(feats_channels),nn.ReLU(inplace=True))def simplify_bev(self, feats, indices):depths = self._depth_proj(feats)[:, None]import ipdbipdb.set_trace()passdef forward(self, feats, indices):# feats: (B*N, C, H, W)# indices: (B*N, Z, X, Y, 3) where 3 dims represent (w, h, d).bev_feats = self._sample_bev_feats(feats, indices[..., :2]) # (B*N, C, Z, X, Y)depth_feats = self._sample_depth_feats(feats, indices) # (B*N, 1, Z, X, Y)final_feats = bev_feats * depth_feats # (B*N, C, Z, Y, X)N, C, Z, Y, X = final_feats.shapefinal_feats = final_feats.view(N, C * Z, Y, X) # (B*N, Z*C, Hg, Wg)final_feats = final_feats*self.ms_cam(final_feats)final_feats = self._blocks(final_feats) # (B*N, C, Hg, Wg) return final_featsdef _sample_bev_feats(self, feats, indices):bev_feats = [self._grid_sample(feats, indices[:, i]) for i in range(self._num_height)]return torch.stack(bev_feats, dim=2) # (B*N, C, Z, Y, X) def _sample_depth_feats(self, feats, indices):depths = self._depth_proj(feats)[:, None] # (B*N, 1, D, H, W)return self._grid_sample(depths, indices) # (B*N, 1, Z, X, Y)
五、展望
LSS投影时将input_3:(N, 1, D, H_in, W_in)中D和H_in进行reshape合并后得(N, 1, D*H_in, W_in),可以完全通过4-D gridsample提取特征,耗时进一步降低,等效替代测试代码如下:
#!/usr/bin/env python3
import unittestimport torch
import torch.nn.functional as Fclass GridSampleTest(unittest.TestCase):def test_grid_sample_equivalence(self):D, H, W = 100, 144, 256Y, X = 64, 128# Generate random features.feats_5d = torch.randn(1, 1, D, H, W)# Generate random indices.d = torch.randint(high=D, size=(Y, X))h = torch.randint(high=H, size=(Y, X))w = torch.randint(high=W, size=(Y, X))# Prepare grid for 5D grid_sample.indices_5d = torch.stack([2.0 * w / (W - 1) - 1.0,2.0 * h / (H - 1) - 1.0,2.0 * d / (D - 1) - 1.0], dim=-1).view(1, 1, Y, X, 3)bev_feats_5d = F.grid_sample(feats_5d, indices_5d, mode="nearest", align_corners=True).view(Y, X)# Flatten D and H dimensions and prepare grid for 4D grid_sample.dh = d * H + hindices_4d = torch.stack([2.0 * w / (W - 1) - 1.0,2.0 * dh / (D * H - 1) - 1.0], dim=-1).view(1, Y, X, 2)feats_4d = feats_5d.view(1, 1, D * H, W)bev_feats_4d = F.grid_sample(feats_4d, indices_4d, mode="nearest", align_corners=True).view(Y, X)# Check if the results are close.self.assertTrue(torch.allclose(bev_feats_5d, bev_feats_4d, atol=1e-6))if __name__ == "__main__":unittest.main()