Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Event camera competition #21

Open
wants to merge 2 commits into
base: event-camera-competition
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 38 additions & 5 deletions main.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import torch
from torch import nn
import hydra
from omegaconf import DictConfig
from torch.utils.data import DataLoader
Expand All @@ -13,6 +14,7 @@
from typing import Dict, Any
import os
import time
from torch.optim.lr_scheduler import CosineAnnealingWarmRestarts


class RepresentationType(Enum):
Expand All @@ -27,6 +29,21 @@ def set_seed(seed):
torch.backends.cudnn.benchmark = False
np.random.seed(seed)

def compute_multiscale_epe_error(pred_flows, gt_flow):
total_loss = 0
weights = [0.32, 0.08, 0.02, 0.01] # 各スケールの重み
for i, (pred_flow, weight) in enumerate(zip(pred_flows, weights)):
# 予測フローのサイズに合わせて地表真値フローをリサイズ
if pred_flow.shape != gt_flow.shape:
gt_flow_scaled = nn.functional.interpolate(gt_flow, size=pred_flow.shape[2:], mode='bilinear', align_corners=False)
else:
gt_flow_scaled = gt_flow

epe = torch.mean(torch.norm(pred_flow - gt_flow_scaled, p=2, dim=1))
total_loss += weight * epe
return total_loss


def compute_epe_error(pred_flow: torch.Tensor, gt_flow: torch.Tensor):
'''
end-point-error (ground truthと予測値の二乗誤差)を計算
Expand Down Expand Up @@ -110,12 +127,14 @@ def main(args: DictConfig):
# ------------------
# Model
# ------------------
model = EVFlowNet(args.train).to(device)
# main関数内でモデルのインスタンス化を修正
model = EVFlowNet(args.train, no_batch_norm=False).to(device)

# ------------------
# optimizer
# ------------------
optimizer = torch.optim.Adam(model.parameters(), lr=args.train.initial_learning_rate, weight_decay=args.train.weight_decay)
scheduler = CosineAnnealingWarmRestarts(optimizer, T_0=10, T_mult=2)
# ------------------
# Start training
# ------------------
Expand All @@ -127,14 +146,20 @@ def main(args: DictConfig):
batch: Dict[str, Any]
event_image = batch["event_volume"].to(device) # [B, 4, 480, 640]
ground_truth_flow = batch["flow_gt"].to(device) # [B, 2, 480, 640]
flow = model(event_image) # [B, 2, 480, 640]
loss: torch.Tensor = compute_epe_error(flow, ground_truth_flow)

# モデルの出力を取得
model_output = model({"event_volume": event_image})
flows = model_output["flow_predictions"]

loss: torch.Tensor = compute_multiscale_epe_error(flows, ground_truth_flow) #追加
print(f"batch {i} loss: {loss.item()}")
optimizer.zero_grad()
loss.backward()
optimizer.step()

total_loss += loss.item()

scheduler.step()
print(f'Epoch {epoch+1}, Loss: {total_loss / len(train_data)}')

# Create the directory if it doesn't exist
Expand All @@ -152,14 +177,22 @@ def main(args: DictConfig):
model.load_state_dict(torch.load(model_path, map_location=device))
model.eval()
flow: torch.Tensor = torch.tensor([]).to(device)
final_flows = []
with torch.no_grad():
print("start test")
for batch in tqdm(test_data):
batch: Dict[str, Any]
event_image = batch["event_volume"].to(device)
batch_flow = model(event_image) # [1, 2, 480, 640]
flow = torch.cat((flow, batch_flow), dim=0) # [N, 2, 480, 640]
model_output = model({"event_volume": event_image})
flows = model_output["flow_predictions"]
# 最終的な(最大解像度の)フローを使用
final_flow = flows[-1]
final_flows.append(final_flow)
print("test done")

# すべてのバッチの最終フローを結合
flow: torch.Tensor = torch.cat(final_flows, dim=0) # [N, 2, 480, 640

# ------------------
# save submission
# ------------------
Expand Down
84 changes: 57 additions & 27 deletions src/models/evflownet.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,17 +5,43 @@

_BASE_CHANNELS = 64

class SelfAttention(nn.Module):
def __init__(self, in_channels):
super(SelfAttention, self).__init__()
self.query = nn.Conv2d(in_channels, in_channels // 8, 1)
self.key = nn.Conv2d(in_channels, in_channels // 8, 1)
self.value = nn.Conv2d(in_channels, in_channels, 1)
self.gamma = nn.Parameter(torch.zeros(1))

def forward(self, x):
batch_size, C, width, height = x.size()
proj_query = self.query(x).view(batch_size, -1, width * height).permute(0, 2, 1)
proj_key = self.key(x).view(batch_size, -1, width * height)
energy = torch.bmm(proj_query, proj_key)
attention = torch.softmax(energy, dim=-1)
proj_value = self.value(x).view(batch_size, -1, width * height)

out = torch.bmm(proj_value, attention.permute(0, 2, 1))
out = out.view(batch_size, C, width, height)
out = self.gamma * out + x
return out

class EVFlowNet(nn.Module):
def __init__(self, args):
def __init__(self, args, no_batch_norm=False):
super(EVFlowNet,self).__init__()
self._args = args
self.no_batch_norm = no_batch_norm

self.encoder1 = general_conv2d(in_channels = 4, out_channels=_BASE_CHANNELS, do_batch_norm=not self._args.no_batch_norm)
self.encoder2 = general_conv2d(in_channels = _BASE_CHANNELS, out_channels=2*_BASE_CHANNELS, do_batch_norm=not self._args.no_batch_norm)
self.encoder3 = general_conv2d(in_channels = 2*_BASE_CHANNELS, out_channels=4*_BASE_CHANNELS, do_batch_norm=not self._args.no_batch_norm)
self.encoder4 = general_conv2d(in_channels = 4*_BASE_CHANNELS, out_channels=8*_BASE_CHANNELS, do_batch_norm=not self._args.no_batch_norm)

self.resnet_block = nn.Sequential(*[build_resnet_block(8*_BASE_CHANNELS, do_batch_norm=not self._args.no_batch_norm) for i in range(2)])
# ResNetブロックの数を増やす
self.resnet_block = nn.Sequential(*[build_resnet_block(8*_BASE_CHANNELS, do_batch_norm=not self._args.no_batch_norm) for i in range(4)])

# Attention機構を追加
self.attention = SelfAttention(8*_BASE_CHANNELS)

self.decoder1 = upsample_conv2d_and_predict_flow(in_channels=16*_BASE_CHANNELS,
out_channels=4*_BASE_CHANNELS, do_batch_norm=not self._args.no_batch_norm)
Expand All @@ -30,40 +56,44 @@ def __init__(self, args):
out_channels=int(_BASE_CHANNELS/2), do_batch_norm=not self._args.no_batch_norm)

def forward(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
# イベントボリュームを取得
event_volume = inputs["event_volume"]

# encoder
skip_connections = {}
inputs = self.encoder1(inputs)
skip_connections['skip0'] = inputs.clone()
inputs = self.encoder2(inputs)
skip_connections['skip1'] = inputs.clone()
inputs = self.encoder3(inputs)
skip_connections['skip2'] = inputs.clone()
inputs = self.encoder4(inputs)
skip_connections['skip3'] = inputs.clone()
x = self.encoder1(event_volume)
skip_connections['skip0'] = x.clone()
x = self.encoder2(x)
skip_connections['skip1'] = x.clone()
x = self.encoder3(x)
skip_connections['skip2'] = x.clone()
x = self.encoder4(x)
skip_connections['skip3'] = x.clone()

# transition
inputs = self.resnet_block(inputs)
x = self.resnet_block(x)
x = self.attention(x) # Attention機構を適用

# decoder
flow_dict = {}
inputs = torch.cat([inputs, skip_connections['skip3']], dim=1)
inputs, flow = self.decoder1(inputs)
flow_dict['flow0'] = flow.clone()
flow_predictions = []
x = torch.cat([x, skip_connections['skip3']], dim=1)
x, flow = self.decoder1(x)
flow_predictions.append(flow)

inputs = torch.cat([inputs, skip_connections['skip2']], dim=1)
inputs, flow = self.decoder2(inputs)
flow_dict['flow1'] = flow.clone()
x = torch.cat([x, skip_connections['skip2']], dim=1)
x, flow = self.decoder2(x)
flow_predictions.append(flow)

inputs = torch.cat([inputs, skip_connections['skip1']], dim=1)
inputs, flow = self.decoder3(inputs)
flow_dict['flow2'] = flow.clone()
x = torch.cat([x, skip_connections['skip1']], dim=1)
x, flow = self.decoder3(x)
flow_predictions.append(flow)

inputs = torch.cat([inputs, skip_connections['skip0']], dim=1)
inputs, flow = self.decoder4(inputs)
flow_dict['flow3'] = flow.clone()
x = torch.cat([x, skip_connections['skip0']], dim=1)
x, flow = self.decoder4(x)
flow_predictions.append(flow)

return flow

return {"flow_predictions": flow_predictions}

# if __name__ == "__main__":
# from config import configs
Expand Down Expand Up @@ -93,4 +123,4 @@ def forward(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
# a = time.time()
# (model(input_))
# b = time.time()
# print(b-a)
# print(b-a)