diff --git a/blackbirds/networks.py b/blackbirds/networks.py index abbf593..e67356f 100644 --- a/blackbirds/networks.py +++ b/blackbirds/networks.py @@ -7,19 +7,19 @@ def count_pars(net): class MLP(nn.Module): - def __init__(self, input_dim=64, hidden_dims=[32, 32], output_dim=1, device="cpu"): + def __init__(self, input_dim=64, hidden_dims=[32, 32], output_dim=1): super().__init__() self.relu = nn.ReLU() self._layers = nn.Sequential( - nn.Linear(input_dim, hidden_dims[0], device=device), self.relu + nn.Linear(input_dim, hidden_dims[0]), self.relu ) for i in range(len(hidden_dims) - 1): self._layers.append( - nn.Linear(hidden_dims[i], hidden_dims[i + 1], device=device) + nn.Linear(hidden_dims[i], hidden_dims[i + 1]) ) self._layers.append(self.relu) - self._layers.append(nn.Linear(hidden_dims[-1], output_dim, device=device)) + self._layers.append(nn.Linear(hidden_dims[-1], output_dim)) def forward(self, x): x = self._layers(x) @@ -40,13 +40,12 @@ def __init__( final_ff=nn.Identity(), nonlinearity="tanh", flavour="gru", - device="cpu", ): super().__init__() if flavour == "gru": self._rnn = nn.GRU( - input_size, hidden_size, num_layers, batch_first=True, device=device + input_size, hidden_size, num_layers, batch_first=True ) else: self._rnn = nn.RNN( @@ -55,7 +54,6 @@ def __init__( num_layers, nonlinearity=nonlinearity, batch_first=True, - device=device, ) self._fff = final_ff self._rnn_n_pars = count_pars(self._rnn) @@ -82,7 +80,6 @@ def __init__( conv_kernel_size=4, pool_kernel_size=2, final_ff=[32, 16], - device="cpu", ): assert len(final_ff) == 2 @@ -94,7 +91,7 @@ def __init__( # pool2_out_dim = conv2_out_dim - pool_kernel_size + 1 self.conv1 = nn.Conv2d( - n_channels, hidden_layer_channels, conv_kernel_size, device=device + n_channels, hidden_layer_channels, conv_kernel_size ) self.pool = nn.MaxPool2d(pool_kernel_size, pool_kernel_size) # self.conv2 = nn.Conv2d(hidden_layer_channels, @@ -103,9 +100,8 @@ def __init__( self.fc1 = nn.Linear( hidden_layer_channels * pool1_out_dim**2, # (pool2_out_dim)**2, final_ff[0], - device=device, ) - self.fc2 = nn.Linear(final_ff[0], final_ff[1], device=device) + self.fc2 = nn.Linear(final_ff[0], final_ff[1]) self.relu = nn.ReLU() def forward(self, x):