Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
  • Loading branch information
goliaro committed Oct 9, 2024
1 parent cf85d60 commit 50a1163
Show file tree
Hide file tree
Showing 4 changed files with 102 additions and 59 deletions.
5 changes: 3 additions & 2 deletions tests/fine_grained_alignment_test.sh
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,8 @@ mkdir -p ./inference/output

# Enable backtrace in case we run into a segfault or assertion failure
export LEGION_BACKTRACE=1
export FF_DEBG_NO_WEIGHTS=1
export FF_DEBG_NO_WEIGHTS=0
FUSION=false

PROMPT_LENGTH=$(python -c "
from transformers import AutoTokenizer
Expand Down Expand Up @@ -66,7 +67,7 @@ json_config=$(cat <<-END
"tensor_parallelism_degree": ${TP_DEGREE},
"pipeline_parallelism_degree": ${PP_DEGREE},
"inference_debugging": true,
"fusion": true,
"fusion": ${FUSION},
"refresh_cache": false,
"llm_model": "${MODEL_NAME}",
"cache_path": "${CACHE_PATH}",
Expand Down
2 changes: 1 addition & 1 deletion tests/inference/huggingface_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ def main():
make_debug_dirs()
register_inference_hooks(model)
# Save weights
# save_model_weights(model, target_modules=["lora", "lm_head", "down_proj"])
save_model_weights(model, target_modules=["lora", "lm_head", "final_layer_norm", "self_attn_layer_norm", "out_proj", "fc1", "fc2"])

###############################################
# Generate output
Expand Down
146 changes: 94 additions & 52 deletions tests/inference/inference_alignment_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,14 +95,14 @@ def get_tp_partition_dim(ff_weight_name) -> int:

# 1. get shape of hf weight
hf_weight = torch.load(hf_w_path, map_location='cpu')
hf_weigth_shape = hf_weight.shape
hf_weight_shape = hf_weight.shape
ff_partition_dim = get_tp_partition_dim(ff_weight_name)
ff_weigth_shape = list(hf_weigth_shape)[::-1]
ff_weight_shape = list(hf_weight_shape)[::-1]
if ff_partition_dim >= 0:
ff_weigth_shape[ff_partition_dim] //= self.tp_degree
ff_weight_shape[ff_partition_dim] //= self.tp_degree

# 2. handle flexflow shards in case of tensor parallelism
ff_weights = [load_ff_tensor(ff_w_path.replace("shard_0", f"shard_{tp_idx}"), ff_weigth_shape) for tp_idx in range(self.tp_degree)]
ff_weights = [load_ff_tensor(ff_w_path.replace("shard_0", f"shard_{tp_idx}"), ff_weight_shape) for tp_idx in range(self.tp_degree)]
if self.tp_degree > 1:
if ff_partition_dim >= 0:
ff_weight = np.concatenate(ff_weights, axis=ff_partition_dim)
Expand Down Expand Up @@ -252,6 +252,7 @@ def compare(hf_tensor, ff_tensor, label="", additional_ff_tensor=None, tolerance
ff_tensor = get_ff_tensor(ff_tensor_name, output_comparison, hf_tensor.shape, tp_type=TPType.PARTITION)
print("comparing attention tensor: ", hf_tensor_name, " and ", ff_tensor_name)
compare(hf_tensor, ff_tensor, label=f"Attention {i} output")
assert False

# Post-attention layernorm
hf_tensor_name = f"layers.{i}.post_attention_layernorm"
Expand Down Expand Up @@ -327,16 +328,25 @@ def __init__(self, hf_config, tp_degree=1):

def check_weights_alignment(self):
def convert_hf_filename_to_ff(hf_filename):
if hf_filename == "lm_head.weight":
f_version = f"layers.{self.num_layers-1}.lm_head.weight_0"
elif hf_filename == "final_layer_norm.weight":
f_version = f"layers.{self.num_layers-1}.final_layer_norm.weight_0"
if hf_filename == "lm_head.weight" or hf_filename == "final_layer_norm.weight":
f_version = f"layers.{self.num_layers-1}.{hf_filename}_0"
elif hf_filename == "lm_head.bias" or hf_filename == "final_layer_norm.bias":
f_version = f"layers.{self.num_layers-1}.{hf_filename.replace('bias', 'weight')}_1"
elif hf_filename.startswith("layers.") and hf_filename.endswith("self_attn.out_proj.bias"):
layernum = hf_filename.split("layers.")[1].split(".")[0]
f_version = f"layers.{layernum}.layers.{layernum}.add_bias_residual_layer_norm.weight_0"
elif hf_filename.startswith("layers.") and hf_filename.endswith(".final_layer_norm.weight"):
layernum = hf_filename.split("layers.")[1].split(".")[0]
f_version = f"layers.{layernum}.layers.{layernum}.add_bias_residual_layer_norm.weight_1"
elif hf_filename.startswith("layers.") and hf_filename.endswith(".final_layer_norm.bias"):
layernum = hf_filename.split("layers.")[1].split(".")[0]
f_version = f"layers.{layernum}.layers.{layernum}.add_bias_residual_layer_norm.weight_2"
else:
f_version = ""
if hf_filename.startswith("layers."):
layernum = hf_filename.split("layers.")[1].split(".")[0]
f_version += f"layers.{layernum}."
f_version += hf_filename.replace(".base_layer", "").replace(".default", "")
f_version += hf_filename.replace(".base_layer", "").replace(".default", "").replace("out_proj", "o_proj")
# compute weight index, then rename lora if needed if needed
weight_index="0"
if "lora_A" in f_version:
Expand All @@ -352,6 +362,8 @@ def convert_hf_filename_to_ff(hf_filename):
elif f_version.endswith(".gradient"):
prefix = f_version.split(".gradient")[0]
f_version = prefix + f".weight_{weight_index}.gradient"
elif f_version.endswith(".bias"):
f_version = f_version.replace(".bias", ".weight_1")
return f_version
def get_tp_partition_dim(ff_weight_name) -> int:
# MLP layers split the intermediate size dimension
Expand All @@ -361,11 +373,16 @@ def get_tp_partition_dim(ff_weight_name) -> int:
return -1
if "lora.weight_B" in ff_weight_name:
return -1
if "lm_head" in ff_weight_name or "final_layer_norm" in ff_weight_name:
if "lm_head" in ff_weight_name or "fc1" in ff_weight_name:
return 1
if "fc1" in ff_weight_name:
return 1
elif "fc2" in ff_weight_name:
elif "fc2" in ff_weight_name or "o_proj.weight" in ff_weight_name:
return 0
else:
return -1
def get_bias_tp_partition_dim(ff_weight_name) -> int:
if self.tp_degree == 1:
return -1
elif "lm_head" in ff_weight_name or "fc1" in ff_weight_name:
return 0
else:
return -1
Expand All @@ -374,7 +391,7 @@ def get_tp_partition_dim(ff_weight_name) -> int:
ff_weights_folder = os.path.join(ff_path, "weights", "step_0", "shard_0")
files_list = os.listdir(hf_weights_folder)
for hf_weight_name in tqdm(sorted(files_list)):
if hf_weight_name.endswith(".weight"):
if hf_weight_name.endswith(".weight") or hf_weight_name.endswith(".bias"):
ff_weight_name = convert_hf_filename_to_ff(hf_weight_name)
# print(hf_weight_name, ff_weight_name)
hf_w_path = os.path.join(hf_weights_folder, hf_weight_name)
Expand All @@ -388,24 +405,29 @@ def get_tp_partition_dim(ff_weight_name) -> int:

# 1. get shape of hf weight
hf_weight = torch.load(hf_w_path, map_location='cpu')
hf_weigth_shape = hf_weight.shape
ff_partition_dim = get_tp_partition_dim(ff_weight_name)
ff_weigth_shape = list(hf_weigth_shape)[::-1]
hf_weight_shape = hf_weight.shape
ff_partition_dim = get_tp_partition_dim(ff_weight_name) if hf_weight_name.endswith(".weight") else get_bias_tp_partition_dim(ff_weight_name)
ff_weight_shape = list(hf_weight_shape)[::-1]
# print(ff_partition_dim, ff_weight_name, hf_w_path, ff_weight_shape)
if ff_partition_dim >= 0:
ff_weigth_shape[ff_partition_dim] //= self.tp_degree
ff_weight_shape[ff_partition_dim] //= self.tp_degree

# 2. handle flexflow shards in case of tensor parallelism
ff_weights = [load_ff_tensor(ff_w_path.replace("shard_0", f"shard_{tp_idx}"), ff_weigth_shape) for tp_idx in range(self.tp_degree)]
if self.tp_degree > 1:
if ff_partition_dim >= 0:
ff_weight = np.concatenate(ff_weights, axis=ff_partition_dim)
if hf_weight_name.endswith(".bias") and ff_partition_dim == -1:
# unpartitioned bias (E.g. replicated bias) only lives on shard 0
ff_weight = load_ff_tensor(ff_w_path, ff_weight_shape)
else:
ff_weights = [load_ff_tensor(ff_w_path.replace("shard_0", f"shard_{tp_idx}"), ff_weight_shape) for tp_idx in range(self.tp_degree)]
if self.tp_degree > 1:
if ff_partition_dim >= 0:
ff_weight = np.concatenate(ff_weights, axis=ff_partition_dim)
else:
assert(are_np_arrays_identical(ff_weights))
ff_weight = ff_weights[0]
else:
assert(are_np_arrays_identical(ff_weights))
ff_weight = ff_weights[0]
else:
ff_weight = ff_weights[0]
ff_weight = torch.from_numpy(ff_weight).to(hf_weight.dtype)

# print("comparing weight tensor: ", hf_weight_name, " and ", ff_weight_name)
# check equivalence
try:
torch.testing.assert_close(ff_weight, hf_weight.T)
Expand Down Expand Up @@ -526,7 +548,7 @@ def compare(hf_tensor, ff_tensor, label="", additional_ff_tensor=None, tolerance

# Transformers blocks
for i in range(self.num_layers):
# Input laye norm
# Input layer norm
hf_tensor_name = f"layers.{i}.self_attn_layer_norm"
ff_tensor_name = convert_hf_filename_to_ff(hf_tensor_name)
input_comparison = TensorComparisonIdxs(hf_tensor_type="input", ff_tensor_type="output", hf_tensor_idx=0, ff_tensor_idx=0)
Expand All @@ -538,7 +560,7 @@ def compare(hf_tensor, ff_tensor, label="", additional_ff_tensor=None, tolerance
ff_tensor = get_ff_tensor(ff_tensor_name, output_comparison, hf_tensor.shape)
compare(hf_tensor, ff_tensor, label=f"Self attention layernorm {i} output")

# Attention
# Attention QKV projections
hf_q_proj_tensor_name = f"layers.{i}.self_attn.q_proj"
hf_k_proj_tensor_name = f"layers.{i}.self_attn.k_proj"
hf_v_proj_tensor_name = f"layers.{i}.self_attn.v_proj"
Expand Down Expand Up @@ -581,34 +603,51 @@ def compare(hf_tensor, ff_tensor, label="", additional_ff_tensor=None, tolerance
compare_loaded_tensors(hf_k_proj_out.T, ff_kproj_out)
compare_loaded_tensors(hf_v_proj_out.T, ff_vproj_out)

hf_tensor_name = f"layers.{i}.self_attn.out_proj"
ff_tensor_name = convert_hf_filename_to_ff(hf_tensor_name.replace(".out_proj", ".o_proj"))
# the raw attention result, w/o o_proj. This is the output of senf_attn of FF and the input of o_proj in HF
output_comparison = TensorComparisonIdxs(hf_tensor_type="input", ff_tensor_type="input", hf_tensor_idx=0, ff_tensor_idx=0)
hf_tensor = get_hf_tensor(hf_tensor_name, output_comparison)
# ff_tensor = get_ff_tensor(ff_tensor_name, output_comparison, hf_tensor.shape, tp_type=TPType.TO_REDUCE)
# TP for self-attn partitions the attention heads across TP workers
ff_tensor = get_ff_tensor(ff_tensor_name, output_comparison, hf_tensor.shape, tp_type=TPType.PARTITION)
print("comparing attention tensor: ", hf_tensor_name, " and ", ff_tensor_name)
compare(hf_tensor, ff_tensor, label=f"Attention {i} output")
# hf_tensor_name = f"layers.{i}.final_layer_norm"
# ff_tensor_name = f"layers.{i}.layers.{i}.add_bias_residual_layer_norm"
# output_comparison = TensorComparisonIdxs(hf_tensor_type="input", ff_tensor_type="output", hf_tensor_idx=0, ff_tensor_idx=0)
# hf_tensor = get_hf_tensor(hf_tensor_name, output_comparison)
# ff_tensor = get_ff_tensor(ff_tensor_name, output_comparison, hf_tensor.shape, tp_type=TPType.REPLICATE)
# compare(hf_tensor, ff_tensor, label=f"Add Bias Residula LN {i} output 0")

# hf_tensor_name = f"layers.{i}.self_attn.out_proj"
# ff_tensor_name = convert_hf_filename_to_ff(hf_tensor_name.replace(".out_proj", ".o_proj"))
# # the raw attention result, w/o o_proj. This is the output of senf_attn of FF and the input of o_proj in HF
# output_comparison = TensorComparisonIdxs(hf_tensor_type="input", ff_tensor_type="input", hf_tensor_idx=0, ff_tensor_idx=0)
# hf_tensor = get_hf_tensor(hf_tensor_name, output_comparison)
# # ff_tensor = get_ff_tensor(ff_tensor_name, output_comparison, hf_tensor.shape, tp_type=TPType.TO_REDUCE)
# # TP for self-attn partitions the attention heads across TP workers
# ff_tensor = get_ff_tensor(ff_tensor_name, output_comparison, hf_tensor.shape, tp_type=TPType.PARTITION)
# print("comparing attention tensor: ", hf_tensor_name, " and ", ff_tensor_name)
# compare(hf_tensor, ff_tensor, label=f"Attention {i} output")

# hf_tensor_name = f"layers.{i}.self_attn.out_proj"
# ff_tensor_name = f"layers.{i}.layers.{i}.self_attn"
# output_comparison = TensorComparisonIdxs(hf_tensor_type="input", ff_tensor_type="output", hf_tensor_idx=0, ff_tensor_idx=0)
# hf_tensor = get_hf_tensor(hf_tensor_name, output_comparison)
# ff_tensor = get_ff_tensor(ff_tensor_name, output_comparison, hf_tensor.shape, tp_type=TPType.PARTITION)
# print("comparing attention tensor: ", hf_tensor_name, " and ", ff_tensor_name)
# compare(hf_tensor, ff_tensor, label=f"Attention {i} output")

# Post-attention layernorm
hf_tensor_name = f"layers.{i}.add_bias_residual_layer_norm"
ff_tensor_name = convert_hf_filename_to_ff(hf_tensor_name)
output_comparison = TensorComparisonIdxs(hf_tensor_type="output", ff_tensor_type="output", hf_tensor_idx=0, ff_tensor_idx=1)
hf_tensor = get_hf_tensor(hf_tensor_name, output_comparison)
ff_tensor = get_ff_tensor(ff_tensor_name, output_comparison, hf_tensor.shape)
compare(hf_tensor, ff_tensor, label=f"Add bias residual layernorm {i} output")

# W1 (gate_proj)
hf_tensor_name = f"layers.{i}.fc1"
ff_tensor_name = convert_hf_filename_to_ff(hf_tensor_name)


# # Post-attention layernorm
# hf_tensor_name = f"layers.{i}.add_bias_residual_layer_norm"
# ff_tensor_name = convert_hf_filename_to_ff(hf_tensor_name)
# output_comparison = TensorComparisonIdxs(hf_tensor_type="output", ff_tensor_type="output", hf_tensor_idx=0, ff_tensor_idx=1)
# hf_tensor = get_hf_tensor(hf_tensor_name, output_comparison)
# ff_tensor = get_ff_tensor(ff_tensor_name, output_comparison, hf_tensor.shape)
# compare(hf_tensor, ff_tensor, label=f"Add bias residual layernorm {i} output")

# FC1 (+ ReLU)
hf_tensor_name = f"layers.{i}.activation_fn"
ff_tensor_name = convert_hf_filename_to_ff(f"layers.{i}.fc1")
output_comparison = TensorComparisonIdxs(hf_tensor_type="output", ff_tensor_type="output", hf_tensor_idx=0, ff_tensor_idx=0)
hf_tensor = get_hf_tensor(hf_tensor_name, output_comparison)
ff_tensor = get_ff_tensor(ff_tensor_name, output_comparison, hf_tensor.shape, tp_type=TPType.PARTITION)
compare(hf_tensor, ff_tensor, label=f"FC1 {i} output")

# W2 (down_proj)
# FC2
hf_tensor_name = f"layers.{i}.fc2"
ff_tensor_name = convert_hf_filename_to_ff(hf_tensor_name)
input_comparison = TensorComparisonIdxs(hf_tensor_type="input", ff_tensor_type="input", hf_tensor_idx=0, ff_tensor_idx=0)
Expand All @@ -617,7 +656,10 @@ def compare(hf_tensor, ff_tensor, label="", additional_ff_tensor=None, tolerance
hf_tensor = get_hf_tensor(hf_tensor_name, input_comparison)
ff_tensor = get_ff_tensor(ff_tensor_name, input_comparison, hf_tensor.shape, tp_type=TPType.PARTITION)
compare(hf_tensor, ff_tensor, label=f"FC2 {i} input")

hf_tensor = get_hf_tensor(hf_tensor_name, output_comparison)
ff_tensor = get_ff_tensor(ff_tensor_name, output_comparison, hf_tensor.shape, tp_type=TPType.TO_REDUCE)
# compare(hf_tensor, ff_tensor, label=f"FC2 {i} output")

hf_down_proj_in = hf_tensor.clone()
hf_tensor = get_hf_tensor(hf_tensor_name, output_comparison)
ff_down_proj_out = get_ff_tensor(ff_tensor_name, output_comparison, hf_tensor.shape, tp_type=TPType.TO_REDUCE)
Expand Down Expand Up @@ -659,6 +701,6 @@ def compare(hf_tensor, ff_tensor, label="", additional_ff_tensor=None, tolerance
elif hf_config.architectures[0] == "OPTForCausalLM":
alignment_class = OPTAlignmentTest(hf_config, tp_degree=args.tensor_parallelism_degree)

# alignment_class.check_weights_alignment()
alignment_class.check_weights_alignment()
for i in range(args.num_steps):
alignment_class.check_fwd_pass(i)
8 changes: 4 additions & 4 deletions tests/peft/peft_alignment_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,14 +98,14 @@ def get_tp_partition_dim(ff_weight_name) -> int:

# 1. get shape of hf weight
hf_weight = torch.load(hf_w_path, map_location='cpu')
hf_weigth_shape = hf_weight.shape
hf_weight_shape = hf_weight.shape
ff_partition_dim = get_tp_partition_dim(ff_weight_name)
ff_weigth_shape = list(hf_weigth_shape)[::-1]
ff_weight_shape = list(hf_weight_shape)[::-1]
if ff_partition_dim >= 0:
ff_weigth_shape[ff_partition_dim] //= self.tp_degree
ff_weight_shape[ff_partition_dim] //= self.tp_degree

# 2. handle flexflow shards in case of tensor parallelism
ff_weights = [load_ff_tensor(ff_w_path.replace("shard_0", f"shard_{tp_idx}"), ff_weigth_shape) for tp_idx in range(self.tp_degree)]
ff_weights = [load_ff_tensor(ff_w_path.replace("shard_0", f"shard_{tp_idx}"), ff_weight_shape) for tp_idx in range(self.tp_degree)]
if self.tp_degree > 1:
if ff_partition_dim >= 0:
ff_weight = np.concatenate(ff_weights, axis=ff_partition_dim)
Expand Down

0 comments on commit 50a1163

Please sign in to comment.