Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

TinyLLaMa data loaders and a training script for frontier with exact model specs #18

Open
wants to merge 25 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
94924de
integrate tinyllama dataloaders and model args
siddharth9820 Jan 25, 2024
ac0debb
add the litgpt dataloader
siddharth9820 Jan 25, 2024
f99a7e5
test scripts on 128 GPUs
siddharth9820 Jan 27, 2024
4105129
remove harcoded output file from script
siddharth9820 Jan 27, 2024
a43c4d0
change default dataset location name
siddharth9820 Jan 29, 2024
522320a
make workers 0 and get rank from slurm
siddharth9820 Jan 29, 2024
90f4e9f
remove MPI dependency for real
siddharth9820 Jan 29, 2024
1ea6b8e
init torch directly and remove mpi4py import
siddharth9820 Jan 29, 2024
a2feee1
add venv baed setup for megatron axonn
siddharth9820 Jan 29, 2024
f0f89d0
add instructions for frontier
siddharth9820 Jan 29, 2024
8ba0bd5
add branch name in README
siddharth9820 Jan 29, 2024
7d256ed
add --lit-gpt-data-path as an argument
siddharth9820 Jan 29, 2024
1b0f952
update README
siddharth9820 Jan 29, 2024
2b6ddc4
update README
siddharth9820 Jan 29, 2024
0582227
minor
siddharth9820 Jan 29, 2024
bc55cfb
updated hparam (#21)
neelsjain Feb 9, 2024
76efc9a
sanity check on rewinding dataloaders
siddharth9820 Feb 9, 2024
88e859f
move venv to burst buffer
siddharth9820 Feb 9, 2024
2c58c30
switch to llama2 tokenizer
siddharth9820 Feb 9, 2024
36f4bdb
disable bias and untie weights
siddharth9820 Feb 9, 2024
7da8f85
add tinyllama init method and make pytorch optim default
siddharth9820 Feb 12, 2024
715aef8
add comment about seed
siddharth9820 Feb 13, 2024
e610c94
make dataloader exactly like neel's litgpt
siddharth9820 Feb 13, 2024
0881c29
option to disable activation checkpointing
siddharth9820 Feb 16, 2024
866e58e
latest changes
siddharth9820 Feb 17, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Empty file.
96 changes: 96 additions & 0 deletions custom_litgpt_dataloader/data_util.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
from lit_gpt.packed_dataset import CombinedDataset, PackedDataset
from lit_gpt.utils import CycleIterator
from torch.utils.data import DataLoader
from pathlib import Path
from axonn import axonn as ax
import torch.distributed as dist
from typing import Tuple, Union, Optional

data_config = [
("train_slimpajama", 69.3584),
("train_starcoder", 30.6),
# 0.693584, 0.306416)
# ("c4", 15.0),
# ("cc", 67.0),
# ("github", 4.5),
# ("stackexchange", 2.0),
# ("wikipedia", 4.5),
]


def create_dataloader(
batch_size: int, block_size: int, data_dir: Path, shuffle: bool = True, seed: int = 12345
) -> DataLoader:
datasets = []
for prefix, _ in data_config:
filenames = list(data_dir.glob(f"{prefix}*"))
if not filenames:
raise FileNotFoundError(
f"No files found at {str(data_dir)} with prefix {prefix}. Did you forget to run `prepare_redpajama.py`?"
)
dataset = PackedDataset(
filenames,
n_chunks=4,
block_size=block_size,
shuffle=shuffle,
seed=seed,
num_processes=ax.config.G_data,
process_rank=ax.config.data_parallel_rank,
)
datasets.append(dataset)

if not datasets:
raise RuntimeError(
f"No data found at {data_dir}. Make sure you ran prepare_redpajama.py to create the dataset."
)

weights = [weight for _, weight in data_config]
sum_weights = sum(weights)
weights = [el / sum_weights for el in weights]

#having different seeds here is important such that each batch has tokens
#from all data mixtures.
combined_dataset = CombinedDataset(datasets=datasets, seed=seed, weights=weights)

return DataLoader(combined_dataset, batch_size=batch_size, shuffle=False, pin_memory=True)


def create_dataloaders(
batch_size: int,
block_size: int,
train_data_dir: str,
val_data_dir: str,
seed: int = 12345, #this seed is independent of megatron's seeds
) -> Tuple[DataLoader, DataLoader]:
# Increase by one because we need the next word as well
train_data_dir = Path(train_data_dir)
val_data_dir = Path(val_data_dir)
effective_block_size = block_size + 1
train_dataloader = create_dataloader(
batch_size=batch_size,
block_size=effective_block_size,
data_dir=train_data_dir,
shuffle=True,
seed=seed,
)
val_dataloader = (
create_dataloader(
batch_size=batch_size,
block_size=effective_block_size,
data_dir=val_data_dir,
shuffle=False,
seed=seed,
)
if val_data_dir
else None
)
return CycleIterator(train_dataloader), CycleIterator(val_dataloader)

if __name__ == "__main__":
ax.init(G_inter=1, G_data=1, G_intra_r=8)
train_loader, val_loader = create_dataloaders(
batch_size=32,
block_size=1024, #sequence length
)
data = next(train_loader)
print(dist.get_rank(), ":", data.view(-1)[:5])
40 changes: 40 additions & 0 deletions examples/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
# How to setup on frontier

## Installing all dependencies
Note that this is a python virtual environment based setup
You might need to change this a bit for conda

Also this assumes that you are starting from scratch and have no venv/conda
environment enabled.

We are going to install everything on scratch.

```
cd /lustre/orion/scratch/$(whoami)/csc569/
bash install_everything_on_frontier.sh
```

This should work, let Siddharth know if it doesn't

## Training TinyLLaMA
First checkout the tiny-llama branch of Megatron-AxoNN.
Then open `examples/run_axonn_amd_tinyllama.sh`, and change the following

```
# These are the two things you need to change as per your setup
# 1. Make LD_LIBRARY_PATH point to wherever your plugin is installed
# this enables the slingshot-11 plugin for RCCL (crucial for inter-node bw)
export LD_LIBRARY_PATH="${LD_LIBRARY_PATH}:/path/to/aws-ofi-rccl/build/lib"
# 2. Make PYTHONPATH point to your local clone of litgpt
export PYTHONPATH="$PYTHONPATH:/path/to/lit-gpt-dev"
```

Now you are ready to train.
To launch on 16 nodes (128 GPUs) for 2 hours:
```
## checkout the tiny-llama branch
sbatch -N 128 -o /path/to/output/file -t 02:00:00 examples/run_axonn_amd_tinyllama.sh
```



4 changes: 4 additions & 0 deletions examples/get_rank_from_slurm.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
#!/bin/bash
# select_gpu_device wrapper script
export RANK=${SLURM_PROCID}
exec $*
68 changes: 68 additions & 0 deletions examples/install_everything_on_frontier.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
#!/bin/bash

# Setup Virtual Environment
echo "Setting up Virtual Environment"
module load cray-python
python -m venv ./my-venv --system-site-packages
cd my-venv
. bin/activate

# PyTorch
echo "Installing PyTorch"
pip3 install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/rocm5.6

module load amd-mixed/5.6.0
module load PrgEnv-cray

# mpi4py
echo "Installing mpi4py"
module load craype-accel-amd-gfx90a
export MPICH_GPU_SUPPORT_ENABLED=1
export LD_LIBRARY_PATH="${LD_LIBRARY_PATH}:${CRAY_MPICH_ROOTDIR}/gtl/lib"
echo ${LD_LIBRARY_PATH}

MPICC=CC python -m pip install --ignore-installed --no-cache-dir mpi4py

# Flash Attention
echo "Installing Flash Attention"
git clone https://github.com/ROCmSoftwarePlatform/flash-attention
cd flash-attention
vi setup.py -c ':%s/c++20/c++17/g' -c ':wq'
CC=cc CXX=CC PYTORCH_ROCM_ARCH='gfx90a' GPU_ARCHS='gfx90a' pip install -v .

# Apex
echo "Installing Apex"
cd ..
git clone https://github.com/ROCmSoftwarePlatform/apex
cd apex
git checkout release/1.1.0
CC=cc CXX=CC PYTORCH_ROCM_ARCH='gfx90a' GPU_ARCHS='gfx90a' python setup.py install --cpp_ext --cuda_ext

# RCCL Plugin
echo "Installing RCCL Plugin"
cd ..
git clone https://github.com/ROCmSoftwarePlatform/aws-ofi-rccl
cd aws-ofi-rccl
module load libtool
./autogen.sh
CC=cc CXX=CC ./configure --with-libfabric=/opt/cray/libfabric/1.15.0.0 --with-hip=/opt/rocm-5.6.0/ --with-rccl="$(dirname "$(pwd)")"/lib/python3.9/site-packages/torch/lib/ --prefix="$(dirname "$(pwd)")"/aws-ofi-rccl/build/
CC=cc CXX=CC make -j install

cd ..

# AxoNN
echo "Installing AxoNN"
git clone https://github.com/axonn-ai/axonn.git
cd axonn
pip install -e .

cd ..

# Megatron-AxoNN
echo "Installing Megatron-AxoNN"
git clone https://github.com/axonn-ai/Megatron-AxoNN.git

pip install regex

echo "Done!"

136 changes: 0 additions & 136 deletions examples/run_axonn_amd.sh

This file was deleted.

Loading