Skip to content

Commit

Permalink
networks for gpu
Browse files Browse the repository at this point in the history
  • Loading branch information
arnauqb committed Nov 16, 2023
1 parent 24a58c3 commit 081290f
Showing 1 changed file with 7 additions and 11 deletions.
18 changes: 7 additions & 11 deletions blackbirds/networks.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,19 +7,19 @@ def count_pars(net):


class MLP(nn.Module):
def __init__(self, input_dim=64, hidden_dims=[32, 32], output_dim=1, device="cpu"):
def __init__(self, input_dim=64, hidden_dims=[32, 32], output_dim=1):
super().__init__()

self.relu = nn.ReLU()
self._layers = nn.Sequential(
nn.Linear(input_dim, hidden_dims[0], device=device), self.relu
nn.Linear(input_dim, hidden_dims[0]), self.relu
)
for i in range(len(hidden_dims) - 1):
self._layers.append(
nn.Linear(hidden_dims[i], hidden_dims[i + 1], device=device)
nn.Linear(hidden_dims[i], hidden_dims[i + 1])
)
self._layers.append(self.relu)
self._layers.append(nn.Linear(hidden_dims[-1], output_dim, device=device))
self._layers.append(nn.Linear(hidden_dims[-1], output_dim))

def forward(self, x):
x = self._layers(x)
Expand All @@ -40,13 +40,12 @@ def __init__(
final_ff=nn.Identity(),
nonlinearity="tanh",
flavour="gru",
device="cpu",
):
super().__init__()

if flavour == "gru":
self._rnn = nn.GRU(
input_size, hidden_size, num_layers, batch_first=True, device=device
input_size, hidden_size, num_layers, batch_first=True
)
else:
self._rnn = nn.RNN(
Expand All @@ -55,7 +54,6 @@ def __init__(
num_layers,
nonlinearity=nonlinearity,
batch_first=True,
device=device,
)
self._fff = final_ff
self._rnn_n_pars = count_pars(self._rnn)
Expand All @@ -82,7 +80,6 @@ def __init__(
conv_kernel_size=4,
pool_kernel_size=2,
final_ff=[32, 16],
device="cpu",
):
assert len(final_ff) == 2

Expand All @@ -94,7 +91,7 @@ def __init__(
# pool2_out_dim = conv2_out_dim - pool_kernel_size + 1

self.conv1 = nn.Conv2d(
n_channels, hidden_layer_channels, conv_kernel_size, device=device
n_channels, hidden_layer_channels, conv_kernel_size
)
self.pool = nn.MaxPool2d(pool_kernel_size, pool_kernel_size)
# self.conv2 = nn.Conv2d(hidden_layer_channels,
Expand All @@ -103,9 +100,8 @@ def __init__(
self.fc1 = nn.Linear(
hidden_layer_channels * pool1_out_dim**2, # (pool2_out_dim)**2,
final_ff[0],
device=device,
)
self.fc2 = nn.Linear(final_ff[0], final_ff[1], device=device)
self.fc2 = nn.Linear(final_ff[0], final_ff[1])
self.relu = nn.ReLU()

def forward(self, x):
Expand Down

0 comments on commit 081290f

Please sign in to comment.