Skip to content

Commit

Permalink
Merge pull request #93 from microsoft/1MoE
Browse files Browse the repository at this point in the history
PEER MoE
  • Loading branch information
oleksost committed Aug 23, 2024
2 parents 5b1a07a + 12ec534 commit 0be1a67
Show file tree
Hide file tree
Showing 16 changed files with 571 additions and 92 deletions.
7 changes: 7 additions & 0 deletions mttl/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -539,6 +539,13 @@ class MoEExpertConfig(MultiExpertConfig):
moe_ent_free_bits: float = 0.0
moe_num_experts: int = 8
init_from_scratch: bool = True
pk_use_batchnorm: bool = True
down_proj_layer: str = (
"fc1" # this is for the PEER container, it signals the names of the down and up projecting layers
)
up_proj_layer: str = (
"fc2" # this is for the PEER container, it signals the names of the down and up projecting layers
)


@dataclass
Expand Down
116 changes: 82 additions & 34 deletions mttl/models/containers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
CoalescedLoRAExpertContainer,
LoRAExpertContainer,
)
from mttl.models.containers.peer_container import PEERMLPContainer
from mttl.models.containers.selectors.base import (
Selector,
SelectorConfig,
Expand Down Expand Up @@ -52,6 +53,8 @@ def get_container_class(modifier: str):
"COALESCED_LORA_CONTAINER is not set to 1, but still using it for SkilledLoRA"
)
return CoalescedLoRAExpertContainer
elif modifier == "peer":
return PEERMLPContainer
elif modifier == "kv_adapter":
return KVExpertContainer
else:
Expand Down Expand Up @@ -138,6 +141,16 @@ def replace_selector_for_container(
"""
expert_containers = []
for _, module in dict(transformer.named_modules()).items():
if isinstance(module, ExpertContainer):
# check if the container holds the same modifier type, e.g. PEERConfig --> "peers"
for supports_config in module.__supports_configs__:
container_modifier = Modifier.get_name_by_config_class(supports_config)
# selector does not apply to this container
if not container_modifier == modifier_name:
continue
else:
expert_containers.append(module)

for _, layer in dict(module.named_children()).items():
if isinstance(layer, ExpertContainer):
# check if the container holds the same modifier type, e.g. LoRAConfig --> "lora"
Expand Down Expand Up @@ -255,6 +268,43 @@ def get_modules_to_modify_trie(transformer):
yield m_name, module


def create_modif_regex(modify_modules, modify_layers=None):
"""
Combine modify_modules and modify_layers into a single regex to keep add_expert_to_transformer slim
"""
is_set = lambda x: x is not None and x != ""

if not is_set(modify_modules) and not is_set(modify_layers):
raise ValueError(
"Neither modify_modules nor modify_layers are set, will not modify anything"
)

if is_set(modify_modules) and not is_set(modify_layers):
return modify_modules
if not is_set(modify_modules) and is_set(modify_layers):
return modify_layers

# keep backward compatibility
modules = modify_modules.split("|")
layers = modify_layers.split("|")
parts = []
for m in modules:
for l in layers:
if m == ".*":
l.replace(".*", "")
parts.append(f"{m}\\.{l}")
return "|".join(parts)


def match_modules_to_modify(transformer, modify_modules):
"""
Match modules in the transformer model based on the modify_modules regex
"""
for m_name, module in dict(transformer.named_modules()).items():
if re.fullmatch(modify_modules, m_name):
yield m_name, module


def add_expert_to_transformer(
transformer,
expert: Expert,
Expand Down Expand Up @@ -298,41 +348,39 @@ def add_expert_to_transformer(
added_layers = []
added_containers = []

for m_name, module in get_modules_to_modify_trie(transformer):
if re.fullmatch(expert_config.modify_modules, m_name):
for c_name, layer in dict(module.named_children()).items():
if re.fullmatch(expert_config.modify_layers, c_name):
total_layers += 1
layer_name = f"{m_name}.{c_name}"

if not isinstance(layer, ExpertContainer):
CONTAINER_CLASS = get_container_class(model_modifier)
expert_container = CONTAINER_CLASS(
expert_config,
layer,
lora_merge_after=(
selector_config.lora_merge_after
if selector_config
else False
),
)
expert_container.__layer_name__ = layer_name
setattr(
module,
c_name,
expert_container,
)
added_containers.append(expert_container)
else:
expert_container = layer

added_layers.append(expert_container.__layer_name__)
expert_container.add_expert(
expert,
action=action,
is_default=is_default,
)
modify_modules = create_modif_regex(
expert_config.modify_modules, expert_config.modify_layers
)
for m_name, module in match_modules_to_modify(transformer, modify_modules):
# no layers to modify, try modifying the module
total_layers += 1
module_name = f"{m_name}"

if not isinstance(module, ExpertContainer):
CONTAINER_CLASS = get_container_class(model_modifier)
expert_container = CONTAINER_CLASS(
expert_config,
module,
)
expert_container.__layer_name__ = module_name

parent_name, child_name = m_name.rsplit(".", 1)
parent_module = dict(transformer.named_modules())[parent_name]
setattr(
parent_module,
child_name,
expert_container,
)
added_containers.append(expert_container)
else:
expert_container = module

added_layers.append(expert_container.__layer_name__)
expert_container.add_expert(
expert,
action=action,
is_default=is_default,
)
### PARAM TYING ###
# Note: because experts are added into expert containers
# instead of parameter names being e.g. model.layers.4.self_attn.q_proj.lora_a,
Expand Down
5 changes: 5 additions & 0 deletions mttl/models/containers/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,11 @@
from mttl.models.modifiers.modify_model import get_modifier_name


class ContainerFullException(Exception):
def __init__(self):
super().__init__("Container is full. Cannot add more experts.")


class Container(abc.ABC):
@abc.abstractmethod
def __getitem__(self, key):
Expand Down
121 changes: 121 additions & 0 deletions mttl/models/containers/peer_container.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,121 @@
from types import SimpleNamespace

import torch
from pyparsing import Union
from torch import Tensor, nn

from mttl.logging import warn_once
from mttl.models.containers.base import ContainerFullException, ExpertContainer
from mttl.models.containers.selectors.product_key import PKSelectorConfig, PKSSelector
from mttl.models.containers.selectors.selector_output import (
BatchExpertsAndWeightsSelectorOutput,
BatchExpertsSelectorOutput,
BatchSequenceExpertsAndWeightsSelectorOutput,
ExpertsAndWeightsSelectorOutput,
MultiheadBatchSequenceExpertsAndWeightsSelectorOutput,
SelectorOutput,
)
from mttl.models.library.expert import Expert
from mttl.models.modifiers.base import (
MergeableModifierMixin,
Modifier,
ModifierConfig,
ModifyMixin,
)
from mttl.models.modifiers.mlp import PEERConfig, PEERModifier
from mttl.models.modifiers.modify_model import get_modifier_name

# diff architectures name those layers differently
DOWN_NAMES = ["fc1", "c_fc"]
UP_NAMES = ["fc2", "c_proj"]


class PEERMLPContainer(ExpertContainer):
"""
PEER layer from Mixture of A Million Experts (https://arxiv.org/pdf/2407.04153)
Right now it assumes that it receives a module -- an MLP block, that has attributes fc1 and fc2.
It upcycles the base model. Yet, for now the experts are innitialized randomly.
"""

__supports_configs__ = [PEERConfig]

def __init__(
self,
config: PEERConfig,
module,
selector_config: PKSelectorConfig = None,
**kwargs,
):
super().__init__(config, module)
self.num_experts = 0
down_names = DOWN_NAMES + [
config.down_proj_layer
] # names of the up and down projection layers in the MLP block
up_names = UP_NAMES + [
config.up_proj_layer
] # needed to infer the dimentions of the MLP block

assert any(
hasattr(module, name) for name in down_names + up_names
), "Module must have fc1 and fc2 attributes, this is only applicable to MLP block for"
n_idx = [i for i, name in enumerate(down_names) if hasattr(module, name)][0]

self.activation = module.act
self.input_dim = getattr(module, down_names[n_idx]).in_features
self.output_dim = getattr(module, up_names[n_idx]).out_features
if selector_config:
self.selector = PKSSelector(selector_config, in_d=self.input_dim)
# to enable selector instantiation without having to cary the original module's weights
self.dtype = next(self.layer.parameters()).dtype

self.layer = nn.Identity()
self.layer.in_features = self.input_dim
self.experts = PEERModifier(config)

def initialize_experts(self, expert_config: PEERConfig) -> None:
self.num_experts = expert_config.moe_num_experts
assert (
self.num_experts**0.5
).is_integer(), "Number of experts must be a square number"
self.peer_weight_down_embed = nn.Embedding(
num_embeddings=self.num_experts,
embedding_dim=self.input_dim,
dtype=self.dtype,
)
self.peer_weight_up_embed = nn.Embedding(
num_embeddings=self.num_experts,
embedding_dim=self.output_dim,
dtype=self.dtype,
)

def forward(self, input, **kwargs):
routing: MultiheadBatchSequenceExpertsAndWeightsSelectorOutput = self.selector(
input
)
indices, scores = (
routing.experts,
routing.weights,
) # both shape b, s, heads, experts

w_down = self.peer_weight_down_embed(indices) # b, s, heads, experts, input_dim
w_up = self.peer_weight_up_embed(indices) # b, s, heads, experts, output_dim

x = torch.einsum("bsd,bshed->bshe", input, w_down) # b, s, heads, experts
x = self.activation(x)
x *= scores
x = torch.einsum("bshe,bshed->bsd", x, w_up)
return x

def add_expert(self, expert: Expert, **kwargs) -> None:
return self.on_add_expert(expert, **kwargs)

def on_add_expert(self, expert: Expert, **kwargs) -> None:
expert_config: PEERConfig = expert.expert_config
if self.num_experts == expert_config.moe_num_experts:
raise ContainerFullException()
self.initialize_experts(expert_config)

def __getitem__(self, key):
pass
1 change: 1 addition & 0 deletions mttl/models/containers/selectors/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
PolySelector,
PolySelectorConfig,
)
from mttl.models.containers.selectors.product_key import PKSelectorConfig, PKSSelector
from mttl.models.containers.selectors.selector_output import (
BatchExpertsAndWeightsSelectorOutput,
BatchExpertsSelectorOutput,
Expand Down
Loading

0 comments on commit 0be1a67

Please sign in to comment.