First of all, great work! I was trying to follow your video on converting the GPT-J weights to PyTorch weights on colab. But, while running the python_convert_model_to_torch.py
script, I get the following error. I suspect that the error is either due to colab memory or disk usage, but not sure.
loading shards for part 0
read from checkpoint
< (8, 4096) to (4096,)
> transformer.wte.bias torch.Size([4096])
< (8, 6300, 4096) to (1, 50400, 4096)
> transformer.wte.weight torch.Size([4096, 50400])
< (8, 4096, 512) to (1, 4096, 4096)
convert_model_to_torch.py:147: UserWarning: The given NumPy array is not writeable, and PyTorch does not support non-writeable tensors. This means you can write to the underlying (supposedly non-writeable) NumPy array using the tensor. You may want to copy the array to protect its data or make it writeable before converting it to a tensor. This type of warning will be suppressed for the rest of this program. (Triggered internally at /pytorch/torch/csrc/utils/tensor_numpy.cpp:141.)
params = torch.tensor(params.copy()).half()
> transformer.h.0.attn.attention.q_proj.weight torch.Size([4096, 4096])
< (8, 4096, 512) to (1, 4096, 4096)
> transformer.h.0.attn.attention.v_proj.weight torch.Size([4096, 4096])
< (8, 4096, 512) to (1, 4096, 4096)
> transformer.h.0.attn.attention.k_proj.weight torch.Size([4096, 4096])
< (8, 512, 4096) to (1, 4096, 4096)
> transformer.h.0.attn.attention.out_proj.weight torch.Size([4096, 4096])
< (8, 2048) to (1, 16384)
> transformer.h.0.mlp.c_fc.bias torch.Size([16384])
< (8, 4096, 2048) to (1, 4096, 16384)
> transformer.h.0.mlp.c_fc.weight torch.Size([16384, 4096])
< (8, 4096) to (4096,)
> transformer.h.0.mlp.c_proj.bias torch.Size([4096])
< (8, 2048, 4096) to (1, 16384, 4096)
> transformer.h.0.mlp.c_proj.weight torch.Size([4096, 16384])
< (8, 4096) to (4096,)
> transformer.h.0.ln_1.bias torch.Size([4096])
< (8, 4096) to (4096,)
> transformer.h.0.ln_1.weight torch.Size([4096])
< (8, 4096, 512) to (1, 4096, 4096)
> transformer.h.1.attn.attention.q_proj.weight torch.Size([4096, 4096])
< (8, 4096, 512) to (1, 4096, 4096)
> transformer.h.1.attn.attention.v_proj.weight torch.Size([4096, 4096])
< (8, 4096, 512) to (1, 4096, 4096)
> transformer.h.1.attn.attention.k_proj.weight torch.Size([4096, 4096])
< (8, 512, 4096) to (1, 4096, 4096)
> transformer.h.1.attn.attention.out_proj.weight torch.Size([4096, 4096])
< (8, 2048) to (1, 16384)
> transformer.h.1.mlp.c_fc.bias torch.Size([16384])
< (8, 4096, 2048) to (1, 4096, 16384)
> transformer.h.1.mlp.c_fc.weight torch.Size([16384, 4096])
loading shards for part 1
read from checkpoint
< (8, 4096) to (4096,)
> transformer.h.1.mlp.c_proj.bias torch.Size([4096])
< (8, 2048, 4096) to (1, 16384, 4096)
> transformer.h.1.mlp.c_proj.weight torch.Size([4096, 16384])
< (8, 4096) to (4096,)
> transformer.h.1.ln_1.bias torch.Size([4096])
< (8, 4096) to (4096,)
> transformer.h.1.ln_1.weight torch.Size([4096])
< (8, 4096, 512) to (1, 4096, 4096)
> transformer.h.10.attn.attention.q_proj.weight torch.Size([4096, 4096])
< (8, 4096, 512) to (1, 4096, 4096)
> transformer.h.10.attn.attention.v_proj.weight torch.Size([4096, 4096])
< (8, 4096, 512) to (1, 4096, 4096)
> transformer.h.10.attn.attention.k_proj.weight torch.Size([4096, 4096])
< (8, 512, 4096) to (1, 4096, 4096)
> transformer.h.10.attn.attention.out_proj.weight torch.Size([4096, 4096])
< (8, 2048) to (1, 16384)
> transformer.h.10.mlp.c_fc.bias torch.Size([16384])
< (8, 4096, 2048) to (1, 4096, 16384)
> transformer.h.10.mlp.c_fc.weight torch.Size([16384, 4096])
< (8, 4096) to (4096,)
> transformer.h.10.mlp.c_proj.bias torch.Size([4096])
< (8, 2048, 4096) to (1, 16384, 4096)
> transformer.h.10.mlp.c_proj.weight torch.Size([4096, 16384])
< (8, 4096) to (4096,)
> transformer.h.10.ln_1.bias torch.Size([4096])
< (8, 4096) to (4096,)
> transformer.h.10.ln_1.weight torch.Size([4096])
< (8, 4096, 512) to (1, 4096, 4096)
> transformer.h.11.attn.attention.q_proj.weight torch.Size([4096, 4096])
< (8, 4096, 512) to (1, 4096, 4096)
> transformer.h.11.attn.attention.v_proj.weight torch.Size([4096, 4096])
< (8, 4096, 512) to (1, 4096, 4096)
> transformer.h.11.attn.attention.k_proj.weight torch.Size([4096, 4096])
< (8, 512, 4096) to (1, 4096, 4096)
> transformer.h.11.attn.attention.out_proj.weight torch.Size([4096, 4096])
loading shards for part 2
read from checkpoint
< (8, 2048) to (1, 16384)
> transformer.h.11.mlp.c_fc.bias torch.Size([16384])
< (8, 4096, 2048) to (1, 4096, 16384)
> transformer.h.11.mlp.c_fc.weight torch.Size([16384, 4096])
< (8, 4096) to (4096,)
> transformer.h.11.mlp.c_proj.bias torch.Size([4096])
< (8, 2048, 4096) to (1, 16384, 4096)
> transformer.h.11.mlp.c_proj.weight torch.Size([4096, 16384])
< (8, 4096) to (4096,)
> transformer.h.11.ln_1.bias torch.Size([4096])
< (8, 4096) to (4096,)
> transformer.h.11.ln_1.weight torch.Size([4096])
< (8, 4096, 512) to (1, 4096, 4096)
> transformer.h.12.attn.attention.q_proj.weight torch.Size([4096, 4096])
< (8, 4096, 512) to (1, 4096, 4096)
> transformer.h.12.attn.attention.v_proj.weight torch.Size([4096, 4096])
< (8, 4096, 512) to (1, 4096, 4096)
> transformer.h.12.attn.attention.k_proj.weight torch.Size([4096, 4096])
< (8, 512, 4096) to (1, 4096, 4096)
> transformer.h.12.attn.attention.out_proj.weight torch.Size([4096, 4096])
< (8, 2048) to (1, 16384)
> transformer.h.12.mlp.c_fc.bias torch.Size([16384])
< (8, 4096, 2048) to (1, 4096, 16384)
> transformer.h.12.mlp.c_fc.weight torch.Size([16384, 4096])
< (8, 4096) to (4096,)
> transformer.h.12.mlp.c_proj.bias torch.Size([4096])
< (8, 2048, 4096) to (1, 16384, 4096)
> transformer.h.12.mlp.c_proj.weight torch.Size([4096, 16384])
< (8, 4096) to (4096,)
> transformer.h.12.ln_1.bias torch.Size([4096])
< (8, 4096) to (4096,)
> transformer.h.12.ln_1.weight torch.Size([4096])
< (8, 4096, 512) to (1, 4096, 4096)
> transformer.h.13.attn.attention.q_proj.weight torch.Size([4096, 4096])
< (8, 4096, 512) to (1, 4096, 4096)
> transformer.h.13.attn.attention.v_proj.weight torch.Size([4096, 4096])
loading shards for part 3
read from checkpoint
< (8, 4096, 512) to (1, 4096, 4096)
> transformer.h.13.attn.attention.k_proj.weight torch.Size([4096, 4096])
< (8, 512, 4096) to (1, 4096, 4096)
> transformer.h.13.attn.attention.out_proj.weight torch.Size([4096, 4096])
< (8, 2048) to (1, 16384)
> transformer.h.13.mlp.c_fc.bias torch.Size([16384])
< (8, 4096, 2048) to (1, 4096, 16384)
> transformer.h.13.mlp.c_fc.weight torch.Size([16384, 4096])
< (8, 4096) to (4096,)
> transformer.h.13.mlp.c_proj.bias torch.Size([4096])
< (8, 2048, 4096) to (1, 16384, 4096)
> transformer.h.13.mlp.c_proj.weight torch.Size([4096, 16384])
< (8, 4096) to (4096,)
> transformer.h.13.ln_1.bias torch.Size([4096])
< (8, 4096) to (4096,)
> transformer.h.13.ln_1.weight torch.Size([4096])
< (8, 4096, 512) to (1, 4096, 4096)
> transformer.h.14.attn.attention.q_proj.weight torch.Size([4096, 4096])
< (8, 4096, 512) to (1, 4096, 4096)
> transformer.h.14.attn.attention.v_proj.weight torch.Size([4096, 4096])
< (8, 4096, 512) to (1, 4096, 4096)
> transformer.h.14.attn.attention.k_proj.weight torch.Size([4096, 4096])
< (8, 512, 4096) to (1, 4096, 4096)
> transformer.h.14.attn.attention.out_proj.weight torch.Size([4096, 4096])
< (8, 2048) to (1, 16384)
> transformer.h.14.mlp.c_fc.bias torch.Size([16384])
< (8, 4096, 2048) to (1, 4096, 16384)
> transformer.h.14.mlp.c_fc.weight torch.Size([16384, 4096])
< (8, 4096) to (4096,)
> transformer.h.14.mlp.c_proj.bias torch.Size([4096])
< (8, 2048, 4096) to (1, 16384, 4096)
> transformer.h.14.mlp.c_proj.weight torch.Size([4096, 16384])
< (8, 4096) to (4096,)
> transformer.h.14.ln_1.bias torch.Size([4096])
< (8, 4096) to (4096,)
> transformer.h.14.ln_1.weight torch.Size([4096])
loading shards for part 4
read from checkpoint
< (8, 4096, 512) to (1, 4096, 4096)
> transformer.h.15.attn.attention.q_proj.weight torch.Size([4096, 4096])
< (8, 4096, 512) to (1, 4096, 4096)
> transformer.h.15.attn.attention.v_proj.weight torch.Size([4096, 4096])
< (8, 4096, 512) to (1, 4096, 4096)
> transformer.h.15.attn.attention.k_proj.weight torch.Size([4096, 4096])
< (8, 512, 4096) to (1, 4096, 4096)
> transformer.h.15.attn.attention.out_proj.weight torch.Size([4096, 4096])
< (8, 2048) to (1, 16384)
> transformer.h.15.mlp.c_fc.bias torch.Size([16384])
< (8, 4096, 2048) to (1, 4096, 16384)
> transformer.h.15.mlp.c_fc.weight torch.Size([16384, 4096])
< (8, 4096) to (4096,)
> transformer.h.15.mlp.c_proj.bias torch.Size([4096])
< (8, 2048, 4096) to (1, 16384, 4096)
> transformer.h.15.mlp.c_proj.weight torch.Size([4096, 16384])
< (8, 4096) to (4096,)
> transformer.h.15.ln_1.bias torch.Size([4096])
< (8, 4096) to (4096,)
> transformer.h.15.ln_1.weight torch.Size([4096])
< (8, 4096, 512) to (1, 4096, 4096)
> transformer.h.16.attn.attention.q_proj.weight torch.Size([4096, 4096])
< (8, 4096, 512) to (1, 4096, 4096)
> transformer.h.16.attn.attention.v_proj.weight torch.Size([4096, 4096])
< (8, 4096, 512) to (1, 4096, 4096)
> transformer.h.16.attn.attention.k_proj.weight torch.Size([4096, 4096])
< (8, 512, 4096) to (1, 4096, 4096)
> transformer.h.16.attn.attention.out_proj.weight torch.Size([4096, 4096])
< (8, 2048) to (1, 16384)
> transformer.h.16.mlp.c_fc.bias torch.Size([16384])
< (8, 4096, 2048) to (1, 4096, 16384)
> transformer.h.16.mlp.c_fc.weight torch.Size([16384, 4096])
< (8, 4096) to (4096,)
> transformer.h.16.mlp.c_proj.bias torch.Size([4096])
< (8, 2048, 4096) to (1, 16384, 4096)
> transformer.h.16.mlp.c_proj.weight torch.Size([4096, 16384])
loading shards for part 5
read from checkpoint
< (8, 4096) to (4096,)
> transformer.h.16.ln_1.bias torch.Size([4096])
< (8, 4096) to (4096,)
> transformer.h.16.ln_1.weight torch.Size([4096])
< (8, 4096, 512) to (1, 4096, 4096)
> transformer.h.17.attn.attention.q_proj.weight torch.Size([4096, 4096])
< (8, 4096, 512) to (1, 4096, 4096)
> transformer.h.17.attn.attention.v_proj.weight torch.Size([4096, 4096])
< (8, 4096, 512) to (1, 4096, 4096)
> transformer.h.17.attn.attention.k_proj.weight torch.Size([4096, 4096])
< (8, 512, 4096) to (1, 4096, 4096)
> transformer.h.17.attn.attention.out_proj.weight torch.Size([4096, 4096])
< (8, 2048) to (1, 16384)
> transformer.h.17.mlp.c_fc.bias torch.Size([16384])
< (8, 4096, 2048) to (1, 4096, 16384)
> transformer.h.17.mlp.c_fc.weight torch.Size([16384, 4096])
< (8, 4096) to (4096,)
> transformer.h.17.mlp.c_proj.bias torch.Size([4096])
< (8, 2048, 4096) to (1, 16384, 4096)
> transformer.h.17.mlp.c_proj.weight torch.Size([4096, 16384])
< (8, 4096) to (4096,)
> transformer.h.17.ln_1.bias torch.Size([4096])
< (8, 4096) to (4096,)
> transformer.h.17.ln_1.weight torch.Size([4096])
< (8, 4096, 512) to (1, 4096, 4096)
> transformer.h.18.attn.attention.q_proj.weight torch.Size([4096, 4096])
< (8, 4096, 512) to (1, 4096, 4096)
> transformer.h.18.attn.attention.v_proj.weight torch.Size([4096, 4096])
< (8, 4096, 512) to (1, 4096, 4096)
> transformer.h.18.attn.attention.k_proj.weight torch.Size([4096, 4096])
< (8, 512, 4096) to (1, 4096, 4096)
> transformer.h.18.attn.attention.out_proj.weight torch.Size([4096, 4096])
< (8, 2048) to (1, 16384)
> transformer.h.18.mlp.c_fc.bias torch.Size([16384])
< (8, 4096, 2048) to (1, 4096, 16384)
> transformer.h.18.mlp.c_fc.weight torch.Size([16384, 4096])
loading shards for part 6
read from checkpoint
< (8, 4096) to (4096,)
> transformer.h.18.mlp.c_proj.bias torch.Size([4096])
< (8, 2048, 4096) to (1, 16384, 4096)
> transformer.h.18.mlp.c_proj.weight torch.Size([4096, 16384])
< (8, 4096) to (4096,)
> transformer.h.18.ln_1.bias torch.Size([4096])
< (8, 4096) to (4096,)
> transformer.h.18.ln_1.weight torch.Size([4096])
< (8, 4096, 512) to (1, 4096, 4096)
> transformer.h.19.attn.attention.q_proj.weight torch.Size([4096, 4096])
< (8, 4096, 512) to (1, 4096, 4096)
> transformer.h.19.attn.attention.v_proj.weight torch.Size([4096, 4096])
< (8, 4096, 512) to (1, 4096, 4096)
> transformer.h.19.attn.attention.k_proj.weight torch.Size([4096, 4096])
< (8, 512, 4096) to (1, 4096, 4096)
> transformer.h.19.attn.attention.out_proj.weight torch.Size([4096, 4096])
< (8, 2048) to (1, 16384)
> transformer.h.19.mlp.c_fc.bias torch.Size([16384])
< (8, 4096, 2048) to (1, 4096, 16384)
> transformer.h.19.mlp.c_fc.weight torch.Size([16384, 4096])
< (8, 4096) to (4096,)
> transformer.h.19.mlp.c_proj.bias torch.Size([4096])
< (8, 2048, 4096) to (1, 16384, 4096)
> transformer.h.19.mlp.c_proj.weight torch.Size([4096, 16384])
< (8, 4096) to (4096,)
> transformer.h.19.ln_1.bias torch.Size([4096])
< (8, 4096) to (4096,)
> transformer.h.19.ln_1.weight torch.Size([4096])
< (8, 4096, 512) to (1, 4096, 4096)
> transformer.h.2.attn.attention.q_proj.weight torch.Size([4096, 4096])
< (8, 4096, 512) to (1, 4096, 4096)
> transformer.h.2.attn.attention.v_proj.weight torch.Size([4096, 4096])
< (8, 4096, 512) to (1, 4096, 4096)
> transformer.h.2.attn.attention.k_proj.weight torch.Size([4096, 4096])
< (8, 512, 4096) to (1, 4096, 4096)
> transformer.h.2.attn.attention.out_proj.weight torch.Size([4096, 4096])
loading shards for part 7
read from checkpoint
< (8, 2048) to (1, 16384)
> transformer.h.2.mlp.c_fc.bias torch.Size([16384])
< (8, 4096, 2048) to (1, 4096, 16384)
> transformer.h.2.mlp.c_fc.weight torch.Size([16384, 4096])
< (8, 4096) to (4096,)
> transformer.h.2.mlp.c_proj.bias torch.Size([4096])
< (8, 2048, 4096) to (1, 16384, 4096)
> transformer.h.2.mlp.c_proj.weight torch.Size([4096, 16384])
< (8, 4096) to (4096,)
> transformer.h.2.ln_1.bias torch.Size([4096])
< (8, 4096) to (4096,)
> transformer.h.2.ln_1.weight torch.Size([4096])
< (8, 4096, 512) to (1, 4096, 4096)
> transformer.h.20.attn.attention.q_proj.weight torch.Size([4096, 4096])
< (8, 4096, 512) to (1, 4096, 4096)
> transformer.h.20.attn.attention.v_proj.weight torch.Size([4096, 4096])
< (8, 4096, 512) to (1, 4096, 4096)
> transformer.h.20.attn.attention.k_proj.weight torch.Size([4096, 4096])
< (8, 512, 4096) to (1, 4096, 4096)
> transformer.h.20.attn.attention.out_proj.weight torch.Size([4096, 4096])
< (8, 2048) to (1, 16384)
> transformer.h.20.mlp.c_fc.bias torch.Size([16384])
< (8, 4096, 2048) to (1, 4096, 16384)
> transformer.h.20.mlp.c_fc.weight torch.Size([16384, 4096])
< (8, 4096) to (4096,)
> transformer.h.20.mlp.c_proj.bias torch.Size([4096])
< (8, 2048, 4096) to (1, 16384, 4096)
> transformer.h.20.mlp.c_proj.weight torch.Size([4096, 16384])
< (8, 4096) to (4096,)
> transformer.h.20.ln_1.bias torch.Size([4096])
< (8, 4096) to (4096,)
> transformer.h.20.ln_1.weight torch.Size([4096])
< (8, 4096, 512) to (1, 4096, 4096)
> transformer.h.21.attn.attention.q_proj.weight torch.Size([4096, 4096])
< (8, 4096, 512) to (1, 4096, 4096)
> transformer.h.21.attn.attention.v_proj.weight torch.Size([4096, 4096])
loading shards for part 8
read from checkpoint
< (8, 4096, 512) to (1, 4096, 4096)
> transformer.h.21.attn.attention.k_proj.weight torch.Size([4096, 4096])
< (8, 512, 4096) to (1, 4096, 4096)
> transformer.h.21.attn.attention.out_proj.weight torch.Size([4096, 4096])
< (8, 2048) to (1, 16384)
> transformer.h.21.mlp.c_fc.bias torch.Size([16384])
< (8, 4096, 2048) to (1, 4096, 16384)
> transformer.h.21.mlp.c_fc.weight torch.Size([16384, 4096])
< (8, 4096) to (4096,)
> transformer.h.21.mlp.c_proj.bias torch.Size([4096])
< (8, 2048, 4096) to (1, 16384, 4096)
> transformer.h.21.mlp.c_proj.weight torch.Size([4096, 16384])
< (8, 4096) to (4096,)
> transformer.h.21.ln_1.bias torch.Size([4096])
< (8, 4096) to (4096,)
> transformer.h.21.ln_1.weight torch.Size([4096])
< (8, 4096, 512) to (1, 4096, 4096)
> transformer.h.22.attn.attention.q_proj.weight torch.Size([4096, 4096])
< (8, 4096, 512) to (1, 4096, 4096)
> transformer.h.22.attn.attention.v_proj.weight torch.Size([4096, 4096])
< (8, 4096, 512) to (1, 4096, 4096)
> transformer.h.22.attn.attention.k_proj.weight torch.Size([4096, 4096])
< (8, 512, 4096) to (1, 4096, 4096)
> transformer.h.22.attn.attention.out_proj.weight torch.Size([4096, 4096])
< (8, 2048) to (1, 16384)
> transformer.h.22.mlp.c_fc.bias torch.Size([16384])
< (8, 4096, 2048) to (1, 4096, 16384)
> transformer.h.22.mlp.c_fc.weight torch.Size([16384, 4096])
< (8, 4096) to (4096,)
> transformer.h.22.mlp.c_proj.bias torch.Size([4096])
< (8, 2048, 4096) to (1, 16384, 4096)
> transformer.h.22.mlp.c_proj.weight torch.Size([4096, 16384])
< (8, 4096) to (4096,)
> transformer.h.22.ln_1.bias torch.Size([4096])
< (8, 4096) to (4096,)
> transformer.h.22.ln_1.weight torch.Size([4096])
loading shards for part 9
read from checkpoint
< (8, 4096, 512) to (1, 4096, 4096)
> transformer.h.23.attn.attention.q_proj.weight torch.Size([4096, 4096])
< (8, 4096, 512) to (1, 4096, 4096)
> transformer.h.23.attn.attention.v_proj.weight torch.Size([4096, 4096])
< (8, 4096, 512) to (1, 4096, 4096)
> transformer.h.23.attn.attention.k_proj.weight torch.Size([4096, 4096])
< (8, 512, 4096) to (1, 4096, 4096)
> transformer.h.23.attn.attention.out_proj.weight torch.Size([4096, 4096])
< (8, 2048) to (1, 16384)
> transformer.h.23.mlp.c_fc.bias torch.Size([16384])
< (8, 4096, 2048) to (1, 4096, 16384)
> transformer.h.23.mlp.c_fc.weight torch.Size([16384, 4096])
< (8, 4096) to (4096,)
> transformer.h.23.mlp.c_proj.bias torch.Size([4096])
< (8, 2048, 4096) to (1, 16384, 4096)
> transformer.h.23.mlp.c_proj.weight torch.Size([4096, 16384])
< (8, 4096) to (4096,)
> transformer.h.23.ln_1.bias torch.Size([4096])
< (8, 4096) to (4096,)
> transformer.h.23.ln_1.weight torch.Size([4096])
< (8, 4096, 512) to (1, 4096, 4096)
> transformer.h.24.attn.attention.q_proj.weight torch.Size([4096, 4096])
< (8, 4096, 512) to (1, 4096, 4096)
> transformer.h.24.attn.attention.v_proj.weight torch.Size([4096, 4096])
< (8, 4096, 512) to (1, 4096, 4096)
> transformer.h.24.attn.attention.k_proj.weight torch.Size([4096, 4096])
< (8, 512, 4096) to (1, 4096, 4096)
> transformer.h.24.attn.attention.out_proj.weight torch.Size([4096, 4096])
< (8, 2048) to (1, 16384)
> transformer.h.24.mlp.c_fc.bias torch.Size([16384])
< (8, 4096, 2048) to (1, 4096, 16384)
> transformer.h.24.mlp.c_fc.weight torch.Size([16384, 4096])
< (8, 4096) to (4096,)
> transformer.h.24.mlp.c_proj.bias torch.Size([4096])
< (8, 2048, 4096) to (1, 16384, 4096)
> transformer.h.24.mlp.c_proj.weight torch.Size([4096, 16384])
loading shards for part 10
read from checkpoint
< (8, 4096) to (4096,)
> transformer.h.24.ln_1.bias torch.Size([4096])
< (8, 4096) to (4096,)
> transformer.h.24.ln_1.weight torch.Size([4096])
< (8, 4096, 512) to (1, 4096, 4096)
> transformer.h.25.attn.attention.q_proj.weight torch.Size([4096, 4096])
< (8, 4096, 512) to (1, 4096, 4096)
> transformer.h.25.attn.attention.v_proj.weight torch.Size([4096, 4096])
< (8, 4096, 512) to (1, 4096, 4096)
> transformer.h.25.attn.attention.k_proj.weight torch.Size([4096, 4096])
< (8, 512, 4096) to (1, 4096, 4096)
> transformer.h.25.attn.attention.out_proj.weight torch.Size([4096, 4096])
< (8, 2048) to (1, 16384)
> transformer.h.25.mlp.c_fc.bias torch.Size([16384])
< (8, 4096, 2048) to (1, 4096, 16384)
> transformer.h.25.mlp.c_fc.weight torch.Size([16384, 4096])
< (8, 4096) to (4096,)
> transformer.h.25.mlp.c_proj.bias torch.Size([4096])
< (8, 2048, 4096) to (1, 16384, 4096)
> transformer.h.25.mlp.c_proj.weight torch.Size([4096, 16384])
< (8, 4096) to (4096,)
> transformer.h.25.ln_1.bias torch.Size([4096])
< (8, 4096) to (4096,)
> transformer.h.25.ln_1.weight torch.Size([4096])
< (8, 4096, 512) to (1, 4096, 4096)
> transformer.h.26.attn.attention.q_proj.weight torch.Size([4096, 4096])
< (8, 4096, 512) to (1, 4096, 4096)
> transformer.h.26.attn.attention.v_proj.weight torch.Size([4096, 4096])
< (8, 4096, 512) to (1, 4096, 4096)
> transformer.h.26.attn.attention.k_proj.weight torch.Size([4096, 4096])
< (8, 512, 4096) to (1, 4096, 4096)
> transformer.h.26.attn.attention.out_proj.weight torch.Size([4096, 4096])
< (8, 2048) to (1, 16384)
> transformer.h.26.mlp.c_fc.bias torch.Size([16384])
< (8, 4096, 2048) to (1, 4096, 16384)
> transformer.h.26.mlp.c_fc.weight torch.Size([16384, 4096])
loading shards for part 11
read from checkpoint
/bin/bash: line 1: 3502 Killed python convert_model_to_torch.py
real 6m31.465s
user 1m48.769s
sys 0m22.349s