Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

GPT-J model conversion failed from pytorch to paxml, throwing OOM error for TPUv3-8 #18

Open
confusedgoose627 opened this issue Jan 10, 2024 · 1 comment

Comments

@confusedgoose627
Copy link

confusedgoose627 commented Jan 10, 2024

Hi, I am trying to do the serving on gpt-j 6B model using TPUv3-8. For which I am using saxml framework,

The error is coming when I am doing the model conversion from pytorch to pax format which is supported by sax. This is the conversion script:

https://github.com/mlcommons/inference_results_v3.1/blob/main/closed/Google/code/gptj-99/convert_gptj_ckpt.py

The admin and model server is running correctly even I have confirmed that they are communicating by running a sample test query.

The model pickle file is just 22.7 GB so it should acomodate into the TPU cluster. Any idea?

The enviornment
pip3 install accelerate
pip3 install torch
pip3 install transformers
pip install paxml==1.1.0)(Although I have build it from its gitrepo)

2024-01-03 05:23:41.411871: W tensorflow/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libcudart.so.11.0'; dlerror: libcudart.so.11.0: cannot open shared object file: No such file or directory; LD_LIBRARY_PATH: :/usr/local/lib
Loading the base model from EleutherAI/gpt-j-6b
transformer.wte.weight (50400, 4096)
transformer.h.0.ln_1.weight (4096,)
transformer.h.0.ln_1.bias (4096,)
transformer.h.0.attn.k_proj.weight (4096, 4096)
transformer.h.0.attn.v_proj.weight (4096, 4096)
transformer.h.0.attn.q_proj.weight (4096, 4096)
transformer.h.0.attn.out_proj.weight (4096, 4096)
transformer.h.0.mlp.fc_in.weight (16384, 4096)
transformer.h.0.mlp.fc_in.bias (16384,)
transformer.h.0.mlp.fc_out.weight (4096, 16384)
transformer.h.0.mlp.fc_out.bias (4096,)
transformer.h.1.ln_1.weight (4096,)
transformer.h.1.ln_1.bias (4096,)
transformer.h.1.attn.k_proj.weight (4096, 4096)
transformer.h.1.attn.v_proj.weight (4096, 4096)
transformer.h.1.attn.q_proj.weight (4096, 4096)
transformer.h.1.attn.out_proj.weight (4096, 4096)
transformer.h.1.mlp.fc_in.weight (16384, 4096)
transformer.h.1.mlp.fc_in.bias (16384,)
transformer.h.1.mlp.fc_out.weight (4096, 16384)
transformer.h.1.mlp.fc_out.bias (4096,)
transformer.h.2.ln_1.weight (4096,)
transformer.h.2.ln_1.bias (4096,)
transformer.h.2.attn.k_proj.weight (4096, 4096)
transformer.h.2.attn.v_proj.weight (4096, 4096)
transformer.h.2.attn.q_proj.weight (4096, 4096)
transformer.h.2.attn.out_proj.weight (4096, 4096)
transformer.h.2.mlp.fc_in.weight (16384, 4096)
transformer.h.2.mlp.fc_in.bias (16384,)
transformer.h.2.mlp.fc_out.weight (4096, 16384)
transformer.h.2.mlp.fc_out.bias (4096,)
transformer.h.3.ln_1.weight (4096,)
transformer.h.3.ln_1.bias (4096,)
transformer.h.3.attn.k_proj.weight (4096, 4096)
transformer.h.3.attn.v_proj.weight (4096, 4096)
transformer.h.3.attn.q_proj.weight (4096, 4096)
transformer.h.3.attn.out_proj.weight (4096, 4096)
transformer.h.3.mlp.fc_in.weight (16384, 4096)
transformer.h.3.mlp.fc_in.bias (16384,)
transformer.h.3.mlp.fc_out.weight (4096, 16384)
transformer.h.3.mlp.fc_out.bias (4096,)
transformer.h.4.ln_1.weight (4096,)
transformer.h.4.ln_1.bias (4096,)
transformer.h.4.attn.k_proj.weight (4096, 4096)
transformer.h.4.attn.v_proj.weight (4096, 4096)
transformer.h.4.attn.q_proj.weight (4096, 4096)
transformer.h.4.attn.out_proj.weight (4096, 4096)
transformer.h.4.mlp.fc_in.weight (16384, 4096)
transformer.h.4.mlp.fc_in.bias (16384,)
transformer.h.4.mlp.fc_out.weight (4096, 16384)
transformer.h.4.mlp.fc_out.bias (4096,)
transformer.h.5.ln_1.weight (4096,)
transformer.h.5.ln_1.bias (4096,)
transformer.h.5.attn.k_proj.weight (4096, 4096)
transformer.h.5.attn.v_proj.weight (4096, 4096)
transformer.h.5.attn.q_proj.weight (4096, 4096)
transformer.h.5.attn.out_proj.weight (4096, 4096)
transformer.h.5.mlp.fc_in.weight (16384, 4096)
transformer.h.5.mlp.fc_in.bias (16384,)
transformer.h.5.mlp.fc_out.weight (4096, 16384)
transformer.h.5.mlp.fc_out.bias (4096,)
transformer.h.6.ln_1.weight (4096,)
transformer.h.6.ln_1.bias (4096,)
transformer.h.6.attn.k_proj.weight (4096, 4096)
transformer.h.6.attn.v_proj.weight (4096, 4096)
transformer.h.6.attn.q_proj.weight (4096, 4096)
transformer.h.6.attn.out_proj.weight (4096, 4096)
transformer.h.6.mlp.fc_in.weight (16384, 4096)
transformer.h.6.mlp.fc_in.bias (16384,)
transformer.h.6.mlp.fc_out.weight (4096, 16384)
transformer.h.6.mlp.fc_out.bias (4096,)
transformer.h.7.ln_1.weight (4096,)
transformer.h.7.ln_1.bias (4096,)
transformer.h.7.attn.k_proj.weight (4096, 4096)
transformer.h.7.attn.v_proj.weight (4096, 4096)
transformer.h.7.attn.q_proj.weight (4096, 4096)
transformer.h.7.attn.out_proj.weight (4096, 4096)
transformer.h.7.mlp.fc_in.weight (16384, 4096)
transformer.h.7.mlp.fc_in.bias (16384,)
transformer.h.7.mlp.fc_out.weight (4096, 16384)
transformer.h.7.mlp.fc_out.bias (4096,)
transformer.h.8.ln_1.weight (4096,)
transformer.h.8.ln_1.bias (4096,)
transformer.h.8.attn.k_proj.weight (4096, 4096)
transformer.h.8.attn.v_proj.weight (4096, 4096)
transformer.h.8.attn.q_proj.weight (4096, 4096)
transformer.h.8.attn.out_proj.weight (4096, 4096)
transformer.h.8.mlp.fc_in.weight (16384, 4096)
transformer.h.8.mlp.fc_in.bias (16384,)
transformer.h.8.mlp.fc_out.weight (4096, 16384)
transformer.h.8.mlp.fc_out.bias (4096,)
transformer.h.9.ln_1.weight (4096,)
transformer.h.9.ln_1.bias (4096,)
transformer.h.9.attn.k_proj.weight (4096, 4096)
transformer.h.9.attn.v_proj.weight (4096, 4096)
transformer.h.9.attn.q_proj.weight (4096, 4096)
transformer.h.9.attn.out_proj.weight (4096, 4096)
transformer.h.9.mlp.fc_in.weight (16384, 4096)
transformer.h.9.mlp.fc_in.bias (16384,)
transformer.h.9.mlp.fc_out.weight (4096, 16384)
transformer.h.9.mlp.fc_out.bias (4096,)
transformer.h.10.ln_1.weight (4096,)
transformer.h.10.ln_1.bias (4096,)
transformer.h.10.attn.k_proj.weight (4096, 4096)
transformer.h.10.attn.v_proj.weight (4096, 4096)
transformer.h.10.attn.q_proj.weight (4096, 4096)
transformer.h.10.attn.out_proj.weight (4096, 4096)
transformer.h.10.mlp.fc_in.weight (16384, 4096)
transformer.h.10.mlp.fc_in.bias (16384,)
transformer.h.10.mlp.fc_out.weight (4096, 16384)
transformer.h.10.mlp.fc_out.bias (4096,)
transformer.h.11.ln_1.weight (4096,)
transformer.h.11.ln_1.bias (4096,)
transformer.h.11.attn.k_proj.weight (4096, 4096)
transformer.h.11.attn.v_proj.weight (4096, 4096)
transformer.h.11.attn.q_proj.weight (4096, 4096)
transformer.h.11.attn.out_proj.weight (4096, 4096)
transformer.h.11.mlp.fc_in.weight (16384, 4096)
transformer.h.11.mlp.fc_in.bias (16384,)
transformer.h.11.mlp.fc_out.weight (4096, 16384)
transformer.h.11.mlp.fc_out.bias (4096,)
transformer.h.12.ln_1.weight (4096,)
transformer.h.12.ln_1.bias (4096,)
transformer.h.12.attn.k_proj.weight (4096, 4096)
transformer.h.12.attn.v_proj.weight (4096, 4096)
transformer.h.12.attn.q_proj.weight (4096, 4096)
transformer.h.12.attn.out_proj.weight (4096, 4096)
transformer.h.12.mlp.fc_in.weight (16384, 4096)
transformer.h.12.mlp.fc_in.bias (16384,)
transformer.h.12.mlp.fc_out.weight (4096, 16384)
transformer.h.12.mlp.fc_out.bias (4096,)
transformer.h.13.ln_1.weight (4096,)
transformer.h.13.ln_1.bias (4096,)
transformer.h.13.attn.k_proj.weight (4096, 4096)
transformer.h.13.attn.v_proj.weight (4096, 4096)
transformer.h.13.attn.q_proj.weight (4096, 4096)
transformer.h.13.attn.out_proj.weight (4096, 4096)
transformer.h.13.mlp.fc_in.weight (16384, 4096)
transformer.h.13.mlp.fc_in.bias (16384,)
transformer.h.13.mlp.fc_out.weight (4096, 16384)
transformer.h.13.mlp.fc_out.bias (4096,)
transformer.h.14.ln_1.weight (4096,)
transformer.h.14.ln_1.bias (4096,)
transformer.h.14.attn.k_proj.weight (4096, 4096)
transformer.h.14.attn.v_proj.weight (4096, 4096)
transformer.h.14.attn.q_proj.weight (4096, 4096)
transformer.h.14.attn.out_proj.weight (4096, 4096)
transformer.h.14.mlp.fc_in.weight (16384, 4096)
transformer.h.14.mlp.fc_in.bias (16384,)
transformer.h.14.mlp.fc_out.weight (4096, 16384)
transformer.h.14.mlp.fc_out.bias (4096,)
transformer.h.15.ln_1.weight (4096,)
transformer.h.15.ln_1.bias (4096,)
transformer.h.15.attn.k_proj.weight (4096, 4096)
transformer.h.15.attn.v_proj.weight (4096, 4096)
transformer.h.15.attn.q_proj.weight (4096, 4096)
transformer.h.15.attn.out_proj.weight (4096, 4096)
transformer.h.15.mlp.fc_in.weight (16384, 4096)
transformer.h.15.mlp.fc_in.bias (16384,)
transformer.h.15.mlp.fc_out.weight (4096, 16384)
transformer.h.15.mlp.fc_out.bias (4096,)
transformer.h.16.ln_1.weight (4096,)
transformer.h.16.ln_1.bias (4096,)
transformer.h.16.attn.k_proj.weight (4096, 4096)
transformer.h.16.attn.v_proj.weight (4096, 4096)
transformer.h.16.attn.q_proj.weight (4096, 4096)
transformer.h.16.attn.out_proj.weight (4096, 4096)
transformer.h.16.mlp.fc_in.weight (16384, 4096)
transformer.h.16.mlp.fc_in.bias (16384,)
transformer.h.16.mlp.fc_out.weight (4096, 16384)
transformer.h.16.mlp.fc_out.bias (4096,)
transformer.h.17.ln_1.weight (4096,)
transformer.h.17.ln_1.bias (4096,)
transformer.h.17.attn.k_proj.weight (4096, 4096)
transformer.h.17.attn.v_proj.weight (4096, 4096)
transformer.h.17.attn.q_proj.weight (4096, 4096)
transformer.h.17.attn.out_proj.weight (4096, 4096)
transformer.h.17.mlp.fc_in.weight (16384, 4096)
transformer.h.17.mlp.fc_in.bias (16384,)
transformer.h.17.mlp.fc_out.weight (4096, 16384)
transformer.h.17.mlp.fc_out.bias (4096,)
transformer.h.18.ln_1.weight (4096,)
transformer.h.18.ln_1.bias (4096,)
transformer.h.18.attn.k_proj.weight (4096, 4096)
transformer.h.18.attn.v_proj.weight (4096, 4096)
transformer.h.18.attn.q_proj.weight (4096, 4096)
transformer.h.18.attn.out_proj.weight (4096, 4096)
transformer.h.18.mlp.fc_in.weight (16384, 4096)
transformer.h.18.mlp.fc_in.bias (16384,)
transformer.h.18.mlp.fc_out.weight (4096, 16384)
transformer.h.18.mlp.fc_out.bias (4096,)
transformer.h.19.ln_1.weight (4096,)
transformer.h.19.ln_1.bias (4096,)
transformer.h.19.attn.k_proj.weight (4096, 4096)
transformer.h.19.attn.v_proj.weight (4096, 4096)
transformer.h.19.attn.q_proj.weight (4096, 4096)
transformer.h.19.attn.out_proj.weight (4096, 4096)
transformer.h.19.mlp.fc_in.weight (16384, 4096)
transformer.h.19.mlp.fc_in.bias (16384,)
transformer.h.19.mlp.fc_out.weight (4096, 16384)
transformer.h.19.mlp.fc_out.bias (4096,)
transformer.h.20.ln_1.weight (4096,)
transformer.h.20.ln_1.bias (4096,)
transformer.h.20.attn.k_proj.weight (4096, 4096)
transformer.h.20.attn.v_proj.weight (4096, 4096)
transformer.h.20.attn.q_proj.weight (4096, 4096)
transformer.h.20.attn.out_proj.weight (4096, 4096)
transformer.h.20.mlp.fc_in.weight (16384, 4096)
transformer.h.20.mlp.fc_in.bias (16384,)
transformer.h.20.mlp.fc_out.weight (4096, 16384)
transformer.h.20.mlp.fc_out.bias (4096,)
transformer.h.21.ln_1.weight (4096,)
transformer.h.21.ln_1.bias (4096,)
transformer.h.21.attn.k_proj.weight (4096, 4096)
transformer.h.21.attn.v_proj.weight (4096, 4096)
transformer.h.21.attn.q_proj.weight (4096, 4096)
transformer.h.21.attn.out_proj.weight (4096, 4096)
transformer.h.21.mlp.fc_in.weight (16384, 4096)
transformer.h.21.mlp.fc_in.bias (16384,)
transformer.h.21.mlp.fc_out.weight (4096, 16384)
transformer.h.21.mlp.fc_out.bias (4096,)
transformer.h.22.ln_1.weight (4096,)
transformer.h.22.ln_1.bias (4096,)
transformer.h.22.attn.k_proj.weight (4096, 4096)
transformer.h.22.attn.v_proj.weight (4096, 4096)
transformer.h.22.attn.q_proj.weight (4096, 4096)
transformer.h.22.attn.out_proj.weight (4096, 4096)
transformer.h.22.mlp.fc_in.weight (16384, 4096)
transformer.h.22.mlp.fc_in.bias (16384,)
transformer.h.22.mlp.fc_out.weight (4096, 16384)
transformer.h.22.mlp.fc_out.bias (4096,)
transformer.h.23.ln_1.weight (4096,)
transformer.h.23.ln_1.bias (4096,)
transformer.h.23.attn.k_proj.weight (4096, 4096)
transformer.h.23.attn.v_proj.weight (4096, 4096)
transformer.h.23.attn.q_proj.weight (4096, 4096)
transformer.h.23.attn.out_proj.weight (4096, 4096)
transformer.h.23.mlp.fc_in.weight (16384, 4096)
transformer.h.23.mlp.fc_in.bias (16384,)
transformer.h.23.mlp.fc_out.weight (4096, 16384)
transformer.h.23.mlp.fc_out.bias (4096,)
transformer.h.24.ln_1.weight (4096,)
transformer.h.24.ln_1.bias (4096,)
transformer.h.24.attn.k_proj.weight (4096, 4096)
transformer.h.24.attn.v_proj.weight (4096, 4096)
transformer.h.24.attn.q_proj.weight (4096, 4096)
transformer.h.24.attn.out_proj.weight (4096, 4096)
transformer.h.24.mlp.fc_in.weight (16384, 4096)
transformer.h.24.mlp.fc_in.bias (16384,)
transformer.h.24.mlp.fc_out.weight (4096, 16384)
transformer.h.24.mlp.fc_out.bias (4096,)
transformer.h.25.ln_1.weight (4096,)
transformer.h.25.ln_1.bias (4096,)
transformer.h.25.attn.k_proj.weight (4096, 4096)
transformer.h.25.attn.v_proj.weight (4096, 4096)
transformer.h.25.attn.q_proj.weight (4096, 4096)
transformer.h.25.attn.out_proj.weight (4096, 4096)
transformer.h.25.mlp.fc_in.weight (16384, 4096)
transformer.h.25.mlp.fc_in.bias (16384,)
transformer.h.25.mlp.fc_out.weight (4096, 16384)
transformer.h.25.mlp.fc_out.bias (4096,)
transformer.h.26.ln_1.weight (4096,)
transformer.h.26.ln_1.bias (4096,)
transformer.h.26.attn.k_proj.weight (4096, 4096)
transformer.h.26.attn.v_proj.weight (4096, 4096)
transformer.h.26.attn.q_proj.weight (4096, 4096)
transformer.h.26.attn.out_proj.weight (4096, 4096)
transformer.h.26.mlp.fc_in.weight (16384, 4096)
transformer.h.26.mlp.fc_in.bias (16384,)
transformer.h.26.mlp.fc_out.weight (4096, 16384)
transformer.h.26.mlp.fc_out.bias (4096,)
transformer.h.27.ln_1.weight (4096,)
transformer.h.27.ln_1.bias (4096,)
transformer.h.27.attn.k_proj.weight (4096, 4096)
transformer.h.27.attn.v_proj.weight (4096, 4096)
transformer.h.27.attn.q_proj.weight (4096, 4096)
transformer.h.27.attn.out_proj.weight (4096, 4096)
transformer.h.27.mlp.fc_in.weight (16384, 4096)
transformer.h.27.mlp.fc_in.bias (16384,)
transformer.h.27.mlp.fc_out.weight (4096, 16384)
transformer.h.27.mlp.fc_out.bias (4096,)
transformer.ln_f.weight (4096,)
transformer.ln_f.bias (4096,)
lm_head.weight (50400, 4096)
lm_head.bias (50400,)
Saving the pax model to pax_6b
Traceback (most recent call last):
File "/home/arghyajoy627/convert_gptj_ckpt.py", line 192, in
convert(args.base_model_path, args.pax_model_path)
File "/home/arghyajoy627/convert_gptj_ckpt.py", line 176, in convert
jax_states_gda = pjitted_identity(jax_states)
File "/home/arghyajoy627/.local/lib/python3.8/site-packages/jax/_src/traceback_util.py", line 166, in reraise_with_filtered_traceback
return fun(*args, **kwargs)
File "/home/arghyajoy627/.local/lib/python3.8/site-packages/jax/_src/pjit.py", line 248, in cache_miss
outs, out_flat, out_tree, args_flat = _python_pjit_helper(
File "/home/arghyajoy627/.local/lib/python3.8/site-packages/jax/_src/pjit.py", line 195, in _python_pjit_helper
out_flat = pjit_p.bind(*args_flat, **params)
File "/home/arghyajoy627/.local/lib/python3.8/site-packages/jax/_src/core.py", line 2591, in bind
return self.bind_with_trace(top_trace, args, params)
File "/home/arghyajoy627/.local/lib/python3.8/site-packages/jax/_src/core.py", line 362, in bind_with_trace
out = trace.process_primitive(self, map(trace.full_raise, args), params)
File "/home/arghyajoy627/.local/lib/python3.8/site-packages/jax/_src/core.py", line 816, in process_primitive
return primitive.impl(*tracers, **params)
File "/home/arghyajoy627/.local/lib/python3.8/site-packages/jax/_src/pjit.py", line 1246, in _pjit_call_impl
compiled = _pjit_lower(
File "/home/arghyajoy627/.local/lib/python3.8/site-packages/jax/_src/interpreters/pxla.py", line 2836, in compile
self._executable = UnloadedMeshExecutable.from_hlo(
File "/home/arghyajoy627/.local/lib/python3.8/site-packages/jax/_src/interpreters/pxla.py", line 3048, in from_hlo
xla_executable = dispatch.compile_or_get_cached(
File "/home/arghyajoy627/.local/lib/python3.8/site-packages/jax/_src/dispatch.py", line 526, in compile_or_get_cached
return backend_compile(backend, serialized_computation, compile_options,
File "/home/arghyajoy627/.local/lib/python3.8/site-packages/jax/_src/profiler.py", line 314, in wrapper
return func(*args, **kwargs)
File "/home/arghyajoy627/.local/lib/python3.8/site-packages/jax/_src/dispatch.py", line 471, in backend_compile
return backend.compile(built_c, compile_options=options)
jax._src.traceback_util.UnfilteredStackTrace: jaxlib.xla_extension.XlaRuntimeError: RESOURCE_EXHAUSTED: XLA:TPU compile permanent error. Ran out of memory in memory space hbm. Used 22.54G of 15.48G hbm. Exceeded hbm capacity by 7.06G.

Total hbm usage >= 23.06G:
reserved 530.00M
program 4.0K
arguments 22.54G

Output size 22.54G; shares 0B with arguments.

Program hbm requirement 4.0K:
global 4.0K

Largest program allocations in hbm:

  1. Size: 4.0K
    Shape: u32[8,128]{1,0}
    Unpadded size: 4.0K
    XLA label: constant literal
    Allocation type: global

The stack trace below excludes JAX-internal frames.
The preceding is the original exception that occurred, unmodified.


The above exception was the direct cause of the following exception:

Traceback (most recent call last):
File "/usr/lib/python3.8/runpy.py", line 194, in _run_module_as_main
return _run_code(code, main_globals, None,
File "/usr/lib/python3.8/runpy.py", line 87, in _run_code
exec(code, run_globals)
File "/home/arghyajoy627/convert_gptj_ckpt.py", line 192, in
convert(args.base_model_path, args.pax_model_path)
File "/home/arghyajoy627/convert_gptj_ckpt.py", line 176, in convert
jax_states_gda = pjitted_identity(jax_states)
jaxlib.xla_extension.XlaRuntimeError: RESOURCE_EXHAUSTED: XLA:TPU compile permanent error. Ran out of memory in memory space hbm. Used 22.54G of 15.48G hbm. Exceeded hbm capacity by 7.06G.

Total hbm usage >= 23.06G:
reserved 530.00M
program 4.0K
arguments 22.54G

Output size 22.54G; shares 0B with arguments.

Program hbm requirement 4.0K:
global 4.0K

Largest program allocations in hbm:

  1. Size: 4.0K
    Shape: u32[8,128]{1,0}
    Unpadded size: 4.0K
    XLA label: constant literal
    Allocation type: global

    @zhihaoshan-google

@NoahBPeterson
Copy link

The problem you're having is entire model is being loaded onto one or each TPU, and there is not enough memory on each one to do this.

The conversion script you linked was written for a single GPU (https://github.com/mlcommons/inference_results_v3.1/blob/951b4a7686692d1a0d9b9067a36a7fc26d72ada5/closed/Google/code/gptj-99/convert_gptj_ckpt.py#L154C1-L154C62), not a TPU cluster, for running the conversion, and so it will not shard without modification.

The offending line:

device_mesh = py_utils.create_device_mesh([1, 1, num_gpus])

This creates a device mesh of 1x1 with the number of GPUs (hardcoded to 1).

Try modifying it using the mesh sharding example from the JAX documentation: https://jax.readthedocs.io/en/latest/jax.sharding.html#jax.sharding.Mesh

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants