Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Parameter input rotary-freq #263

Open
wants to merge 6 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions open_lm/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,15 +97,16 @@ class Params:
moe_top_k: int = 2
moe_freq: int = 0
positional_embedding_type: str = "rotary"
rotary_freq: float = 10000
ffn_type: str = "swiglu"


def get_pos_embed(args: Params):
head_dim = args.dim // args.n_heads
if args.positional_embedding_type == "rotary":
return RotaryWithCast(head_dim, args.seq_len)
return RotaryWithCast(head_dim, args.seq_len, args.rotary_freq)
elif args.positional_embedding_type == "llama_rotary":
return LLaMARotaryWithCast(head_dim, args.n_heads, args.seq_len)
return LLaMARotaryWithCast(head_dim, args.n_heads, args.seq_len, args.rotary_freq)
elif args.positional_embedding_type == "head_rotary":
return HeadRotaryWithCast(head_dim, args.seq_len)
elif args.positional_embedding_type == "none":
Expand Down Expand Up @@ -461,6 +462,7 @@ def create_params(args):
),
apply_qk_norm=cfg.get("qk_norm", args.qk_norm),
positional_embedding_type=cfg.get("positional_embedding_type", args.positional_embedding_type),
rotary_freq=cfg.get("rotary_freq", args.rotary_freq),
ffn_type=cfg.get("ffn_type", args.ffn_type),
moe_num_experts=cfg.get("moe_num_experts", args.moe_num_experts),
moe_loss_weight=cfg.get("moe_loss_weight", args.moe_loss_weight),
Expand Down
6 changes: 6 additions & 0 deletions open_lm/params.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,12 @@ def add_model_args(parser):
default="rotary",
help="Type of positional embedding to use. This might be overridden by the model config.",
)
parser.add_argument(
"--rotary-freq",
type=float,
default=10000,
help="Frequency for rotary positional embeddings. This might be overridden by the model config.",
)
parser.add_argument(
"--moe-freq",
type=int,
Expand Down
3 changes: 2 additions & 1 deletion open_lm/positional_embedding/llama_rotary.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,14 +112,15 @@ class LLaMARotaryEmbedding(torch.nn.Module):
(it does not create the embedding dimension) and will likely be picked up (imported) on a ad-hoc basis
"""

def __init__(self, head_dim: int, num_heads: int, seq_len: int, *_, **__):
def __init__(self, head_dim: int, num_heads: int, seq_len: int, frequency: float = 10000, *_, **__):
super().__init__()
# Generate and save the inverse frequency buffer (non trainable)
self.freqs_cis = precompute_freqs_cis(
# Note that self.params.max_seq_len is multiplied by 2 because the token limit for the Llama 2 generation of models is 4096.
# Adding this multiplier instead of using 4096 directly allows for dynamism of token lengths while training or fine-tuning.
head_dim,
seq_len * 2,
theta=frequency,
)

def reset_parameters(self):
Expand Down
17 changes: 10 additions & 7 deletions open_lm/positional_embedding/rotary.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,22 +44,24 @@ class RotaryEmbedding(torch.nn.Module):
(it does not create the embedding dimension) and will likely be picked up (imported) on a ad-hoc basis
"""

def __init__(self, dim_model: int, seq_len: int, *_, **__):
def __init__(self, dim_model: int, seq_len: int, frequency: float = 10000, *_, **__):
super().__init__()
# Generate and save the inverse frequency buffer (non trainable)
self.dim_model = dim_model
self.register_buffer("inv_freq", torch.zeros(self.dim_model // 2))

self._cos_cached = None
self._sin_cached = None
self._seq_len_cached = 0
self.seq_len = seq_len
self.reset_parameters()

def reset_parameters(self):
self.inv_freq = 1.0 / (10000 ** (torch.arange(0, self.dim_model, 2).float() / self.dim_model))
self.frequency = frequency
self._update_cos_sin_tables(self.seq_len)

def load_state_dict(self, state_dict, strict=True):
# The state dict is not used, as the parameters are not trainable
# Previous versions had an inv_freq buffer, we don't need to load it
# This is kept for compatibility with the previous version
pass

def _update_cos_sin_tables(self, seq_len: int = None, device: torch.device = None, dtype: torch.dtype = None):
# If no seq_len is provided, use the cached one
# If the seq_len is smaller than the cached one it is included in the cached one so no need to update
Expand All @@ -70,8 +72,9 @@ def _update_cos_sin_tables(self, seq_len: int = None, device: torch.device = Non
# or if we're on a new device (possibly due to tracing for instance)
if seq_len > self._seq_len_cached or self._cos_cached.device != device or self._cos_cached.dtype != dtype:
self._seq_len_cached = seq_len
inv_freq = 1.0 / (self.frequency ** (torch.arange(0, self.dim_model, 2).float() / self.dim_model))
t = torch.arange(seq_len, device=device, dtype=torch.float32)
freqs = torch.einsum("i,j->ij", t, self.inv_freq.to(dtype))
freqs = torch.einsum("i,j->ij", t, inv_freq.to(device=device, dtype=dtype))
emb = torch.cat((freqs, freqs), dim=-1).to(device)

self._cos_cached = emb.cos()[None, :, None, :].to(dtype)
Expand Down
Binary file added tests/assets/rotary1_old.pt
Binary file not shown.
Binary file added tests/assets/rotary2_old.pt
Binary file not shown.
87 changes: 87 additions & 0 deletions tests/test_rotary_freq.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
import torch
import pytest
from open_lm.positional_embedding.rotary import RotaryEmbedding # replace 'your_module' with the actual module name


@pytest.fixture
def create_rotary_embedding():
def _create_rotary_embedding(dim_model, seq_len, frequency):
return RotaryEmbedding(dim_model, seq_len, frequency)

return _create_rotary_embedding


def test_frequency_input(create_rotary_embedding):
dim_model = 32
seq_len = 64

# Create two rotary embeddings with different frequencies
freq1 = 10000
freq2 = 20000
rotary1 = create_rotary_embedding(dim_model, seq_len, freq1)
rotary2 = create_rotary_embedding(dim_model, seq_len, freq2)

# Generate some dummy data
q = torch.randn(1, seq_len, dim_model)
k = torch.randn(1, seq_len, dim_model)

# Forward pass with different frequencies
q1, k1 = rotary1(q, k)
q2, k2 = rotary2(q, k)

# Ensure the outputs are different
assert not torch.allclose(q1, q), "The outputs should not be close"
assert not torch.allclose(k1, k), "The outputs should not be close"
assert not torch.allclose(q1, q2), "The outputs for different frequencies should not be close"
assert not torch.allclose(k1, k2), "The outputs for different frequencies should not be close"

# load the state dicts
state_dict1 = torch.load("tests/assets/rotary1_old.pt")
state_dict2 = torch.load("tests/assets/rotary2_old.pt")

# Build new rotary embeddings with exchanged frequencies
rotary1_loaded = create_rotary_embedding(dim_model, seq_len, freq2)
rotary2_loaded = create_rotary_embedding(dim_model, seq_len, freq1)

# Forward pass with loaded models
q1_loaded, k1_loaded = rotary1_loaded(q, k)
q2_loaded, k2_loaded = rotary2_loaded(q, k)

# Ensure the outputs are the same
assert torch.allclose(
q1, q2_loaded
), "The outputs should be the same for the same fequencies before loading the state dict"
assert torch.allclose(
k2, k1_loaded
), "The outputs should be the same for the same fequencies before loading the state dict"

# Assert old state dict is in the old format
assert "inv_freq" in state_dict1, "The old state dict should contain the inv_freq buffer"

# Load the state dicts
rotary1_loaded.load_state_dict(state_dict1, strict=True)
rotary2_loaded.load_state_dict(state_dict2, strict=True)

# Ensure the frequencies are not overwritten
assert rotary1_loaded.frequency == freq2, "Frequency should not be overwritten by load_state_dict"
assert rotary2_loaded.frequency == freq1, "Frequency should not be overwritten by load_state_dict"

# Forward pass with loaded models
q1_loaded, k1_loaded = rotary1_loaded(q, k)
q2_loaded, k2_loaded = rotary2_loaded(q, k)

# Ensure the outputs are the same
assert torch.allclose(
q1, q2_loaded
), "The outputs should be the same for the same fequencies after loading the state dict"
assert torch.allclose(
k2, k1_loaded
), "The outputs should be the same for the same fequencies after loading the state dict"

# Ensure the outputs are still different
assert not torch.allclose(q1_loaded, q2_loaded), "The outputs for different frequencies should not be close"
assert not torch.allclose(k1_loaded, k2_loaded), "The outputs for different frequencies should not be close"


if __name__ == "__main__":
pytest.main([__file__])
Loading