代码解构
sixwalter Lv6

Baseline代码解构

数据读取及预处理

我个人觉得这里的数据处理是比较精妙的,可以把同样的思想应用到RGB输入里

数据准备:

  • 得到从群体类别到id的映射
  • 群体行为类别权重
  • 个体行为类别权重
  • 所有的个体行为类别标签
  • 所有群体行为标签
  • 所有个体行为标签
  • 所有的clip的关键点数据的文件路径
  • 所有的(video,clip)元组
  • 球轨迹
  • 边界框

步骤:

  1. store the dict {(video,clip) : group action id} in annotations_thisdatasetdir
  2. store all the current split’s joint data paths in clip_joints_path
  3. get current split’s video, clip into clips
  4. get correspond group labels in annotations
  5. get correspond person labels in person_actions_all
  6. count the number of clips

数据增强:

  • 保存一个副本作为真实数据,每次用真实数据去增强数据集
1
2
3
4
5
6
7
8
9
10
11
# 只有训练才需要进行增强
# if horizontal flip augmentation and is training
if self.args.horizontal_flip_augment and self.split == 'train':
# 不需要增强的为0,需要的mask为1
self.horizontal_flip_mask = list(np.zeros(len(self.annotations))) + list(np.ones(true_data_size))
self.annotations += true_annotations
self.annotations_each_person += true_annotations_each_person
self.clip_joints_paths += true_clip_joints_paths
self.clips += true_clips
if self.args.ball_trajectory_use:
self.clip_ball_paths += true_clip_ball_paths
  • 设置随机数
1
2
3
4
5
6
7
8
9
10
11
12
13
14
if self.args.horizontal_flip_augment and self.split == 'train':
# 这里因为flip所以群体标签需要调整
self.classidx_horizontal_flip_augment = {
0: 1,
1: 0,
2: 3,
3: 2,
4: 5,
5: 4,
6: 7,
7: 6
}
if self.args.horizontal_flip_augment_purturb:
self.horizontal_flip_augment_joint_randomness = dict()

数据分析:

  • 我们需要对关键点数据进行统计分析。如果没有统计结果文件的话,我们需要得到以下信息:
1
2
3
4
joint_xcoords = []
joint_ycoords = []
joint_dxcoords = []
joint_dycoords = []

注意,这里统计分析的步骤需要包含增强数据

对于每一个clip joint,我们先将它读到 joint_raw

  • 其次我们需要采样T帧:
1
frames = sorted(joint_raw.keys())[self.args.frame_start_idx:self.args.frame_end_idx+1:self.args.frame_sampling]
  • 如果存在数据增强的话,如果需要有扰动,需要提前初始化
1
2
3
4
5
6
7
8
9
10
11
12
13
14
if self.args.horizontal_flip_augment:
if index < len(self.horizontal_flip_mask):
if self.horizontal_flip_mask[index]:
if self.args.horizontal_flip_augment_purturb:
self.horizontal_flip_augment_joint_randomness[index] = defaultdict()
# ONLY CHANGE THE CHOSEN FRAMES JOINT INFO
joint_raw = self.horizontal_flip_augment_joint(
joint_raw, frames,
add_purturbation=True, randomness_set=False, index=index)
else:
joint_raw = self.horizontal_flip_augment_joint(joint_raw, frames)

if self.args.ball_trajectory_use:
ball_trajectory_data = self.horizontal_flip_ball_trajectory(ball_trajectory_data)
  • 还有个比较特殊的dropout增强:这里仅需设置随机性即可
1
2
3
4
5
6
7
8
9
# To compute statistics, no need to consider the random agent dropout augmentation,
# but we can set the randomness here.
# if random agent dropout augmentation and is training
if self.args.agent_dropout_augment:
if index < len(self.agent_dropout_mask):
if self.agent_dropout_mask[index]:
chosen_frame = random.choice(frames)
chosen_person = random.choice(range(self.args.N))
self.agent_dropout_augment_randomness[index] = (chosen_frame, chosen_person)
  • 因为姿态估计是存在错误的,所以需要对不合理的进行修改:
1
joints_sanity_fix() # 函数
  • 之后就可以去更新上面的坐标了:
1
2
3
4
5
6
7
8
9
10
11
12
# 更新joints list
for tidx, frame in enumerate(frames):
joint_xcoords.extend(joint_raw[frame][:,:,0].flatten().tolist())
joint_ycoords.extend(joint_raw[frame][:,:,1].flatten().tolist())

if tidx != 0:
pre_frame = frames[tidx-1]
joint_dxcoords.extend((joint_raw[frame][:,:,0]-joint_raw[pre_frame][:,:,0]).flatten().tolist())
joint_dycoords.extend((joint_raw[frame][:,:,1]-joint_raw[pre_frame][:,:,1]).flatten().tolist())
else:
joint_dxcoords.extend((np.zeros((self.args.N, self.args.J))).flatten().tolist())
joint_dycoords.extend((np.zeros((self.args.N, self.args.J))).flatten().tolist())
  • 有了上面这些,我们就可以去计算平均值和标准差了:
1
2
3
4
5
6
7
8
9
10
11
joint_xcoords_mean, joint_xcoords_std = np.mean(joint_xcoords), np.std(joint_xcoords)
joint_ycoords_mean, joint_ycoords_std = np.mean(joint_ycoords), np.std(joint_ycoords)
joint_dxcoords_mean, joint_dxcoords_std = np.mean(joint_dxcoords), np.std(joint_dxcoords)
joint_dycoords_mean, joint_dycoords_std = np.mean(joint_dycoords), np.std(joint_dycoords)

self.stats = {
'joint_xcoords_mean': joint_xcoords_mean, 'joint_xcoords_std': joint_xcoords_std,
'joint_ycoords_mean': joint_ycoords_mean, 'joint_ycoords_std': joint_ycoords_std,
'joint_dxcoords_mean': joint_dxcoords_mean, 'joint_dxcoords_std': joint_dxcoords_std,
'joint_dycoords_mean': joint_dycoords_mean, 'joint_dycoords_std': joint_dycoords_std
}
  • 计算完成后需要保存统计数据:
1
2
3
# 保存统计数据
with open(os.path.join('datasets', self.args.dataset_name, self.args.joints_folder_name, 'stats_train.pickle'), 'wb') as f:
pickle.dump(self.stats, f)
  • 如果有扰动的话还需要保存下来:
1
2
3
4
5
if self.args.horizontal_flip_augment and self.args.horizontal_flip_augment_purturb:
with open(os.path.join('datasets', self.args.dataset_name, self.args.joints_folder_name,
'horizontal_flip_augment_joint_randomness.pickle'), 'wb') as f:
pickle.dump(self.horizontal_flip_augment_joint_randomness, f)

数据获取:

  • 得到person_labels:
1
2
3
person_labels = torch.LongTensor(person_labels[frames[0]].squeeze())  
# person action remains to be the same across all frames
# person_labels: (N, )
  • 使用数据增强:
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
# if vertical move augmentation and is training
if self.args.vertical_move_augment and self.split == 'train':
if index < len(self.vertical_mask):
if self.vertical_mask[index]:
if self.args.ball_trajectory_use:
if self.args.vertical_move_augment_purturb:
joint_raw, ball_trajectory_data = self.vertical_move_augment_joint(
joint_raw, frames, add_purturbation=True,
randomness_set=True, index=index,
ball_trajectory=ball_trajectory_data)
else:
joint_raw, ball_trajectory_data = self.vertical_move_augment_joint(
joint_raw, frames, ball_trajectory=ball_trajectory_data)
else:
if self.args.vertical_move_augment_purturb:
joint_raw = self.vertical_move_augment_joint(
joint_raw, frames, add_purturbation=True,
randomness_set=True, index=index)
else:
joint_raw = self.vertical_move_augment_joint(joint_raw, frames)
  • 并进行合法性检查
image-20230518113142402
  • 获得4种类型的joint features并将它们进行连接:
1
2
3
4
joint_feats = torch.cat((torch.Tensor(np.array(joint_feats_basic)), 
torch.Tensor(np.array(joint_feats_metrics)).permute(1,2,0,3),
torch.Tensor(np.array(joint_feats_advanced)),
torch.Tensor(np.array(joint_coords_all))), dim=-1)
  • joint_coords_all 目的是为了图像坐标嵌入
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
joint_coords_all = []  # (N, J, T, 2) 
for n in range(self.args.N):
joint_coords_n = []

for j in range(self.args.J):
joint_coords_j = []

for tidx, frame in enumerate(frames):
joint_x, joint_y, joint_type = joint_raw[frame][n,j,:]

joint_x = min(joint_x, self.args.image_w-1)
joint_y = min(joint_y, self.args.image_h-1)
joint_x = max(0, joint_x)
joint_y = max(0, joint_y)

joint_coords = []
joint_coords.append(joint_x) # width axis
joint_coords.append(joint_y) # height axis

joint_coords_j.append(joint_coords)
joint_coords_n.append(joint_coords_j)
joint_coords_all.append(joint_coords_n)
  • joint_feats_basic 对关键点坐标进行标准化
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
joint_feats_basic = []  # (N, J, T, d_0_v1) 
for n in range(self.args.N):
joint_feats_n = []
for j in range(self.args.J):
joint_feats_j = []
for tidx, frame in enumerate(frames):
joint_x, joint_y, joint_type = joint_raw[frame][n,j,:]

joint_feat = []

joint_feat.append((joint_x-self.stats['joint_xcoords_mean'])/self.stats['joint_xcoords_std'])
joint_feat.append((joint_y-self.stats['joint_ycoords_mean'])/self.stats['joint_ycoords_std'])

if tidx != 0:
pre_frame = frames[tidx-1]
pre_joint_x, pre_joint_y, pre_joint_type = joint_raw[pre_frame][n,j,:]
joint_dx, joint_dy = joint_x - pre_joint_x, joint_y - pre_joint_y
else:
joint_dx, joint_dy = 0, 0

joint_feat.append((joint_dx-self.stats['joint_dxcoords_mean'])/self.stats['joint_dxcoords_std'])
joint_feat.append((joint_dy-self.stats['joint_dycoords_mean'])/self.stats['joint_dycoords_std'])
joint_feats_j.append(joint_feat)
joint_feats_n.append(joint_feats_j)
joint_feats_basic.append(joint_feats_n)
  • joint_feats_advanced 对关键点信息进归一化

  • joint_feats_metrics

  • 接下来,如果当前正在训练,并且使用dropout增强的话:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
		# if random agent dropout augmentation and is training                
if self.args.agent_dropout_augment and self.split == 'train':
if index < len(self.agent_dropout_mask):
if self.agent_dropout_mask[index]:
joint_feats = self.agent_dropout_augment_joint(
joint_feats, frames, index=index, J=self.args.J)

def agent_dropout_augment_joint(self, joint_feats, frames, index=0, J=17):
# joint_feats: (N, J, T, d)
chosen_frame = self.agent_dropout_augment_randomness[index][0]
chosen_person = self.agent_dropout_augment_randomness[index][1]
feature_dim = joint_feats.shape[3]

joint_feats[chosen_person, :, frames.index(chosen_frame), :] = torch.zeros(J, feature_dim)
return joint_feats
  • 最后,返回数据
1
return joint_feats, label, video, clip, person_labels#, ball_feats

模型处理模块

  • 获得所有需要的维度信息
1
2
3
4
5
6
B = joint_feats_thisbatch.size(0)
N = joint_feats_thisbatch.size(1)
J = joint_feats_thisbatch.size(2)
T = joint_feats_thisbatch.size(3)

d = self.args.TNT_hidden_dim
  • 首先进行图像坐标位置编码
1
2
3
4
5
6
7
8
9
# image coords positional encoding
image_coords = joint_feats_thisbatch[:,:,:,:,-2:].to(torch.int64).cuda() # B,N,J,T,2
coords_h = np.linspace(0, 1, self.args.image_h, endpoint=False)
coords_w = np.linspace(0, 1, self.args.image_w, endpoint=False)
xy_grid = np.stack(np.meshgrid(coords_w, coords_h), -1)
xy_grid = torch.tensor(xy_grid).unsqueeze(0).permute(0, 3, 1, 2).float().contiguous().to(device)
image_coords_learned = self.image_embed_layer(xy_grid).squeeze(0).permute(1, 2, 0) # h,w,2*embed_dim,
image_coords_embeded = image_coords_learned[image_coords[:,:,:,:,1], image_coords[:,:,:,:,0]]
# (B, N, J, T, d_0)
  • 其次进行时间位置编码
1
2
3
4
# time positional encoding
time_ids = torch.arange(1, T+1, device=device).repeat(B, N, J, 1)
time_seq = self.time_embed_layer(time_ids)
# (B, N, J, T, d_0)
  • 再进行关键点类型嵌入编码
1
2
3
4
5
6
7
8
# joint classes embedding learning as tokens/nodes
joint_class_ids = joint_feats_thisbatch[:,:,:,:,-1] # note that the last dim is the joint class id by default
joint_classes_embeded = self.joint_class_embed_layer(joint_class_ids.type(torch.LongTensor).cuda()) # (B, N, J, T, d_0)

x = joint_classes_embeded.transpose(2, 3).flatten(0, 1).flatten(0, 1) # x: (B*N*T, J, d_0)
input = (x, self.adj.repeat(B*N*T, 1, 1).cuda()) # adj: # (B*N*T, J, J)
joint_classes_encode = self.joint_class_gcn_layers(input)[0] # output
joint_classes_encode = joint_classes_encode.view(B, N, T, J, -1).transpose(2, 3) # (B, N, J, T, d_0)
  • 最后将这些进行拼接得到编码后的关键点信息
1
2
3
joint_feats_composite_encoded = torch.cat(
[joint_feats_thisbatch, time_seq, image_coords_embeded, joint_classes_encode],
dim=-1)
  • 之后进行投影:注意,这里时间维度已经没有了
1
2
3
4
5
6
7
8
9
10
11
12
# PROJECTIONS
# joint track projection
joint_track_feats_thisbatch_proj = self.joint_track_projection_layer(
joint_feats_composite_encoded.flatten(3, 4).flatten(0, 1).flatten(0, 1) # (B*N*J, T*d_0)
).view(B, N*J, -1)
# (B, N*J, d)

# person track projection
person_track_feats_thisbatch_proj = self.person_track_projection_layer(
joint_feats_for_person_thisbatch.flatten(0, 1).contiguous().view(B*N, -1)
).view(B, N, -1)
# (B, N, d)
  • 这里还有两个track, interaction track & group track,思想是类似的
  • 接下来将各个track输入到TNT网络中进行处理:
image-20230518142940111
  • TNT的主干是TNT_block:
image-20230518143303080
  • 下面是我的流程图,整体处理还是比较好理解的:
961ff5f4440f252388d4d9457d866d9

损失函数计算

1
2
3
# outputs is a list of list
# len(outputs) is the numbr of TNT layers
# each inner list is [CLS_f, CLS_m, CLS_c, output_CLS, output_fine, output_middle, output_coarse, output_group]

预测群体logits:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
pred_logits = []
for l in range(self.args.TNT_n_layers):

fine_cls = outputs[l][0].transpose(0, 1).squeeze(1) # (B, d)
middle_cls = outputs[l][1].transpose(0, 1).squeeze(1) # (B, d)
coarse_cls = outputs[l][2].transpose(0, 1).squeeze(1) # (B, d)
group_cls = outputs[l][3].transpose(0, 1).squeeze(1) # (B, d)

pred_logit_f = self.classifier(fine_cls)
pred_logit_m = self.classifier(middle_cls)
pred_logit_c = self.classifier(coarse_cls)
pred_logit_g = self.classifier(group_cls)

pred_logits.append([pred_logit_f, pred_logit_m, pred_logit_c, pred_logit_g])

计算scores:

1
2
3
4
5
6
7
8
9
10
11
# fine_cls, middle_cls, coarse_cls, group_cls are from the last layer
fine_cls_normed = nn.functional.normalize(fine_cls, dim=1, p=2)
middle_cls_normed = nn.functional.normalize(middle_cls, dim=1, p=2)
coarse_cls_normed = nn.functional.normalize(coarse_cls, dim=1, p=2)
group_cls_normed = nn.functional.normalize(group_cls, dim=1, p=2)

scores_f = self.prototypes(fine_cls_normed)
scores_m = self.prototypes(middle_cls_normed)
scores_c = self.prototypes(coarse_cls_normed)
scores_g = self.prototypes(group_cls_normed)
scores = [scores_f, scores_m, scores_c, scores_g]

预测个体logits:

1
2
3
4
5
pred_logits_person = []
for l in range(self.args.TNT_n_layers):
person_feats = outputs[l][5].transpose(0, 1).flatten(0,1) # (BxN, d)
pred_logit_person = self.person_classifier(person_feats)
pred_logits_person.append(pred_logit_person)

loss计算:

1
2
3
4
5
6
7
8
9
10
11
# model forward pass
pred_logits_thisbatch, pred_logits_person, scores = self.model(
joint_feats_thisbatch, ball_feats)
# measure accuracy and record loss
targets_thisbatch = targets_thisbatch.to(pred_logits_thisbatch[0][0].device)
person_labels = person_labels.flatten(0,1).to(pred_logits_thisbatch[0][0].device)

loss_thisbatch, prec1, prec3, prec1_person, prec3_person = self.loss_acc_compute(
pred_logits_thisbatch, targets_thisbatch, pred_logits_person, person_labels)

loss_thisbatch += contrastive_clustering_loss

这部分loss计算挺难理解的,之后看看原文尝试去修改一下

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
# learning the cluster assignment and computing the loss
scores_f = scores[0]
scores_m = scores[1]
scores_c = scores[2]
scores_g = scores[3]

# compute assignments
with torch.no_grad():
q_f = self.sinkhorn(scores_f, nmb_iters=self.args.sinkhorn_iterations)
q_m = self.sinkhorn(scores_m, nmb_iters=self.args.sinkhorn_iterations)
q_c = self.sinkhorn(scores_c, nmb_iters=self.args.sinkhorn_iterations)
q_g = self.sinkhorn(scores_g, nmb_iters=self.args.sinkhorn_iterations)

# swap prediction problem
p_f = scores_f / self.args.temperature
p_m = scores_m / self.args.temperature
p_c = scores_c / self.args.temperature
p_g = scores_g / self.args.temperature

contrastive_clustering_loss = self.args.loss_coe_constrastive_clustering * (
self.swap_prediction(p_f, p_m, q_f, q_m) +
self.swap_prediction(p_f, p_c, q_f, q_c) +
self.swap_prediction(p_f, p_g, q_f, q_g) +
self.swap_prediction(p_m, p_c, q_m, q_c) +
self.swap_prediction(p_m, p_g, q_m, q_g) +
self.swap_prediction(p_c, p_g, q_c, q_g)
) / 6.0 # 6 pairs of views
  • Post title:代码解构
  • Post author:sixwalter
  • Create time:2023-05-18 00:00:00
  • Post link:https://coelien.github.io/2023/05/18/deep-learning/第一章/代码结构/
  • Copyright Notice:All articles in this blog are licensed under BY-NC-SA unless stating additionally.
 Comments