Skip to content

Commit

Permalink
minor fix on llama checkpoint conversion
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 572349696
Change-Id: Icda9d683e5b181b4d0bdd4e06eb39674729042be
  • Loading branch information
Sax Authors authored and copybara-github committed Oct 10, 2023
1 parent 291d708 commit 79aae7d
Showing 1 changed file with 2 additions and 2 deletions.
4 changes: 2 additions & 2 deletions saxml/tools/convert_llama_ckpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ def convert(base_model_path, pax_model_path, model_size):
}
},
'final_ln': {
'scale': pytorch_vars[0]['norm.weight'].numpy()
'scale': pytorch_vars[0]['norm.weight'].type(torch.float16).numpy()
},
'transformer': {}
}
Expand Down Expand Up @@ -165,7 +165,7 @@ def convert(base_model_path, pax_model_path, model_size):

print(f'Saving the pax model to {pax_model_path}')
jax_states = train_states.TrainState(
step=np.zeros(1),
step=0,
mdl_vars={'params': jax_weights},
opt_states={})

Expand Down

0 comments on commit 79aae7d

Please sign in to comment.