Code Monkey home page Code Monkey logo

Comments (8)

cdpierse avatar cdpierse commented on May 17, 2024 3

Hi @roma-glushko,

Thanks for using the package and for such a detailed overview of the issue you are running into. I think I have a fix for it.

It's not well documented yet (which I must do) but in one of the more recent releases, I added a new feature to control the internal batch size that is passed when calculating the attributions. By default, the example text and every n_steps used to approximate the attributions are all calculated at once. This is available for every explainer and future explainers.

However, if you pass in internal_batch_size into the explainer objects' callable method it should give you some fine-grained control over this. So for your example, it might look something like:

# find a batch size number that works for you, smaller batch will result in slower calculation times
self.explainer(agent_notes, internal_batch_size=2)[1:-1] 

I actually ran into this same issue myself with a demo app on streamlit that had very limited RAM I did the internal batch size trick there too and it worked really well.

You are totally correct btw as to why this is happening. Gradients are calculated for n_steps which by default is 50. So if the internal batch size is not specified that means 50 sets of gradients will be stored in RAM all at once, which for larger models like Roberta starts to become a problem. Using the internal batch size helps us to do this piece by piece in a more manageable way.

I hope this helps, do let me know how you get on.

from transformers-interpret.

roma-glushko avatar roma-glushko commented on May 17, 2024

Hi @cdpierse 👋 do you know what could be the workaround here?

from transformers-interpret.

roma-glushko avatar roma-glushko commented on May 17, 2024

Hi @cdpierse Charles, thank you for the swift response 👍

The internal_batch_size does help a bit, but I believe we are talking about two different issues.
The internal_batch_size argument seems to constraint the peak RAM usage during the gradient calculations. This is great for environments with the limited resources. However, unfortunately it doesn't fix the issue with RAM leaks. The memory keeps leaking, but with the slower pace. It was sufficient for the Streamlit app because Streamlit reloads the whole script each time we change any UI controls (AFAIK). That seemed to clean up leaked memory naturally and the application stayed resilient.

In case of serving an application via conventional web frameworks (like FastAPI, Flask, etc), they don't restart the whole application over and over again, so the leaked memory is not cleared up like it happened in the Streamlit case. As a result, the application is going to run out of memory no matter how many we are ready to give it.

Could be so that we are missing detach() call somewhere in the process of attribution calculation, so the the full gradient graph accumulates like a snowflake over the time?

from transformers-interpret.

roma-glushko avatar roma-glushko commented on May 17, 2024

Alright, I think it also makes sense to create an issue in Captum repository: pytorch/captum#866

from transformers-interpret.

cdpierse avatar cdpierse commented on May 17, 2024

Hi @roma-glushko,

Thanks for looking into this further. I think you're right regarding the issue, I ran a long-running process locally and noticed a gradual memory increase regardless of batch size. I also tried experimenting with zeroing the gradients on the model but it doesn't seem to work. Most likely it's as you say that the n_steps gradients accumulated throughout the process are to blame.

I'll keep trying to debug a bit on my side to see if there is a way for me to force accumulation to be dropped. It might be on Captums' side too so I'll take a look at that too and see if I can up with a PR.

Will keep you posted!

from transformers-interpret.

roma-glushko avatar roma-glushko commented on May 17, 2024

Hi @cdpierse, Thank you for keeping investigation in this task 🙏

experimenting with zeroing the gradients on the model but it doesn't seem to work

Yeah, can confirm it did not work for me neither.

Please let me know if I can help somehow with this issue. I would love to but I have already ran out of ideas how to track this issue down.

from transformers-interpret.

cdpierse avatar cdpierse commented on May 17, 2024

Hi @roma-glushko was just thinking about this issue again and was wondering in your tests of you had tried doing a loop of inferences on the model itself without the explainer, trying to isolate where the leak is. Iirc Pytorch models should be stateless with each new forward pass, so it must likely be with the explainer.

from transformers-interpret.

roma-glushko avatar roma-glushko commented on May 17, 2024

Hi @cdpierse,

you had tried doing a loop of inferences on the model itself without the explainer, trying to isolate where the leak is

Yes, I have tried to comment out the explainer part and the pytorch inference itself seemed to be stable in terms of RAM usage. So yeah, I can confirm that results of my tests pointed to the explainer code. However, it was not clear what exactly the issue in the explainer iteself nor in captum lib.

from transformers-interpret.

Related Issues (20)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.