diff --git a/main.py b/main.py index 5f6c86da..ceed819e 100644 --- a/main.py +++ b/main.py @@ -1,4 +1,5 @@ import torch +from torch import nn import hydra from omegaconf import DictConfig from torch.utils.data import DataLoader @@ -13,6 +14,7 @@ from typing import Dict, Any import os import time +from torch.optim.lr_scheduler import CosineAnnealingWarmRestarts class RepresentationType(Enum): @@ -27,6 +29,21 @@ def set_seed(seed): torch.backends.cudnn.benchmark = False np.random.seed(seed) +def compute_multiscale_epe_error(pred_flows, gt_flow): + total_loss = 0 + weights = [0.32, 0.08, 0.02, 0.01] # 各スケールの重み + for i, (pred_flow, weight) in enumerate(zip(pred_flows, weights)): + # 予測フローのサイズに合わせて地表真値フローをリサイズ + if pred_flow.shape != gt_flow.shape: + gt_flow_scaled = nn.functional.interpolate(gt_flow, size=pred_flow.shape[2:], mode='bilinear', align_corners=False) + else: + gt_flow_scaled = gt_flow + + epe = torch.mean(torch.norm(pred_flow - gt_flow_scaled, p=2, dim=1)) + total_loss += weight * epe + return total_loss + + def compute_epe_error(pred_flow: torch.Tensor, gt_flow: torch.Tensor): ''' end-point-error (ground truthと予測値の二乗誤差)を計算 @@ -110,12 +127,14 @@ def main(args: DictConfig): # ------------------ # Model # ------------------ - model = EVFlowNet(args.train).to(device) + # main関数内でモデルのインスタンス化を修正 + model = EVFlowNet(args.train, no_batch_norm=False).to(device) # ------------------ # optimizer # ------------------ optimizer = torch.optim.Adam(model.parameters(), lr=args.train.initial_learning_rate, weight_decay=args.train.weight_decay) + scheduler = CosineAnnealingWarmRestarts(optimizer, T_0=10, T_mult=2) # ------------------ # Start training # ------------------ @@ -127,14 +146,20 @@ def main(args: DictConfig): batch: Dict[str, Any] event_image = batch["event_volume"].to(device) # [B, 4, 480, 640] ground_truth_flow = batch["flow_gt"].to(device) # [B, 2, 480, 640] - flow = model(event_image) # [B, 2, 480, 640] - loss: torch.Tensor = compute_epe_error(flow, ground_truth_flow) + + # モデルの出力を取得 + model_output = model({"event_volume": event_image}) + flows = model_output["flow_predictions"] + + loss: torch.Tensor = compute_multiscale_epe_error(flows, ground_truth_flow) #追加 print(f"batch {i} loss: {loss.item()}") optimizer.zero_grad() loss.backward() optimizer.step() total_loss += loss.item() + + scheduler.step() print(f'Epoch {epoch+1}, Loss: {total_loss / len(train_data)}') # Create the directory if it doesn't exist @@ -152,14 +177,22 @@ def main(args: DictConfig): model.load_state_dict(torch.load(model_path, map_location=device)) model.eval() flow: torch.Tensor = torch.tensor([]).to(device) + final_flows = [] with torch.no_grad(): print("start test") for batch in tqdm(test_data): batch: Dict[str, Any] event_image = batch["event_volume"].to(device) - batch_flow = model(event_image) # [1, 2, 480, 640] - flow = torch.cat((flow, batch_flow), dim=0) # [N, 2, 480, 640] + model_output = model({"event_volume": event_image}) + flows = model_output["flow_predictions"] + # 最終的な(最大解像度の)フローを使用 + final_flow = flows[-1] + final_flows.append(final_flow) print("test done") + + # すべてのバッチの最終フローを結合 + flow: torch.Tensor = torch.cat(final_flows, dim=0) # [N, 2, 480, 640 + # ------------------ # save submission # ------------------ diff --git a/src/models/evflownet.py b/src/models/evflownet.py index ddfab828..00ff68ae 100644 --- a/src/models/evflownet.py +++ b/src/models/evflownet.py @@ -5,17 +5,43 @@ _BASE_CHANNELS = 64 +class SelfAttention(nn.Module): + def __init__(self, in_channels): + super(SelfAttention, self).__init__() + self.query = nn.Conv2d(in_channels, in_channels // 8, 1) + self.key = nn.Conv2d(in_channels, in_channels // 8, 1) + self.value = nn.Conv2d(in_channels, in_channels, 1) + self.gamma = nn.Parameter(torch.zeros(1)) + + def forward(self, x): + batch_size, C, width, height = x.size() + proj_query = self.query(x).view(batch_size, -1, width * height).permute(0, 2, 1) + proj_key = self.key(x).view(batch_size, -1, width * height) + energy = torch.bmm(proj_query, proj_key) + attention = torch.softmax(energy, dim=-1) + proj_value = self.value(x).view(batch_size, -1, width * height) + + out = torch.bmm(proj_value, attention.permute(0, 2, 1)) + out = out.view(batch_size, C, width, height) + out = self.gamma * out + x + return out + class EVFlowNet(nn.Module): - def __init__(self, args): + def __init__(self, args, no_batch_norm=False): super(EVFlowNet,self).__init__() self._args = args + self.no_batch_norm = no_batch_norm self.encoder1 = general_conv2d(in_channels = 4, out_channels=_BASE_CHANNELS, do_batch_norm=not self._args.no_batch_norm) self.encoder2 = general_conv2d(in_channels = _BASE_CHANNELS, out_channels=2*_BASE_CHANNELS, do_batch_norm=not self._args.no_batch_norm) self.encoder3 = general_conv2d(in_channels = 2*_BASE_CHANNELS, out_channels=4*_BASE_CHANNELS, do_batch_norm=not self._args.no_batch_norm) self.encoder4 = general_conv2d(in_channels = 4*_BASE_CHANNELS, out_channels=8*_BASE_CHANNELS, do_batch_norm=not self._args.no_batch_norm) - self.resnet_block = nn.Sequential(*[build_resnet_block(8*_BASE_CHANNELS, do_batch_norm=not self._args.no_batch_norm) for i in range(2)]) + # ResNetブロックの数を増やす + self.resnet_block = nn.Sequential(*[build_resnet_block(8*_BASE_CHANNELS, do_batch_norm=not self._args.no_batch_norm) for i in range(4)]) + + # Attention機構を追加 + self.attention = SelfAttention(8*_BASE_CHANNELS) self.decoder1 = upsample_conv2d_and_predict_flow(in_channels=16*_BASE_CHANNELS, out_channels=4*_BASE_CHANNELS, do_batch_norm=not self._args.no_batch_norm) @@ -30,40 +56,44 @@ def __init__(self, args): out_channels=int(_BASE_CHANNELS/2), do_batch_norm=not self._args.no_batch_norm) def forward(self, inputs: Dict[str, Any]) -> Dict[str, Any]: + # イベントボリュームを取得 + event_volume = inputs["event_volume"] + # encoder skip_connections = {} - inputs = self.encoder1(inputs) - skip_connections['skip0'] = inputs.clone() - inputs = self.encoder2(inputs) - skip_connections['skip1'] = inputs.clone() - inputs = self.encoder3(inputs) - skip_connections['skip2'] = inputs.clone() - inputs = self.encoder4(inputs) - skip_connections['skip3'] = inputs.clone() + x = self.encoder1(event_volume) + skip_connections['skip0'] = x.clone() + x = self.encoder2(x) + skip_connections['skip1'] = x.clone() + x = self.encoder3(x) + skip_connections['skip2'] = x.clone() + x = self.encoder4(x) + skip_connections['skip3'] = x.clone() # transition - inputs = self.resnet_block(inputs) + x = self.resnet_block(x) + x = self.attention(x) # Attention機構を適用 # decoder - flow_dict = {} - inputs = torch.cat([inputs, skip_connections['skip3']], dim=1) - inputs, flow = self.decoder1(inputs) - flow_dict['flow0'] = flow.clone() + flow_predictions = [] + x = torch.cat([x, skip_connections['skip3']], dim=1) + x, flow = self.decoder1(x) + flow_predictions.append(flow) - inputs = torch.cat([inputs, skip_connections['skip2']], dim=1) - inputs, flow = self.decoder2(inputs) - flow_dict['flow1'] = flow.clone() + x = torch.cat([x, skip_connections['skip2']], dim=1) + x, flow = self.decoder2(x) + flow_predictions.append(flow) - inputs = torch.cat([inputs, skip_connections['skip1']], dim=1) - inputs, flow = self.decoder3(inputs) - flow_dict['flow2'] = flow.clone() + x = torch.cat([x, skip_connections['skip1']], dim=1) + x, flow = self.decoder3(x) + flow_predictions.append(flow) - inputs = torch.cat([inputs, skip_connections['skip0']], dim=1) - inputs, flow = self.decoder4(inputs) - flow_dict['flow3'] = flow.clone() + x = torch.cat([x, skip_connections['skip0']], dim=1) + x, flow = self.decoder4(x) + flow_predictions.append(flow) - return flow - + return {"flow_predictions": flow_predictions} + # if __name__ == "__main__": # from config import configs @@ -93,4 +123,4 @@ def forward(self, inputs: Dict[str, Any]) -> Dict[str, Any]: # a = time.time() # (model(input_)) # b = time.time() -# print(b-a) \ No newline at end of file +# print(b-a)