Comments (6)
Hello Team! I am Kaeun, one of the new contributors at Thunder who is having a lot of fun with these tasks. I am wondering if it would be ok for me to handle this issue. It seems like I have to update torch/__init__.py
with something like:
@torchsymbol(torch.nn.functional.mse_loss, id="torch.nn.functional.mse_loss", is_method=False)
def mse_loss(a: TensorLike, b: TensorLike, /) -> TensorLike:
Would it be ok if I take care of this? Appreciate your help and support!
Best,
Kaeun
from lightning-thunder.
Hello Team! I am Kaeun, one of the new contributors at Thunder who is having a lot of fun with these tasks. I am wondering if it would be ok for me to handle this issue. It seems like I have to update
torch/__init__.py
with something like:@torchsymbol(torch.nn.functional.mse_loss, id="torch.nn.functional.mse_loss", is_method=False) def mse_loss(a: TensorLike, b: TensorLike, /) -> TensorLike:
Would it be ok if I take care of this? Appreciate your help and support! Best, Kaeun
Absolutely! Anyone is welcome to submit a PR to address this issue. Yes, this PR would start by updating torch/__init__.py
. A few additional notes:
- Make sure to capture the additional arguments to the function, too
- The decomposition of this function may be a little tricky
- The torch executor can be updated to run the operation if it's called without having to execute the decomposition
from lightning-thunder.
Hi @mruberry! I am almost done with the implementation regarding
mse_loss
. However, I have some questions and would appreciate your help! (this will help me a lot :) )
Happy to help!
- I am using the following script to quickly confirm the forward and backward pass of
mse_loss
.import torch import thunder reduction = "none" def mse(input, target): output = torch.nn.functional.mse_loss(input, target, reduction=reduction) return output input = torch.randn(3, 5, requires_grad=True) target = torch.randn(3, 5) cfn = thunder.jit(mse) actual_loss = cfn(input, target) grad_jfn = thunder.core.transforms.grad(cfn) actual_grad, = grad_jfn(input, target) expected_loss = torch.nn.functional.mse_loss(input, target, reduction = reduction) go = torch.ones_like(expected_loss) expected_grad, = torch.autograd.grad(torch.nn.functional.mse_loss(input, target, reduction=reduction), input, go) print("Max error in loss:", (actual_loss - expected_loss).abs().max().item()) print("Max error in logits grad:", (actual_grad - expected_grad).abs().max().item())I am not sure if this is the proper way to calculate the gradient. Would you be able to check this for me? Also, just to check, does the above script call
mse_loss
with decomposition (the one intorch/__init__.py
not intorchex.py
)?
You should be able to do the following, which might be simpler:
actual_loss = cfn(inp, target)
actual_loss.sum().backward()
thunder_grad = inp.grad
inp.grad = None
expected_loss = fn(inp, target)
expected_loss.sum().backward()
pytorch_grad = inp.grad
assert_close(thunder_grad, pytorch_grad)
@IvanYashchuk can correct me if I'm mistaken about this. Take a look at the assert_close
utility for comparing tensors.
Your question about the decomposition is great. I'm not sure because it depends on the details the mse_loss
implementation and any updates to the torchexecutor. You can see how the program is actually being executed by printing the execution trace of the program, which will show whether torch.nn.functional.mse_loss
is being called directly, or a decomposition is being called instead.
Note for @jjsjann123 @t-vi @IvanYashchuk, we should probably reintroduce a developer option for the jit to force executors like torch to execute only primitives. This would be straightforward to do.
One way to test the decomposition locally is to not register a direct binding of mse_loss
with the torch executor and check that the execution trace decomposes as expected. Then you can add the direct binding for the loss and verify it's working as expected, too. In the future this should be easier (once the above developer option is available).
- Regarding the backward pass that will be added in
torch/__init__.py
, I am trying to mimic what has been done with_cross_entropy_grad
. However, I am having a hard time understanding the main purpose of_cross_entorpy_grad
andcross_entropy_backward
.cross_entropy_backward
simply returns aTensorProxy
of the gradient that has been computed throughget_grad(fwd)
. Why would we want to separate the two?
When taking a look at the implementation oflog_softmax_backward
, it explicitly calculates the derivative oflog_softmax
and does not do anyget_grad
orput_grad
. Why is there a difference between thecross_entropy
's backward pass andlog_softmax
's backward pass?
This is a great question! First, if you have a decomposition for mse_loss
, and a grad formula is defined for every element of that decomposition, then you will also define an (implicit) grad formula for mse_loss
, and shouldn't have to add a custom grad formula at all.
Second, the current state of grad formulas in thunder can be confusing! @IvanYashchuk can help direct you here. It would be helpful if you submitting a draft PR, so we can look at the code in more detail before making a recommendation.
For the
mse_loss
, I am assuming I should be doing something like thecross_entropy
's implementation. Would that be a proper approach? (havingmse_loss_backward
and_mse_loss_grad
)
Let's take a look at a draft PR, maybe even one where grad support isn't considered to start, and then we can discusss!
- Referring to the third bullet point in your comment, I am also currently working on implementing
mse_loss
intorchex.py
. How would I call themse_loss
without having to execute the decomposition (as you have mentioned)? I am wondering if there is a way to check if I have implemented the non-decomposedmse_loss
intorchex.py
properly.
The easiest way to check is to inspect the execution trace and verify that torch.nn.functional.mse_loss
is called directly. An example of a "direct" lowering to torch is the dropout
operation. See here:
lightning-thunder/thunder/executors/torchex.py
Line 1208 in 649c3d7
and here:
lightning-thunder/thunder/executors/torchex.py
Line 1548 in 649c3d7
The torch executor has some helper functions that make this straightforward. The _register_torch_operation
function tells thunder how to call torch operations (like dropout
), and the _register_implementation
function tells thunder that it can call PyTorch's dropout to execute thunder.torch.dropout
.
I hope you understand that I am in the process of learning how things work in thunder and would appreciate your help :)
These are great questions, and it's great you're asking them. I hope these responses are helpful. Let us know if you have additional questions!
from lightning-thunder.
Hi @mruberry!
Thanks so much for the detailed explanation!😄 It definitely helped me to further understand how trace execution can help my testing during the implementation and how gradients are calculated in general. I have submitted a draft PR that at least passes
tests/test_ops.py
.
Awesome! I'm glad that was helpful.
Also, using the script that you have provided above, it seems to calculate the gradient and forward pass correctly. (+ I have performed
traces = thunder.last_traces(cfn)
and confirmed that it is calling the decomposition ofmse_loss
) It'll be awesome if you can take a look so we can further discuss about the implementation.
I look forward to reviewing the PR in more detail.
Currently, I am facing issues with
tests/test_grad.py
. Specifically, there is some discrepancy when runningtest_vjp_correctness_mse_loss_torch_cpu_float64
which I suspect is due to some implementation intorchex.py
(strange how the current implementation's grad output is different totorch.ops.aten.mse_loss_backward
's output).
Interesting! Let's discuss more with @IvanYashchuk on the PR itself.
I do have another question regarding your comment:
One way to test the decomposition locally is to not register a direct binding of mse_loss with the torch executor and check that the execution trace decomposes as expected.
Does this mean that I should not have something like:
mse_loss = _register_torch_operation("mse_loss", module=torch.nn.functional)in
torchex.py
? And just run when doingactual_loss = cfn(input, target) actual_loss.sum().backward() thunder_grad = input.grad traces = thunder.last_traces(cfn)I should be able to see the decomposition of
mse_loss
?
Yes, I think that's correct. Without the direct binding of mse_loss
registered in the torch executor it should have to execute it by decomposing it into other operations.
Then you can add the direct binding for the loss and verify it's working as expected, too. In the future this should be easier (once the above developer option is available).
This sounds like, once I add
_register_torch_operation
formse_loss
, I should be able to seetorch.nn.functional.mse_loss
intraces = thunder.last_traces(cfn)
. However, I am only seeing the decomposed version ofmse_loss
. Did I misunderstand something? Please let me know!
If you both register the operation with _register_torch_operation
and then bind it with _register_implementation
you should see the execution trace call torch.nn.functional.mse_loss
directly. The decomposition will still appear below it, but it will be commented out.
I am learning a lot from your comments and having a lot of fun! Thank you so much :)
You're very welcome; I'm glad you're having fun.
from lightning-thunder.
Hi @mruberry! I am almost done with the implementation regarding mse_loss
. However, I have some questions and would appreciate your help! (this will help me a lot :) )
- I am using the following script to quickly confirm the forward and backward pass of
mse_loss
.
import torch
import thunder
reduction = "none"
def mse(input, target):
output = torch.nn.functional.mse_loss(input, target, reduction=reduction)
return output
input = torch.randn(3, 5, requires_grad=True)
target = torch.randn(3, 5)
cfn = thunder.jit(mse)
actual_loss = cfn(input, target)
grad_jfn = thunder.core.transforms.grad(cfn)
actual_grad, = grad_jfn(input, target)
expected_loss = torch.nn.functional.mse_loss(input, target, reduction = reduction)
go = torch.ones_like(expected_loss)
expected_grad, = torch.autograd.grad(torch.nn.functional.mse_loss(input, target, reduction=reduction), input, go)
print("Max error in loss:", (actual_loss - expected_loss).abs().max().item())
print("Max error in logits grad:", (actual_grad - expected_grad).abs().max().item())
I am not sure if this is the proper way to calculate the gradient. Would you be able to check this for me? Also, just to check, does the above script call mse_loss
with decomposition (the one in torch/__init__.py
not in torchex.py
)?
-
Regarding the backward pass that will be added in
torch/__init__.py
, I am trying to mimic what has been done with_cross_entropy_grad
. However, I am having a hard time understanding the main purpose of_cross_entorpy_grad
andcross_entropy_backward
.cross_entropy_backward
simply returns aTensorProxy
of the gradient that has been computed throughget_grad(fwd)
. Why would we want to separate the two?
When taking a look at the implementation oflog_softmax_backward
, it explicitly calculates the derivative oflog_softmax
and does not do anyget_grad
orput_grad
. Why is there a difference between thecross_entropy
's backward pass andlog_softmax
's backward pass?
For themse_loss
, I am assuming I should be doing something like thecross_entropy
's implementation. Would that be a proper approach? (havingmse_loss_backward
and_mse_loss_grad
) -
Referring to the third bullet point in your comment, I am also currently working on implementing
mse_loss
intorchex.py
. How would I call themse_loss
without having to execute the decomposition (as you have mentioned)? I am wondering if there is a way to check if I have implemented the non-decomposedmse_loss
intorchex.py
properly.
I hope you understand that I am in the process of learning how things work in thunder and would appreciate your help :)
from lightning-thunder.
Hi @mruberry!
Thanks so much for the detailed explanation!😄 It definitely helped me to further understand how trace execution can help my testing during the implementation and how gradients are calculated in general.
I have submitted a draft PR that at least passes tests/test_ops.py
. Also, using the script that you have provided above, it seems to calculate the gradient and forward pass correctly. (+ I have performed traces = thunder.last_traces(cfn)
and confirmed that it is calling the decomposition of mse_loss
) It'll be awesome if you can take a look so we can further discuss about the implementation.
Currently, I am facing issues with tests/test_grad.py
. Specifically, there is some discrepancy when running test_vjp_correctness_mse_loss_torch_cpu_float64
which I suspect is due to some implementation in torchex.py
(strange how the current implementation's grad output is different to torch.ops.aten.mse_loss_backward
's output).
I do have another question regarding your comment:
One way to test the decomposition locally is to not register a direct binding of mse_loss with the torch executor and check that the execution trace decomposes as expected.
Does this mean that I should not have something like:
mse_loss = _register_torch_operation("mse_loss", module=torch.nn.functional)
in torchex.py
? And just run when doing
actual_loss = cfn(input, target)
actual_loss.sum().backward()
thunder_grad = input.grad
traces = thunder.last_traces(cfn)
I should be able to see the decomposition of mse_loss
?
Then you can add the direct binding for the loss and verify it's working as expected, too. In the future this should be easier (once the above developer option is available).
This sounds like, once I add _register_torch_operation
for mse_loss
, I should be able to see torch.nn.functional.mse_loss
in traces = thunder.last_traces(cfn)
. However, I am only seeing the decomposed version of mse_loss
. Did I misunderstand something? Please let me know!
I am learning a lot from your comments and having a lot of fun! Thank you so much :)
from lightning-thunder.
Related Issues (20)
- `use_cudagraphs=True` is silently ignored. HOT 1
- Distributed and Bucketing Performance Improvements HOT 1
- using size() twice in a row on same tensor causes interpreter out of sync error with thunder jit HOT 1
- Consider if functions like full_like and expand_as should be symbols (or not)
- max_pool2d:Shape propagation error for input dimensions used by resnet HOT 1
- transformer_engine: clear_collection doesn't clear the tensors saved by transformer_engine_ex ops for backward. HOT 1
- transformer_engine: Fails with Pythia-14m in distributed
- Benchmark: Running benchmark_litgpt with `transformerengine` leads to `nullcontext` not defined.
- uniform_like: The function outputs different values when the input tensor is the same but `requires_grad` is True/False.
- Add activation checkpointing to litGPT benchmarking script HOT 1
- utils.find_producer_symbols is too aggressive at collecting producer symbols not taking into account stop information HOT 1
- Add more nvFuser debug information HOT 2
- reenable apex backward / phantom grad tests without wrecking CI HOT 2
- Investigate GPU memory intense CI tests HOT 1
- Regression on thunder/benchmarks/targets.py
- CSE pass should work on symbols with keyword arguments
- grad transform introduces reduction operation with NumberProxy in axes
- No module named 'flax' when using thunder/benchmarks/benchmark_litgpt.py HOT 1
- Excessive caching for `torch.dtype` objects
- torch.Tensor.size is broken on main HOT 1
Recommend Projects
-
React
A declarative, efficient, and flexible JavaScript library for building user interfaces.
-
Vue.js
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
-
Typescript
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
-
TensorFlow
An Open Source Machine Learning Framework for Everyone
-
Django
The Web framework for perfectionists with deadlines.
-
Laravel
A PHP framework for web artisans
-
D3
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
-
Recommend Topics
-
javascript
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
-
web
Some thing interesting about web. New door for the world.
-
server
A server is a program made to process requests and deliver data to clients.
-
Machine learning
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
-
Visualization
Some thing interesting about visualization, use data art
-
Game
Some thing interesting about game, make everyone happy.
Recommend Org
-
Facebook
We are working to build community through open source technology. NB: members must have two-factor auth.
-
Microsoft
Open source projects and samples from Microsoft.
-
Google
Google ❤️ Open Source for everyone.
-
Alibaba
Alibaba Open Source for everyone
-
D3
Data-Driven Documents codes.
-
Tencent
China tencent open source team.
from lightning-thunder.