Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update model.py #7

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
226 changes: 103 additions & 123 deletions lfm_torch/model.py
Original file line number Diff line number Diff line change
@@ -1,168 +1,148 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
from loguru import logger
import math
from typing import Optional, Tuple
from collections import OrderedDict


class AdaptiveLinear(nn.Module):
"""
Adaptive Linear layer whose weight and bias adapt based on input.
Supports multiple adaptation methods.
"""

def __init__(
self, in_features: int, out_features: int, adapt_dim: int
self,
in_features: int,
out_features: int,
adapt_dim: int,
adapt_method: str = "add", # Adaptation method: 'add', 'multiply', 'gate'
):
super(AdaptiveLinear, self).__init__()
super().__init__()
if in_features <= 0 or out_features <= 0 or adapt_dim <= 0:
raise ValueError("Dimensions must be positive integers.")
if adapt_method not in ["add", "multiply", "gate"]:
raise ValueError(f"Invalid adaptation method: {adapt_method}")

self.in_features = in_features
self.out_features = out_features
self.adapt_method = adapt_method

self.weight = nn.Parameter(
torch.randn(out_features, in_features)
)
self.bias = nn.Parameter(torch.randn(out_features))
self.weight = nn.Parameter(torch.empty(out_features, in_features).kaiming_uniform_(a=math.sqrt(5))) # Kaiming init
self.bias = nn.Parameter(torch.zeros(out_features))

# Linear transformation for adapting the weight based on input
self.adapt = nn.Linear(adapt_dim, out_features * in_features)

def forward(
self, x: torch.Tensor, adapt_input: torch.Tensor
) -> torch.Tensor:
adapt_weight = self.adapt(adapt_input).view(
self.out_features, self.in_features
)
weight = self.weight + adapt_weight
return F.linear(x, weight, self.bias)
if self.adapt_method == "gate":
self.gate = nn.Linear(adapt_dim, out_features * in_features)

def forward(self, x: torch.Tensor, adapt_input: torch.Tensor) -> torch.Tensor:
if x.shape[-1] != self.in_features:
raise ValueError(f"Input tensor has incorrect number of features. Expected {self.in_features}, got {x.shape[-1]}.")
adapt_weight = self.adapt(adapt_input).view(self.out_features, self.in_features)

class TokenMixing(nn.Module):
"""
Token mixing layer that performs token-wise interactions using adaptive linear layers.
Operates across the sequence dimension (sequence_length).
"""
if self.adapt_method == "add":
weight = self.weight + adapt_weight
elif self.adapt_method == "multiply":
weight = self.weight * (adapt_weight + 1)
elif self.adapt_method == "gate":
gate = torch.sigmoid(self.gate(adapt_input)).view(self.out_features, self.in_features)
weight = self.weight * gate + adapt_weight * (1 - gate)

def __init__(self, token_dim: int, adapt_dim: int):
super(TokenMixing, self).__init__()
self.token_mixing = AdaptiveLinear(
token_dim, token_dim, adapt_dim
)
return F.linear(x, weight, self.bias)

def forward(
self, x: torch.Tensor, adapt_input: torch.Tensor
) -> torch.Tensor:
# x: [batch_size, sequence_length, embedding_dim]
batch_size, seq_length, embed_dim = x.shape
x = x.view(
batch_size * seq_length, embed_dim
) # Flatten sequence for linear transformation
x_mixed = self.token_mixing(x, adapt_input)
return x_mixed.view(batch_size, seq_length, embed_dim)

class TokenMixing(nn.Module):
def __init__(self, token_dim: int, adapt_dim: int, dropout_rate: float = 0.1):
super().__init__()
self.norm = nn.LayerNorm(token_dim)
self.linear1 = nn.Linear(token_dim, token_dim)
self.linear2 = nn.Linear(token_dim, token_dim)
self.dropout = nn.Dropout(dropout_rate)

class ChannelMixing(nn.Module):
"""
Channel mixing layer that performs cross-channel interactions using adaptive linear layers.
Operates across the embedding dimension (embedding_dim).
"""
self.adapt_linear = AdaptiveLinear(token_dim, token_dim, adapt_dim, adapt_method="add")

def __init__(self, channel_dim: int, adapt_dim: int):
super(ChannelMixing, self).__init__()
self.channel_mixing = AdaptiveLinear(
channel_dim, channel_dim, adapt_dim
)

def forward(
self, x: torch.Tensor, adapt_input: torch.Tensor
) -> torch.Tensor:
# x: [batch_size, sequence_length, embedding_dim]
return self.channel_mixing(x, adapt_input)
def forward(self, x: torch.Tensor, adapt_input: torch.Tensor) -> torch.Tensor:
x = self.norm(x)
x = self.linear1(x)
x = self.adapt_linear(x, adapt_input)
x = F.gelu(x)
x = self.dropout(x)
x = self.linear2(x)
return x


class MixtureOfExperts(nn.Module):
"""
Mixture of Experts (MoE) module that dynamically selects experts based on input.
Operates after channel and token mixing.
"""
class ChannelMixing(nn.Module):
def __init__(self, channel_dim: int, adapt_dim: int, dropout_rate: float = 0.1):
super().__init__()
self.norm = nn.LayerNorm(channel_dim)
self.linear1 = nn.Linear(channel_dim, channel_dim)
self.linear2 = nn.Linear(channel_dim, channel_dim)
self.dropout = nn.Dropout(dropout_rate)
self.adapt_linear = AdaptiveLinear(channel_dim, channel_dim, adapt_dim, adapt_method="add")

def forward(self, x: torch.Tensor, adapt_input: torch.Tensor) -> torch.Tensor:
x = self.norm(x)
x = self.linear1(x)
x = self.adapt_linear(x, adapt_input)
x = F.gelu(x)
x = self.dropout(x)
x = self.linear2(x)
return x


class MoE(nn.Module):
def __init__(self, input_dim: int, expert_dim: int, num_experts: int, dropout_rate: float = 0.1):
super().__init__()
self.experts = nn.ModuleList([nn.Linear(input_dim, expert_dim) for _ in range(num_experts)])
self.gate = nn.Linear(input_dim, num_experts)
self.dropout = nn.Dropout(dropout_rate)

def __init__(
self, expert_dim: int, num_experts: int, adapt_dim: int
):
super(MixtureOfExperts, self).__init__()
self.experts = nn.ModuleList(
[
AdaptiveLinear(expert_dim, expert_dim, adapt_dim)
for _ in range(num_experts)
]
)
self.gating = nn.Linear(adapt_dim, num_experts)

def forward(
self, x: torch.Tensor, adapt_input: torch.Tensor
) -> torch.Tensor:
gate_scores = F.softmax(self.gating(adapt_input), dim=-1)
output = sum(
gate_scores[:, i].unsqueeze(1) * expert(x, adapt_input)
for i, expert in enumerate(self.experts)
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
gate_weights = F.softmax(self.gate(x), dim=-1)
expert_outputs = [expert(x) for expert in self.experts]
output = torch.stack(expert_outputs, dim=-2)
output = torch.sum(output * gate_weights.unsqueeze(-1), dim=-2)
output = self.dropout(output)
return output


class LFModel(nn.Module):
"""
Custom LF Model architecture combining token mixing, channel mixing, and MoE.
Accepts 3D input tensor: [batch_size, sequence_length, embedding_dim].
"""
ACTIVATION_FUNCTIONS = OrderedDict([
("relu", F.relu),
("gelu", F.gelu),
("swish", lambda x: x * torch.sigmoid(x))
])

def __init__(
self,
input_dim: int,
token_dim: int,
channel_dim: int,
expert_dim: int,
adapt_dim: int,
num_experts: int,
dropout_rate: float = 0.1,
adapt_method: str = "add",
activation_function: str = "relu"
):
super(LFModel, self).__init__()
self.featurizer = nn.Linear(token_dim, adapt_dim)
self.token_mixer = TokenMixing(token_dim, adapt_dim)
self.channel_mixer = ChannelMixing(channel_dim, adapt_dim)
self.moe = MixtureOfExperts(
expert_dim, num_experts, adapt_dim
)
self.output_layer = nn.Linear(expert_dim, token_dim)
super().__init__()
if activation_function not in self.ACTIVATION_FUNCTIONS:
raise ValueError(f"Invalid activation function: {activation_function}")

def forward(self, x: torch.Tensor) -> torch.Tensor:
logger.info("Input shape: {}", x.shape)

# Featurization stage
batch_size, seq_length, embed_dim = x.shape
adapt_input = self.featurizer(
x.mean(dim=1)
) # Aggregate across sequence for adaptation
logger.info(
"Featurization complete. Shape: {}", adapt_input.shape
)

# Token Mixing
token_mixed = self.token_mixer(x, adapt_input)
logger.info(
"Token mixing complete. Shape: {}", token_mixed.shape
)

# Channel Mixing
channel_mixed = self.channel_mixer(token_mixed, adapt_input)
logger.info(
"Channel mixing complete. Shape: {}", channel_mixed.shape
)

# Mixture of Experts
expert_output = self.moe(channel_mixed, adapt_input)
logger.info(
"Mixture of Experts complete. Shape: {}",
expert_output.shape,
)

# Final Output
output = self.output_layer(expert_output)
logger.info("Output shape: {}", output.shape)
return output
self.activation_function = self.ACTIVATION_FUNCTIONS[activation_function]

self.token_mixing = TokenMixing(token_dim, adapt_dim, dropout_rate)
self.channel_mixing = ChannelMixing(channel_dim, adapt_dim, dropout_rate)
self.moe = MoE(input_dim, expert_dim, num_experts, dropout_rate)
self.output_layer = nn.Linear(expert_dim, input_dim)


def forward(self, x: torch.Tensor, adapt_input: torch.Tensor) -> torch.Tensor:
x = self.token_mixing(x, adapt_input)
x = self.channel_mixing(x, adapt_input)
x = self.moe(x)
output = self.output_layer(x)
return output
Loading