Code Monkey home page Code Monkey logo

Comments (6)

k223kim avatar k223kim commented on May 13, 2024 2

Hello Team! I am Kaeun, one of the new contributors at Thunder who is having a lot of fun with these tasks. I am wondering if it would be ok for me to handle this issue. It seems like I have to update torch/__init__.py with something like:

@torchsymbol(torch.nn.functional.mse_loss, id="torch.nn.functional.mse_loss", is_method=False)
def mse_loss(a: TensorLike, b: TensorLike, /) -> TensorLike:

Would it be ok if I take care of this? Appreciate your help and support!
Best,
Kaeun

from lightning-thunder.

mruberry avatar mruberry commented on May 13, 2024 1

Hello Team! I am Kaeun, one of the new contributors at Thunder who is having a lot of fun with these tasks. I am wondering if it would be ok for me to handle this issue. It seems like I have to update torch/__init__.py with something like:

@torchsymbol(torch.nn.functional.mse_loss, id="torch.nn.functional.mse_loss", is_method=False)
def mse_loss(a: TensorLike, b: TensorLike, /) -> TensorLike:

Would it be ok if I take care of this? Appreciate your help and support! Best, Kaeun

Absolutely! Anyone is welcome to submit a PR to address this issue. Yes, this PR would start by updating torch/__init__.py. A few additional notes:

  • Make sure to capture the additional arguments to the function, too
  • The decomposition of this function may be a little tricky
  • The torch executor can be updated to run the operation if it's called without having to execute the decomposition

from lightning-thunder.

mruberry avatar mruberry commented on May 13, 2024 1

Hi @mruberry! I am almost done with the implementation regarding mse_loss. However, I have some questions and would appreciate your help! (this will help me a lot :) )

Happy to help!

  1. I am using the following script to quickly confirm the forward and backward pass of mse_loss.
import torch
import thunder

reduction = "none"

def mse(input, target):
    output = torch.nn.functional.mse_loss(input, target, reduction=reduction)
    return output

input = torch.randn(3, 5, requires_grad=True)
target = torch.randn(3, 5)

cfn = thunder.jit(mse)
actual_loss = cfn(input, target)

grad_jfn = thunder.core.transforms.grad(cfn)
actual_grad, = grad_jfn(input, target)

expected_loss = torch.nn.functional.mse_loss(input, target, reduction = reduction)
go = torch.ones_like(expected_loss)
expected_grad, = torch.autograd.grad(torch.nn.functional.mse_loss(input, target, reduction=reduction), input, go)

print("Max error in loss:", (actual_loss - expected_loss).abs().max().item())
print("Max error in logits grad:", (actual_grad - expected_grad).abs().max().item())

I am not sure if this is the proper way to calculate the gradient. Would you be able to check this for me? Also, just to check, does the above script call mse_loss with decomposition (the one in torch/__init__.py not in torchex.py)?

You should be able to do the following, which might be simpler:

actual_loss = cfn(inp, target)
actual_loss.sum().backward()
thunder_grad = inp.grad
inp.grad = None

expected_loss = fn(inp, target)
expected_loss.sum().backward()
pytorch_grad = inp.grad

assert_close(thunder_grad, pytorch_grad)

@IvanYashchuk can correct me if I'm mistaken about this. Take a look at the assert_close utility for comparing tensors.

Your question about the decomposition is great. I'm not sure because it depends on the details the mse_loss implementation and any updates to the torchexecutor. You can see how the program is actually being executed by printing the execution trace of the program, which will show whether torch.nn.functional.mse_loss is being called directly, or a decomposition is being called instead.

Note for @jjsjann123 @t-vi @IvanYashchuk, we should probably reintroduce a developer option for the jit to force executors like torch to execute only primitives. This would be straightforward to do.

One way to test the decomposition locally is to not register a direct binding of mse_loss with the torch executor and check that the execution trace decomposes as expected. Then you can add the direct binding for the loss and verify it's working as expected, too. In the future this should be easier (once the above developer option is available).

  1. Regarding the backward pass that will be added in torch/__init__.py, I am trying to mimic what has been done with _cross_entropy_grad. However, I am having a hard time understanding the main purpose of _cross_entorpy_grad and cross_entropy_backward. cross_entropy_backward simply returns a TensorProxy of the gradient that has been computed through get_grad(fwd). Why would we want to separate the two?
    When taking a look at the implementation of log_softmax_backward, it explicitly calculates the derivative of log_softmax and does not do any get_grad or put_grad. Why is there a difference between the cross_entropy's backward pass and log_softmax's backward pass?

This is a great question! First, if you have a decomposition for mse_loss, and a grad formula is defined for every element of that decomposition, then you will also define an (implicit) grad formula for mse_loss, and shouldn't have to add a custom grad formula at all.

Second, the current state of grad formulas in thunder can be confusing! @IvanYashchuk can help direct you here. It would be helpful if you submitting a draft PR, so we can look at the code in more detail before making a recommendation.

For the mse_loss, I am assuming I should be doing something like the cross_entropy's implementation. Would that be a proper approach? (having mse_loss_backward and _mse_loss_grad)

Let's take a look at a draft PR, maybe even one where grad support isn't considered to start, and then we can discusss!

  1. Referring to the third bullet point in your comment, I am also currently working on implementing mse_loss in torchex.py. How would I call the mse_loss without having to execute the decomposition (as you have mentioned)? I am wondering if there is a way to check if I have implemented the non-decomposed mse_loss in torchex.py properly.

The easiest way to check is to inspect the execution trace and verify that torch.nn.functional.mse_loss is called directly. An example of a "direct" lowering to torch is the dropout operation. See here:

dropout = _register_torch_operation("dropout", module=torch.nn.functional)

and here:

_register_implementation(ltorch.dropout, dropout, checker=_always_executable)

The torch executor has some helper functions that make this straightforward. The _register_torch_operation function tells thunder how to call torch operations (like dropout), and the _register_implementation function tells thunder that it can call PyTorch's dropout to execute thunder.torch.dropout.

I hope you understand that I am in the process of learning how things work in thunder and would appreciate your help :)

These are great questions, and it's great you're asking them. I hope these responses are helpful. Let us know if you have additional questions!

from lightning-thunder.

mruberry avatar mruberry commented on May 13, 2024 1

Hi @mruberry!

Thanks so much for the detailed explanation!😄 It definitely helped me to further understand how trace execution can help my testing during the implementation and how gradients are calculated in general. I have submitted a draft PR that at least passes tests/test_ops.py.

Awesome! I'm glad that was helpful.

Also, using the script that you have provided above, it seems to calculate the gradient and forward pass correctly. (+ I have performed traces = thunder.last_traces(cfn) and confirmed that it is calling the decomposition of mse_loss) It'll be awesome if you can take a look so we can further discuss about the implementation.

I look forward to reviewing the PR in more detail.

Currently, I am facing issues with tests/test_grad.py. Specifically, there is some discrepancy when running test_vjp_correctness_mse_loss_torch_cpu_float64 which I suspect is due to some implementation in torchex.py (strange how the current implementation's grad output is different to torch.ops.aten.mse_loss_backward's output).

Interesting! Let's discuss more with @IvanYashchuk on the PR itself.

I do have another question regarding your comment:

One way to test the decomposition locally is to not register a direct binding of mse_loss with the torch executor and check that the execution trace decomposes as expected.

Does this mean that I should not have something like:

mse_loss = _register_torch_operation("mse_loss", module=torch.nn.functional)

in torchex.py? And just run when doing

actual_loss = cfn(input, target)
actual_loss.sum().backward()
thunder_grad = input.grad
traces = thunder.last_traces(cfn)

I should be able to see the decomposition of mse_loss?

Yes, I think that's correct. Without the direct binding of mse_loss registered in the torch executor it should have to execute it by decomposing it into other operations.

Then you can add the direct binding for the loss and verify it's working as expected, too. In the future this should be easier (once the above developer option is available).

This sounds like, once I add _register_torch_operation for mse_loss, I should be able to see torch.nn.functional.mse_loss in traces = thunder.last_traces(cfn). However, I am only seeing the decomposed version of mse_loss. Did I misunderstand something? Please let me know!

If you both register the operation with _register_torch_operation and then bind it with _register_implementation you should see the execution trace call torch.nn.functional.mse_loss directly. The decomposition will still appear below it, but it will be commented out.

I am learning a lot from your comments and having a lot of fun! Thank you so much :)

You're very welcome; I'm glad you're having fun.

from lightning-thunder.

k223kim avatar k223kim commented on May 13, 2024

Hi @mruberry! I am almost done with the implementation regarding mse_loss. However, I have some questions and would appreciate your help! (this will help me a lot :) )

  1. I am using the following script to quickly confirm the forward and backward pass of mse_loss.
import torch
import thunder

reduction = "none"

def mse(input, target):
    output = torch.nn.functional.mse_loss(input, target, reduction=reduction)
    return output

input = torch.randn(3, 5, requires_grad=True)
target = torch.randn(3, 5)

cfn = thunder.jit(mse)
actual_loss = cfn(input, target)

grad_jfn = thunder.core.transforms.grad(cfn)
actual_grad, = grad_jfn(input, target)

expected_loss = torch.nn.functional.mse_loss(input, target, reduction = reduction)
go = torch.ones_like(expected_loss)
expected_grad, = torch.autograd.grad(torch.nn.functional.mse_loss(input, target, reduction=reduction), input, go)

print("Max error in loss:", (actual_loss - expected_loss).abs().max().item())
print("Max error in logits grad:", (actual_grad - expected_grad).abs().max().item())

I am not sure if this is the proper way to calculate the gradient. Would you be able to check this for me? Also, just to check, does the above script call mse_loss with decomposition (the one in torch/__init__.py not in torchex.py)?

  1. Regarding the backward pass that will be added in torch/__init__.py, I am trying to mimic what has been done with _cross_entropy_grad. However, I am having a hard time understanding the main purpose of _cross_entorpy_grad and cross_entropy_backward. cross_entropy_backward simply returns a TensorProxy of the gradient that has been computed through get_grad(fwd). Why would we want to separate the two?
    When taking a look at the implementation of log_softmax_backward, it explicitly calculates the derivative of log_softmax and does not do any get_grad or put_grad. Why is there a difference between the cross_entropy's backward pass and log_softmax's backward pass?
    For the mse_loss, I am assuming I should be doing something like the cross_entropy's implementation. Would that be a proper approach? (having mse_loss_backward and _mse_loss_grad)

  2. Referring to the third bullet point in your comment, I am also currently working on implementing mse_loss in torchex.py. How would I call the mse_loss without having to execute the decomposition (as you have mentioned)? I am wondering if there is a way to check if I have implemented the non-decomposed mse_loss in torchex.py properly.

I hope you understand that I am in the process of learning how things work in thunder and would appreciate your help :)

from lightning-thunder.

k223kim avatar k223kim commented on May 13, 2024

Hi @mruberry!

Thanks so much for the detailed explanation!😄 It definitely helped me to further understand how trace execution can help my testing during the implementation and how gradients are calculated in general.
I have submitted a draft PR that at least passes tests/test_ops.py. Also, using the script that you have provided above, it seems to calculate the gradient and forward pass correctly. (+ I have performed traces = thunder.last_traces(cfn) and confirmed that it is calling the decomposition of mse_loss) It'll be awesome if you can take a look so we can further discuss about the implementation.

Currently, I am facing issues with tests/test_grad.py. Specifically, there is some discrepancy when running test_vjp_correctness_mse_loss_torch_cpu_float64 which I suspect is due to some implementation in torchex.py (strange how the current implementation's grad output is different to torch.ops.aten.mse_loss_backward's output).

I do have another question regarding your comment:

One way to test the decomposition locally is to not register a direct binding of mse_loss with the torch executor and check that the execution trace decomposes as expected.

Does this mean that I should not have something like:

mse_loss = _register_torch_operation("mse_loss", module=torch.nn.functional)

in torchex.py? And just run when doing

actual_loss = cfn(input, target)
actual_loss.sum().backward()
thunder_grad = input.grad
traces = thunder.last_traces(cfn)

I should be able to see the decomposition of mse_loss?

Then you can add the direct binding for the loss and verify it's working as expected, too. In the future this should be easier (once the above developer option is available).

This sounds like, once I add _register_torch_operation for mse_loss, I should be able to see torch.nn.functional.mse_loss in traces = thunder.last_traces(cfn). However, I am only seeing the decomposed version of mse_loss. Did I misunderstand something? Please let me know!

I am learning a lot from your comments and having a lot of fun! Thank you so much :)

from lightning-thunder.

Related Issues (20)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.