Code Monkey home page Code Monkey logo

Comments (4)

kiya00 avatar kiya00 commented on July 17, 2024

run this script

import torch
import torchvision

import os
os.environ["NVIDIA_TF32_OVERRIDE"]="0"
os.environ["CUBLAS_WORKSPACE_CONFIG"]=":4096:8"
torch.manual_seed(42)
import random
random.seed(42)
torch.use_deterministic_algorithms(True)

model = torchvision.models.resnet18(weights=None).to(device="cuda", dtype=torch.float32)
x = torch.randn((1, 3, 224, 224), dtype=torch.float32, device="cuda", requires_grad=True)
print(torch.autograd.gradcheck(model, (x,)))

has GradcheckError:

/usr/local/lib/python3.10/dist-packages/torch/autograd/graph.py:768: UserWarning: Attempting to run cuBLAS, but there was no current CUDA context! Attempting to set the primary context... (Triggered internally at /opt/pytorch/pytorch/aten/src/ATen/cuda/CublasHandlePool.cpp:135.)
  return Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
root@9340b8cf8485:/wayan/lightning-thunder# python thunder/tests/testtrace.py
/usr/local/lib/python3.10/dist-packages/torch/autograd/gradcheck.py:920: UserWarning: Input #0 requires gradient and is not a double precision floating point or complex. This check will likely fail if all the inputs are not of double precision floating point or complex.
  warnings.warn(
/usr/local/lib/python3.10/dist-packages/torch/autograd/graph.py:768: UserWarning: Attempting to run cuBLAS, but there was no cutext... (Triggered internally at /opt/pytorch/pytorch/aten/src/ATen/cuda/CublasHandlePool.cpp:135.)
  return Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
Traceback (most recent call last):
  File "/wayan/lightning-thunder/thunder/tests/testtrace.py", line 15, in <module>
    print(torch.autograd.gradcheck(model, (x,)))
  File "/usr/local/lib/python3.10/dist-packages/torch/autograd/gradcheck.py", line 2053, in gradcheck
    return _gradcheck_helper(**args)
  File "/usr/local/lib/python3.10/dist-packages/torch/autograd/gradcheck.py", line 2082, in _gradcheck_helper
    _gradcheck_real_imag(
  File "/usr/local/lib/python3.10/dist-packages/torch/autograd/gradcheck.py", line 1492, in _gradcheck_real_imag
    gradcheck_fn(
  File "/usr/local/lib/python3.10/dist-packages/torch/autograd/gradcheck.py", line 1633, in _slow_gradcheck
    raise GradcheckError(
torch.autograd.gradcheck.GradcheckError: Jacobian mismatch for output 0 with respect to input 0,
numerical:tensor([[ 0.1043,  0.0522,  0.0298,  ..., -0.0149, -0.0447, -0.0596],
        [ 0.1043, -0.0447,  0.0298,  ...,  0.1341,  0.0298, -0.0894],
        [ 0.1192, -0.1043,  0.0000,  ..., -0.0596, -0.0149,  0.0596],
        ...,
        [ 0.1788, -0.0820, -0.0149,  ...,  0.0224,  0.1341, -0.0596],
        [ 0.0000, -0.2459,  0.1639,  ...,  0.0894, -0.1267, -0.0596],
        [ 0.1043, -0.0075, -0.1043,  ...,  0.0894, -0.0969,  0.0000]],
       device='cuda:0')
analytical:tensor([[ 2.5345e-04, -2.1945e-04,  7.5599e-05,  ..., -1.5271e-04,
         -1.6242e-04,  4.9330e-04],
        [-2.0753e-04,  5.1979e-04,  8.1766e-05,  ..., -2.5569e-04,
         -2.4477e-04,  1.9414e-04],
        [-1.2130e-04,  1.2330e-04, -2.3220e-04,  ...,  2.7823e-04,
          2.9276e-04, -1.9633e-04],
        ...,
        [-7.2192e-05, -9.3861e-05, -4.2660e-05,  ..., -7.6299e-05,
         -6.6284e-05,  1.2527e-05],
        [ 4.0978e-05,  2.3847e-05,  2.6876e-05,  ..., -2.3141e-05,
         -1.0444e-06,  1.5903e-05],
        [-4.5947e-05, -1.3556e-05, -8.9267e-05,  ...,  6.1379e-05,
          2.5143e-05,  2.6964e-05]], device='cuda:0')

with float64 it can pass

from lightning-thunder.

t-vi avatar t-vi commented on July 17, 2024

https://pytorch.org/docs/stable/generated/torch.autograd.gradcheck.gradcheck.html says

Note
The default values are designed for input of double precision. This check will likely fail if input is of less precision, e.g., FloatTensor.

however, the values above seem very far off, so I'm wondering whether the operators we call have some bug / input assumptions not satisfied etc.

from lightning-thunder.

kiya00 avatar kiya00 commented on July 17, 2024

Comparison of fp64 and fp32 results:

Torch eager fp64 vs thunder torchex fp32:
Mismatched elements: 9248 / 9408 (98.3%)
Greatest absolute difference: 0.00013152781110292722 at index (33, 2, 5, 6) (up to 1e-07 allowed)
Greatest relative difference: 0.04311691189015545 at index (25, 2, 4, 4) (up to 1e-07 allowed)

Torch eager fp64 vs torch eager fp32:
Mismatched elements: 9242 / 9408 (98.2%)
Greatest absolute difference: 9.72576008653192e-05 at index (44, 2, 4, 1) (up to 1e-07 allowed)
Greatest relative difference: 0.049338367753292256 at index (25, 2, 4, 4) (up to 1e-07 allowed)

Torch eager fp32 vs thunder torchex fp32:
Mismatched elements: 5468 / 9408 (58.1%)
Greatest absolute difference: 0.00011849403381347656 at index (2, 2, 5, 0) (up to 1e-05 allowed)
Greatest relative difference: 0.027340635657310486 at index (41, 2, 6, 6) (up to 1.3e-06 allowed)

script
import torch
import torchvision

import os
os.environ["NVIDIA_TF32_OVERRIDE"]="0"
# os.environ["CUBLAS_WORKSPACE_CONFIG"]=":4096:8"
torch.manual_seed(42)
import random
random.seed(42)
# torch.use_deterministic_algorithms(True)

model_fp32 = torchvision.models.resnet18(weights=None).to(device="cuda", dtype=torch.float32)
torch_model_fp32 = torchvision.models.resnet18(weights=None).to(device="cuda", dtype=torch.float32)
torch_model_fp64 = torchvision.models.resnet18(weights=None).to(device="cuda", dtype=torch.float64)

model_fp32=model_fp32.train()
torch_model_fp32 = torch_model_fp32.train()
torch_model_fp64 = torch_model_fp64.train()

model_fp32.load_state_dict(torch_model_fp64.state_dict())
torch_model_fp32.load_state_dict(torch_model_fp64.state_dict())

x = torch.randn((1, 3, 224, 224), dtype=torch.float64, device="cuda")
import thunder
jitted_fp32 = thunder.jit(model_fp32, executors=[thunder.pytorch_executor])

out1 = jitted_fp32(x.to(torch.float32))
out2 = torch_model_fp32(x.to(torch.float32))
out3 = torch_model_fp64(x)
torch.testing.assert_close(out1, out2)

thunder_grads_fp32 = torch.autograd.grad(out1, jitted_fp32.parameters(), torch.ones_like(out1))
torch_grads_fp32 = torch.autograd.grad(out2, torch_model_fp32.parameters(), torch.ones_like(out2))
torch_grads_fp64 = torch.autograd.grad(out3, torch_model_fp64.parameters(), torch.ones_like(out3))
tmp1 = [x.to(torch.float64) for x in thunder_grads_fp32]
tmp2 = [x.to(torch.float64) for x in torch_grads_fp32]

torch.testing.assert_close(thunder_grads_fp32, torch_grads_fp32)
# print("thunder:")
# torch.testing.assert_close(tmp1, torch_grads_fp64)
# print("torch:")
# torch.testing.assert_close(tmp2, torch_grads_fp64)

from lightning-thunder.

tfogal avatar tfogal commented on July 17, 2024

triage review:

  • seems to be something in the backward pass
  • some discrepancy w.r.t. tf32 in forward and backward?
  • but also happens in fp64, so seems likely that's it not tf32-specific, or rather not even type-specific
  • maybe try to override modules with identity and see if that helps narrow down which module is failing

from lightning-thunder.

Related Issues (20)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.