From 9b1a735ab84774c8c8714766f239cef06ec2d0bf Mon Sep 17 00:00:00 2001 From: kirill670 <51964569+kirill670@users.noreply.github.com> Date: Thu, 3 Oct 2024 12:13:37 +0300 Subject: [PATCH] Update model.py --- lfm_torch/model.py | 226 +++++++++++++++++++++------------------------ 1 file changed, 103 insertions(+), 123 deletions(-) diff --git a/lfm_torch/model.py b/lfm_torch/model.py index fc16a2a..dacd2d3 100644 --- a/lfm_torch/model.py +++ b/lfm_torch/model.py @@ -1,168 +1,148 @@ import torch import torch.nn as nn import torch.nn.functional as F -from loguru import logger +import math from typing import Optional, Tuple +from collections import OrderedDict class AdaptiveLinear(nn.Module): """ Adaptive Linear layer whose weight and bias adapt based on input. + Supports multiple adaptation methods. """ def __init__( - self, in_features: int, out_features: int, adapt_dim: int + self, + in_features: int, + out_features: int, + adapt_dim: int, + adapt_method: str = "add", # Adaptation method: 'add', 'multiply', 'gate' ): - super(AdaptiveLinear, self).__init__() + super().__init__() + if in_features <= 0 or out_features <= 0 or adapt_dim <= 0: + raise ValueError("Dimensions must be positive integers.") + if adapt_method not in ["add", "multiply", "gate"]: + raise ValueError(f"Invalid adaptation method: {adapt_method}") + self.in_features = in_features self.out_features = out_features + self.adapt_method = adapt_method - self.weight = nn.Parameter( - torch.randn(out_features, in_features) - ) - self.bias = nn.Parameter(torch.randn(out_features)) + self.weight = nn.Parameter(torch.empty(out_features, in_features).kaiming_uniform_(a=math.sqrt(5))) # Kaiming init + self.bias = nn.Parameter(torch.zeros(out_features)) - # Linear transformation for adapting the weight based on input self.adapt = nn.Linear(adapt_dim, out_features * in_features) - def forward( - self, x: torch.Tensor, adapt_input: torch.Tensor - ) -> torch.Tensor: - adapt_weight = self.adapt(adapt_input).view( - self.out_features, self.in_features - ) - weight = self.weight + adapt_weight - return F.linear(x, weight, self.bias) + if self.adapt_method == "gate": + self.gate = nn.Linear(adapt_dim, out_features * in_features) + def forward(self, x: torch.Tensor, adapt_input: torch.Tensor) -> torch.Tensor: + if x.shape[-1] != self.in_features: + raise ValueError(f"Input tensor has incorrect number of features. Expected {self.in_features}, got {x.shape[-1]}.") + adapt_weight = self.adapt(adapt_input).view(self.out_features, self.in_features) -class TokenMixing(nn.Module): - """ - Token mixing layer that performs token-wise interactions using adaptive linear layers. - Operates across the sequence dimension (sequence_length). - """ + if self.adapt_method == "add": + weight = self.weight + adapt_weight + elif self.adapt_method == "multiply": + weight = self.weight * (adapt_weight + 1) + elif self.adapt_method == "gate": + gate = torch.sigmoid(self.gate(adapt_input)).view(self.out_features, self.in_features) + weight = self.weight * gate + adapt_weight * (1 - gate) - def __init__(self, token_dim: int, adapt_dim: int): - super(TokenMixing, self).__init__() - self.token_mixing = AdaptiveLinear( - token_dim, token_dim, adapt_dim - ) + return F.linear(x, weight, self.bias) - def forward( - self, x: torch.Tensor, adapt_input: torch.Tensor - ) -> torch.Tensor: - # x: [batch_size, sequence_length, embedding_dim] - batch_size, seq_length, embed_dim = x.shape - x = x.view( - batch_size * seq_length, embed_dim - ) # Flatten sequence for linear transformation - x_mixed = self.token_mixing(x, adapt_input) - return x_mixed.view(batch_size, seq_length, embed_dim) +class TokenMixing(nn.Module): + def __init__(self, token_dim: int, adapt_dim: int, dropout_rate: float = 0.1): + super().__init__() + self.norm = nn.LayerNorm(token_dim) + self.linear1 = nn.Linear(token_dim, token_dim) + self.linear2 = nn.Linear(token_dim, token_dim) + self.dropout = nn.Dropout(dropout_rate) -class ChannelMixing(nn.Module): - """ - Channel mixing layer that performs cross-channel interactions using adaptive linear layers. - Operates across the embedding dimension (embedding_dim). - """ + self.adapt_linear = AdaptiveLinear(token_dim, token_dim, adapt_dim, adapt_method="add") - def __init__(self, channel_dim: int, adapt_dim: int): - super(ChannelMixing, self).__init__() - self.channel_mixing = AdaptiveLinear( - channel_dim, channel_dim, adapt_dim - ) - def forward( - self, x: torch.Tensor, adapt_input: torch.Tensor - ) -> torch.Tensor: - # x: [batch_size, sequence_length, embedding_dim] - return self.channel_mixing(x, adapt_input) + def forward(self, x: torch.Tensor, adapt_input: torch.Tensor) -> torch.Tensor: + x = self.norm(x) + x = self.linear1(x) + x = self.adapt_linear(x, adapt_input) + x = F.gelu(x) + x = self.dropout(x) + x = self.linear2(x) + return x -class MixtureOfExperts(nn.Module): - """ - Mixture of Experts (MoE) module that dynamically selects experts based on input. - Operates after channel and token mixing. - """ +class ChannelMixing(nn.Module): + def __init__(self, channel_dim: int, adapt_dim: int, dropout_rate: float = 0.1): + super().__init__() + self.norm = nn.LayerNorm(channel_dim) + self.linear1 = nn.Linear(channel_dim, channel_dim) + self.linear2 = nn.Linear(channel_dim, channel_dim) + self.dropout = nn.Dropout(dropout_rate) + self.adapt_linear = AdaptiveLinear(channel_dim, channel_dim, adapt_dim, adapt_method="add") + + def forward(self, x: torch.Tensor, adapt_input: torch.Tensor) -> torch.Tensor: + x = self.norm(x) + x = self.linear1(x) + x = self.adapt_linear(x, adapt_input) + x = F.gelu(x) + x = self.dropout(x) + x = self.linear2(x) + return x + + +class MoE(nn.Module): + def __init__(self, input_dim: int, expert_dim: int, num_experts: int, dropout_rate: float = 0.1): + super().__init__() + self.experts = nn.ModuleList([nn.Linear(input_dim, expert_dim) for _ in range(num_experts)]) + self.gate = nn.Linear(input_dim, num_experts) + self.dropout = nn.Dropout(dropout_rate) - def __init__( - self, expert_dim: int, num_experts: int, adapt_dim: int - ): - super(MixtureOfExperts, self).__init__() - self.experts = nn.ModuleList( - [ - AdaptiveLinear(expert_dim, expert_dim, adapt_dim) - for _ in range(num_experts) - ] - ) - self.gating = nn.Linear(adapt_dim, num_experts) - - def forward( - self, x: torch.Tensor, adapt_input: torch.Tensor - ) -> torch.Tensor: - gate_scores = F.softmax(self.gating(adapt_input), dim=-1) - output = sum( - gate_scores[:, i].unsqueeze(1) * expert(x, adapt_input) - for i, expert in enumerate(self.experts) - ) + def forward(self, x: torch.Tensor) -> torch.Tensor: + gate_weights = F.softmax(self.gate(x), dim=-1) + expert_outputs = [expert(x) for expert in self.experts] + output = torch.stack(expert_outputs, dim=-2) + output = torch.sum(output * gate_weights.unsqueeze(-1), dim=-2) + output = self.dropout(output) return output class LFModel(nn.Module): - """ - Custom LF Model architecture combining token mixing, channel mixing, and MoE. - Accepts 3D input tensor: [batch_size, sequence_length, embedding_dim]. - """ + ACTIVATION_FUNCTIONS = OrderedDict([ + ("relu", F.relu), + ("gelu", F.gelu), + ("swish", lambda x: x * torch.sigmoid(x)) + ]) def __init__( self, + input_dim: int, token_dim: int, channel_dim: int, expert_dim: int, adapt_dim: int, num_experts: int, + dropout_rate: float = 0.1, + adapt_method: str = "add", + activation_function: str = "relu" ): - super(LFModel, self).__init__() - self.featurizer = nn.Linear(token_dim, adapt_dim) - self.token_mixer = TokenMixing(token_dim, adapt_dim) - self.channel_mixer = ChannelMixing(channel_dim, adapt_dim) - self.moe = MixtureOfExperts( - expert_dim, num_experts, adapt_dim - ) - self.output_layer = nn.Linear(expert_dim, token_dim) + super().__init__() + if activation_function not in self.ACTIVATION_FUNCTIONS: + raise ValueError(f"Invalid activation function: {activation_function}") - def forward(self, x: torch.Tensor) -> torch.Tensor: - logger.info("Input shape: {}", x.shape) - - # Featurization stage - batch_size, seq_length, embed_dim = x.shape - adapt_input = self.featurizer( - x.mean(dim=1) - ) # Aggregate across sequence for adaptation - logger.info( - "Featurization complete. Shape: {}", adapt_input.shape - ) - - # Token Mixing - token_mixed = self.token_mixer(x, adapt_input) - logger.info( - "Token mixing complete. Shape: {}", token_mixed.shape - ) - - # Channel Mixing - channel_mixed = self.channel_mixer(token_mixed, adapt_input) - logger.info( - "Channel mixing complete. Shape: {}", channel_mixed.shape - ) - - # Mixture of Experts - expert_output = self.moe(channel_mixed, adapt_input) - logger.info( - "Mixture of Experts complete. Shape: {}", - expert_output.shape, - ) - - # Final Output - output = self.output_layer(expert_output) - logger.info("Output shape: {}", output.shape) - return output + self.activation_function = self.ACTIVATION_FUNCTIONS[activation_function] + self.token_mixing = TokenMixing(token_dim, adapt_dim, dropout_rate) + self.channel_mixing = ChannelMixing(channel_dim, adapt_dim, dropout_rate) + self.moe = MoE(input_dim, expert_dim, num_experts, dropout_rate) + self.output_layer = nn.Linear(expert_dim, input_dim) + + + def forward(self, x: torch.Tensor, adapt_input: torch.Tensor) -> torch.Tensor: + x = self.token_mixing(x, adapt_input) + x = self.channel_mixing(x, adapt_input) + x = self.moe(x) + output = self.output_layer(x) + return output