Code Monkey home page Code Monkey logo

Comments (8)

oobabooga avatar oobabooga commented on May 18, 2024 2

I find that max_memory gets properly fed into accelerate, but the model loads entirely into VRAM nevertheless. For instance:

from auto_gptq import AutoGPTQForCausalLM

path_to_model = 'models/TheBloke_stable-vicuna-13B-GPTQ'
params = {
    'model_basename': 'stable-vicuna-13B-GPTQ-4bit.compat.no-act-order',
    'use_triton': False,
    'use_safetensors': True,
    'max_memory': {0: '2GiB', 'cpu': '99GiB'}
}

model = AutoGPTQForCausalLM.from_quantized(path_to_model, **params)

input()

nvidia-smi reports a 7533MiB allocation instead of something close to 2000MiB.

from autogptq.

PanQiWei avatar PanQiWei commented on May 18, 2024

offload_folder argument is not supported, you should set like max_memory={"cpu": "30GIB", 0: "3GIB"} to use cpu offload. For more details you can also turn to this tutorial

from autogptq.

Dessix avatar Dessix commented on May 18, 2024

I ended up having to massively reduce my example dataset in order to get it to load, because of the vram constraints involved. Disabling the cache-on-GPU flag didn't appear to affect this result. Also, it appears to be impossible to do a truly cpu-only run, as a lack of GPUs leads to a division by zero.

from autogptq.

abhinavkulkarni avatar abhinavkulkarni commented on May 18, 2024

Hi @PanQiWei,

So, I am still running into an error while trying to quantize large models (that don't fit in the 12GB of VRAM).

The script ran for 8 mins and then failed:

pretrained_model_dir = "EleutherAI/gpt-j-6b"
quantized_model_dir = "EleutherAI/gpt-j-6b-4bit-128g"

quantize_config = BaseQuantizeConfig(
    bits=4,  # quantize model to 4-bit
    group_size=128,  # it is recommended to set the value to 128
)

max_memory={0: "6GiB", 'cpu': '80GiB'}

# load un-quantized model, by default, the model will always be loaded into CPU memory
model = AutoGPTQForCausalLM.from_pretrained(pretrained_model_dir, quantize_config, max_memory=max_memory)

# quantize model, the examples should be list of dict whose keys can only be "input_ids" and "attention_mask"
model.quantize(examples, use_triton=False)

# save quantized model
model.save_quantized(quantized_model_dir)

I run into the following problem:

---------------------------------------------------------------------------
NotImplementedError                       Traceback (most recent call last)
Cell In[6], line 16
     13 model.quantize(examples, use_triton=False)
     15 # save quantized model
---> 16 model.save_quantized(quantized_model_dir)
     18 # save quantized model using safetensors
     19 model.save_quantized(quantized_model_dir, use_safetensors=True)

File [/opt/anaconda3/envs/autogptq/lib/python3.9/site-packages/auto_gptq/modeling/_base.py:392](https://vscode-remote+ssh-002dremote-002borigin-002econcentricai-002ecom.vscode-resource.vscode-cdn.net/opt/anaconda3/envs/autogptq/lib/python3.9/site-packages/auto_gptq/modeling/_base.py:392), in BaseGPTQForCausalLM.save_quantized(self, save_dir, use_safetensors)
    389 if not self.quantized:
    390     raise EnvironmentError("can only save quantized model, please execute .quantize first.")
--> 392 self.model.to(CPU)
    394 model_save_name = f"gptq_model-{self.quantize_config.bits}bit-{self.quantize_config.group_size}g"
    395 if use_safetensors:

File [/opt/anaconda3/envs/autogptq/lib/python3.9/site-packages/transformers/modeling_utils.py:1878](https://vscode-remote+ssh-002dremote-002borigin-002econcentricai-002ecom.vscode-resource.vscode-cdn.net/opt/anaconda3/envs/autogptq/lib/python3.9/site-packages/transformers/modeling_utils.py:1878), in PreTrainedModel.to(self, *args, **kwargs)
   1873     raise ValueError(
   1874         "`.to` is not supported for `8-bit` models. Please use the model as it is, since the"
   1875         " model has already been set to the correct devices and casted to the correct `dtype`."
   1876     )
   1877 else:
-> 1878     return super().to(*args, **kwargs)

File [/opt/anaconda3/envs/autogptq/lib/python3.9/site-packages/torch/nn/modules/module.py:1145](https://vscode-remote+ssh-002dremote-002borigin-002econcentricai-002ecom.vscode-resource.vscode-cdn.net/opt/anaconda3/envs/autogptq/lib/python3.9/site-packages/torch/nn/modules/module.py:1145), in Module.to(self, *args, **kwargs)
   1141         return t.to(device, dtype if t.is_floating_point() or t.is_complex() else None,
   1142                     non_blocking, memory_format=convert_to_format)
   1143     return t.to(device, dtype if t.is_floating_point() or t.is_complex() else None, non_blocking)
-> 1145 return self._apply(convert)

File [/opt/anaconda3/envs/autogptq/lib/python3.9/site-packages/torch/nn/modules/module.py:797](https://vscode-remote+ssh-002dremote-002borigin-002econcentricai-002ecom.vscode-resource.vscode-cdn.net/opt/anaconda3/envs/autogptq/lib/python3.9/site-packages/torch/nn/modules/module.py:797), in Module._apply(self, fn)
    795 def _apply(self, fn):
    796     for module in self.children():
--> 797         module._apply(fn)
    799     def compute_should_use_set_data(tensor, tensor_applied):
    800         if torch._has_compatible_shallow_copy_type(tensor, tensor_applied):
    801             # If the new tensor has compatible tensor type as the existing tensor,
    802             # the current behavior is to change the tensor in-place using `.data =`,
    (...)
    807             # global flag to let the user control whether they want the future
    808             # behavior of overwriting the existing tensor or not.

File /opt/anaconda3/envs/autogptq/lib/python3.9/site-packages/torch/nn/modules/module.py:797, in Module._apply(self, fn)
    795 def _apply(self, fn):
    796     for module in self.children():
--> 797         module._apply(fn)
    799     def compute_should_use_set_data(tensor, tensor_applied):
    800         if torch._has_compatible_shallow_copy_type(tensor, tensor_applied):
    801             # If the new tensor has compatible tensor type as the existing tensor,
    802             # the current behavior is to change the tensor in-place using `.data =`,
   (...)
    807             # global flag to let the user control whether they want the future
    808             # behavior of overwriting the existing tensor or not.

    [... skipping similar frames: Module._apply at line 797 (1 times)]

File [/opt/anaconda3/envs/autogptq/lib/python3.9/site-packages/torch/nn/modules/module.py:797](https://vscode-remote+ssh-002dremote-002borigin-002econcentricai-002ecom.vscode-resource.vscode-cdn.net/opt/anaconda3/envs/autogptq/lib/python3.9/site-packages/torch/nn/modules/module.py:797), in Module._apply(self, fn)
    795 def _apply(self, fn):
    796     for module in self.children():
--> 797         module._apply(fn)
    799     def compute_should_use_set_data(tensor, tensor_applied):
    800         if torch._has_compatible_shallow_copy_type(tensor, tensor_applied):
    801             # If the new tensor has compatible tensor type as the existing tensor,
    802             # the current behavior is to change the tensor in-place using `.data =`,
   (...)
    807             # global flag to let the user control whether they want the future
    808             # behavior of overwriting the existing tensor or not.

File /opt/anaconda3/envs/autogptq/lib/python3.9/site-packages/torch/nn/modules/module.py:820, in Module._apply(self, fn)
    816 # Tensors stored in modules are graph leaves, and we don't want to
    817 # track autograd history of `param_applied`, so we have to use
    818 # `with torch.no_grad():`
    819 with torch.no_grad():
--> 820     param_applied = fn(param)
    821 should_use_set_data = compute_should_use_set_data(param, param_applied)
    822 if should_use_set_data:

File /opt/anaconda3/envs/autogptq/lib/python3.9/site-packages/torch/nn/modules/module.py:1143, in Module.to..convert(t)
   1140 if convert_to_format is not None and t.dim() in (4, 5):
   1141     return t.to(device, dtype if t.is_floating_point() or t.is_complex() else None,
   1142                 non_blocking, memory_format=convert_to_format)
-> 1143 return t.to(device, dtype if t.is_floating_point() or t.is_complex() else None, non_blocking)

NotImplementedError: Cannot copy out of meta tensor; no data!

Also, I see that the VRAM usage goes upto 11GB even when I have specified 6GB in the quantized_config. The examples array has only 10 sentences.

from autogptq.

z80maniac avatar z80maniac commented on May 18, 2024

but the model loads entirely into VRAM nevertheless

I may be mistaken, but there might be a bug in the accelerate library.

In #47 (comment) I've shown what device map is generated when you specify max_memory=max_memory={0: "2GIB", "cpu": "30GIB"}:

{'model.embed_tokens': 0, 'model.layers.0': 0, 'model.layers.1': 0, 'model.layers.2': 0, 'model.layers.3': 0, 'model.layers.4': 0, 'model.layers.5': 0, 'model.layers.6': 0, 'model.layers.7': 0, 'model.layers.8': 'cpu', 'model.layers.9': 'cpu', 'model.layers.10': 'cpu', 'model.layers.11': 'cpu', 'model.layers.12': 'cpu', 'model.layers.13': 'cpu', 'model.layers.14': 'cpu', 'model.layers.15': 'cpu', 'model.layers.16': 'cpu', 'model.layers.17': 'cpu', 'model.layers.18': 'cpu', 'model.layers.19': 'cpu', 'model.layers.20': 'cpu', 'model.layers.21': 'cpu', 'model.layers.22': 'cpu', 'model.layers.23': 'cpu', 'model.layers.24': 'cpu', 'model.layers.25': 'cpu', 'model.layers.26': 'cpu', 'model.layers.27': 'cpu', 'model.layers.28': 'cpu', 'model.layers.29': 'cpu', 'model.layers.30': 'cpu', 'model.layers.31': 'cpu', 'model.layers.32': 'cpu', 'model.layers.33': 'cpu', 'model.layers.34': 'cpu', 'model.layers.35': 'cpu', 'model.layers.36': 'cpu', 'model.layers.37': 'cpu', 'model.layers.38': 'cpu', 'model.layers.39': 'cpu', 'model.norm': 'cpu', 'lm_head': 'cpu'}

Let's take this part:

'model.layers.1': 0

The accelerate library in the modelling.py (load_state_dict function) has the following code:

# For each device, get the weights that go there
device_weights = {device: [] for device in devices}
for module_name, device in device_map.items():
    if device in devices:
        device_weights[device].extend([k for k in weight_names if k.startswith(module_name)])

This code is supposed to distribute all modules to their respective devices, so that 'model.layers.1': 0 would mean that all modules that start with model.layers.1 (e.g. model.layers.1.input_layernorm.weight, etc) will go the device 0.

And here's the problem: model.layers.11 also starts with model.layers.1... As well as model.layers.12, etc. So in the original example:

'model.layers.1': 0, 'model.layers.2': 0, 'model.layers.3': 0

would mean that not only layers 1-3 go to the GPU, but also layers 10-39 as well. In fact, it seems like layers 10-39 go to both CPU and GPU.

Disclaimer: I have no idea how accelerate library works, I just tried to debug the code and stumbled upon this weird logic. So, I may be digging in the wrong direction.

Also, as I've shown in #47 (comment) even if you put only model.layers.0 on the GPU (so the bug above is not activated), the model is still fully loaded into VRAM. So, there may be an additional problem somewhere.

from autogptq.

Ph0rk0z avatar Ph0rk0z commented on May 18, 2024

For accelerate I always have to lower the actual memory fed into it to get usable results. I will tell it to feed 16gb and it will load 18gb.

from autogptq.

PanQiWei avatar PanQiWei commented on May 18, 2024

Hi! This pr #100 fixed the bug that can't save quantized model when load pretrained model using CPU offload.

from autogptq.

PanQiWei avatar PanQiWei commented on May 18, 2024

Close this issue for the problem mentioned here has been fixed

from autogptq.

Related Issues (20)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.