Comments (4)
It seems to be caused by the INFERENCE MODE?
I added torch.is_inference_mode_enabled
to with torch.enable_grad():
, and found that it was False
in validation_step
, but True
in test_step
.
with torch.enable_grad():
N = 5
f = lambda x: x ** 2
x = torch.randn(N, requires_grad=True)
# print(x.requires_grad)
y = f(x)
I_N = torch.eye(N)
print(torch.is_inference_mode_enabled()) # False in train and validation, but True in test.
# Sequential approach
jacobian_rows = [torch.autograd.grad(y, x, v, retain_graph=True,allow_unused=True)[0]
for v in I_N.unbind()]
jacobian = torch.stack(jacobian_rows)
return jacobian
from pytorch-lightning.
with torch.enable_grad():
x= torch.randn(3,10).requires_grad_(True)
y = x+1
print(f"{torch.is_inference_mode_enabled()=}")
print(f"{x.requires_grad=}")
print(f"{y.requires_grad=}")
torch.is_inference_mode_enabled()=False
x.requires_grad=True
y.requires_grad=True
with torch.inference_mode():
with torch.enable_grad():
x= torch.randn(3,1).requires_grad_(True)
y = x+1
print(f"{torch.is_inference_mode_enabled()=}")
print(f"{x.requires_grad=}")
print(f"{y.requires_grad=}")
torch.is_inference_mode_enabled()=True
x.requires_grad=True
y.requires_grad=False
I want all subsequent gradients as well in test_step, how do I do that?
from pytorch-lightning.
Do I seem to have found a solution?
Use with torch.inference_mode()
instead of with torch.enable_grad()
.
with torch.inference_mode():
x= torch.randn(3,10).requires_grad_(True)
y = x+1
print(f"{torch.is_inference_mode_enabled()=}")
print(f"{x.requires_grad=}")
print(f"{y.requires_grad=}")
with torch.inference_mode(mode=False):
# with torch.enable_grad():
x= torch.randn(3,10).requires_grad_(True)
y = x+1
print(f"{torch.is_inference_mode_enabled()=}")
print(f"{x.requires_grad=}")
print(f"{y.requires_grad=}")
torch.is_inference_mode_enabled()=True
x.requires_grad=True
y.requires_grad=False
torch.is_inference_mode_enabled()=False
x.requires_grad=True
y.requires_grad=True
from pytorch-lightning.
Inference mode is the default for validation/testing in the Trainer:
https://lightning.ai/docs/pytorch/stable/common/trainer.html#inference-mode
You can't take gradients in the validation/test_step methods by default. But you can turn it off by setting and using torch.no_grad(enabled=False)
or like you've done above.
from pytorch-lightning.
Related Issues (20)
- ModelCheckpoint save ckpts at the end of every epoch even in step-saving strategy
- training time increase epoch by epoch HOT 1
- `pkg_resources` Deprecation Warnings on import HOT 2
- module statistics has no attribute mean HOT 3
- Improve error message when object is passed to Trainer callbacks HOT 2
- Sometimes I get Dataset Errors when using the lightning module in a distributed manor
- Checkpoint silently not correctly restored. HOT 3
- dirpath isn't updated when logger chages dir after first run
- wandblogger : File handles cannot be properly released HOT 1
- Please allow automatic optimization for multiple optimizers again. HOT 2
- Sometimes error when logging model graph with `functional.interpolate` and `deterministic=True`
- Adding support for Python 12? HOT 2
- Pytorch FSDPStrategy saving checkpoint behavior work correctly?
- pl.TrainResult not found in 2.3.3 HOT 1
- LightningCLI doesn't save optimizer's configuration if not explicitly given HOT 4
- Get `num_nodes` automatically HOT 2
- What happens during training with HuggingFace models in eval mode? HOT 2
- OptimizerLRScheduler typing does not fit examples HOT 3
- TypeError: on_train_batch_start() takes 3 positional arguments but 4 were given HOT 3
- CSV Logger acts weirdly in Callbacks
Recommend Projects
-
React
A declarative, efficient, and flexible JavaScript library for building user interfaces.
-
Vue.js
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
-
Typescript
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
-
TensorFlow
An Open Source Machine Learning Framework for Everyone
-
Django
The Web framework for perfectionists with deadlines.
-
Laravel
A PHP framework for web artisans
-
D3
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
-
Recommend Topics
-
javascript
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
-
web
Some thing interesting about web. New door for the world.
-
server
A server is a program made to process requests and deliver data to clients.
-
Machine learning
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
-
Visualization
Some thing interesting about visualization, use data art
-
Game
Some thing interesting about game, make everyone happy.
Recommend Org
-
Facebook
We are working to build community through open source technology. NB: members must have two-factor auth.
-
Microsoft
Open source projects and samples from Microsoft.
-
Google
Google ❤️ Open Source for everyone.
-
Alibaba
Alibaba Open Source for everyone
-
D3
Data-Driven Documents codes.
-
Tencent
China tencent open source team.
from pytorch-lightning.