Comments (8)
Hi @roma-glushko,
Thanks for using the package and for such a detailed overview of the issue you are running into. I think I have a fix for it.
It's not well documented yet (which I must do) but in one of the more recent releases, I added a new feature to control the internal batch size that is passed when calculating the attributions. By default, the example text and every n_steps used to approximate the attributions are all calculated at once. This is available for every explainer and future explainers.
However, if you pass in internal_batch_size
into the explainer objects' callable method it should give you some fine-grained control over this. So for your example, it might look something like:
# find a batch size number that works for you, smaller batch will result in slower calculation times
self.explainer(agent_notes, internal_batch_size=2)[1:-1]
I actually ran into this same issue myself with a demo app on streamlit that had very limited RAM I did the internal batch size trick there too and it worked really well.
You are totally correct btw as to why this is happening. Gradients are calculated for n_steps
which by default is 50. So if the internal batch size is not specified that means 50 sets of gradients will be stored in RAM all at once, which for larger models like Roberta starts to become a problem. Using the internal batch size helps us to do this piece by piece in a more manageable way.
I hope this helps, do let me know how you get on.
from transformers-interpret.
Hi @cdpierse 👋 do you know what could be the workaround here?
from transformers-interpret.
Hi @cdpierse Charles, thank you for the swift response 👍
The internal_batch_size
does help a bit, but I believe we are talking about two different issues.
The internal_batch_size
argument seems to constraint the peak RAM usage during the gradient calculations. This is great for environments with the limited resources. However, unfortunately it doesn't fix the issue with RAM leaks. The memory keeps leaking, but with the slower pace. It was sufficient for the Streamlit app because Streamlit reloads the whole script each time we change any UI controls (AFAIK). That seemed to clean up leaked memory naturally and the application stayed resilient.
In case of serving an application via conventional web frameworks (like FastAPI, Flask, etc), they don't restart the whole application over and over again, so the leaked memory is not cleared up like it happened in the Streamlit case. As a result, the application is going to run out of memory no matter how many we are ready to give it.
Could be so that we are missing detach()
call somewhere in the process of attribution calculation, so the the full gradient graph accumulates like a snowflake over the time?
from transformers-interpret.
Alright, I think it also makes sense to create an issue in Captum repository: pytorch/captum#866
from transformers-interpret.
Hi @roma-glushko,
Thanks for looking into this further. I think you're right regarding the issue, I ran a long-running process locally and noticed a gradual memory increase regardless of batch size. I also tried experimenting with zeroing the gradients on the model but it doesn't seem to work. Most likely it's as you say that the n_steps
gradients accumulated throughout the process are to blame.
I'll keep trying to debug a bit on my side to see if there is a way for me to force accumulation to be dropped. It might be on Captums' side too so I'll take a look at that too and see if I can up with a PR.
Will keep you posted!
from transformers-interpret.
Hi @cdpierse, Thank you for keeping investigation in this task 🙏
experimenting with zeroing the gradients on the model but it doesn't seem to work
Yeah, can confirm it did not work for me neither.
Please let me know if I can help somehow with this issue. I would love to but I have already ran out of ideas how to track this issue down.
from transformers-interpret.
Hi @roma-glushko was just thinking about this issue again and was wondering in your tests of you had tried doing a loop of inferences on the model itself without the explainer, trying to isolate where the leak is. Iirc Pytorch models should be stateless with each new forward pass, so it must likely be with the explainer.
from transformers-interpret.
Hi @cdpierse,
you had tried doing a loop of inferences on the model itself without the explainer, trying to isolate where the leak is
Yes, I have tried to comment out the explainer part and the pytorch inference itself seemed to be stable in terms of RAM usage. So yeah, I can confirm that results of my tests pointed to the explainer code. However, it was not clear what exactly the issue in the explainer iteself nor in captum lib.
from transformers-interpret.
Related Issues (20)
- How to use transformers-interpret for sequencelabelling, for example layoutlmv3 or v3 HOT 1
- MultiLabelSequenceClassificationExplainer potentially bugged. HOT 14
- ImportError: cannot import name 'PairwiseSequenceClassificationExplainer' HOT 1
- How to interpret the model fine tuning on the pre-trained ViT model using the imagery with larger resolution (500 * 500) than the pre-trained dataset (224 * 224)
- Token Classification Memory Issue
- Issue using BertTokenizer (AttributeError) HOT 2
- 'Bert' object has no attribute 'config'
- Text attribution fails for XLM-Roberta models HOT 4
- Is it normal that attribution takes multiple seconds per text, even on a GPU? HOT 1
- ZeroShotClassificationExplainer appears to be broken
- Prediction differs from non-explainable evaluation
- Output probability - SequenceClassificationExplainer
- Support for Summarization models HOT 3
- Support for Longformer
- ImageClassificationExplainer: AttributeError: ndim when trying to visualize. HOT 3
- Issue with Zero Shot Classifier
- How to use other types of transformers models? HOT 1
- Support for Reformer
- Broken link for Captum Algorithm Overview in the README
- Using this for Music domain models
Recommend Projects
-
React
A declarative, efficient, and flexible JavaScript library for building user interfaces.
-
Vue.js
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
-
Typescript
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
-
TensorFlow
An Open Source Machine Learning Framework for Everyone
-
Django
The Web framework for perfectionists with deadlines.
-
Laravel
A PHP framework for web artisans
-
D3
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
-
Recommend Topics
-
javascript
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
-
web
Some thing interesting about web. New door for the world.
-
server
A server is a program made to process requests and deliver data to clients.
-
Machine learning
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
-
Visualization
Some thing interesting about visualization, use data art
-
Game
Some thing interesting about game, make everyone happy.
Recommend Org
-
Facebook
We are working to build community through open source technology. NB: members must have two-factor auth.
-
Microsoft
Open source projects and samples from Microsoft.
-
Google
Google ❤️ Open Source for everyone.
-
Alibaba
Alibaba Open Source for everyone
-
D3
Data-Driven Documents codes.
-
Tencent
China tencent open source team.
from transformers-interpret.