Code Monkey home page Code Monkey logo

Comments (5)

justinchuby avatar justinchuby commented on May 25, 2024

Could you paste the difference too?

from onnx.

dannyfriar avatar dannyfriar commented on May 25, 2024

@justinchuby sure I've added that now

from onnx.

BowenBao avatar BowenBao commented on May 25, 2024

The issue is more suited to be filed under onnxruntime repo.

Looks like an additional check for NaN is needed on line https://github.com/microsoft/onnxruntime/blob/33578cc76efc19b50c9fc011215b2777de193cd1/onnxruntime/core/providers/cuda/math/clip_impl.cu#L14

cc @yuslepukhin to check if my understanding is correct.

from onnx.

adamreeve avatar adamreeve commented on May 25, 2024

@BowenBao I think you're correct that this is an onnxruntime issue rather than onnx, but the problem appears to be in the Min and Max operator implementations rather than Clip.

When the clip bounds are arrays, torch exports this to ONNX as a Max followed by a Min, and I can reproduce this with a simpler example that doesn't use torch and demonstrates the problem using only the Min operator:

import torch  # Not used but initializing the CUDA execution provider fails without it...

import numpy as np
import onnx
from onnx.onnx_pb import TensorProto
import onnxruntime

input = onnx.helper.make_tensor_value_info("input", TensorProto.FLOAT, ["N", 10])
output = onnx.helper.make_tensor_value_info("output", TensorProto.FLOAT, ["N", 10])

min_const = onnx.helper.make_node(
        "Constant",
        inputs=[],
        outputs=["min_const"],
        value=onnx.numpy_helper.from_array(np.array([10.0] * 10, dtype=np.float32)))

min_node = onnx.helper.make_node(
        "Min",
        inputs=["input", "min_const"],
        outputs=["output"],
        )

graph_def = onnx.helper.make_graph(
        nodes=[min_const, min_node],
        name="test-model",
        inputs=[input],
        outputs=[output])

opset_import = onnx.helper.make_opsetid("", 17)

model_def = onnx.helper.make_model(
        graph_def,
        opset_imports=[opset_import],
        producer_name="test")

onnx.checker.check_model(model_def, full_check=True)

model_path = 'test_min.onnx'
onnx.save(model_def, model_path)

input = np.random.randn(3, 10).astype(np.float32)
input[0, :] = np.nan

cpu_session = onnxruntime.InferenceSession(model_path, providers=["CPUExecutionProvider"])
output = cpu_session.run(["output"], {"input": input})
print("CPU session output:")
print(output)

gpu_session = onnxruntime.InferenceSession(model_path, providers=["CUDAExecutionProvider"])
output = gpu_session.run(["output"], {"input": input})
print("GPU session output:")
print(output)

This outputs something like:

CPU session output:
[array([[        nan,         nan,         nan,         nan,         nan,
                nan,         nan,         nan,         nan,         nan],
       [-0.73628867, -1.0645038 , -0.29687342, -0.06496124,  0.40141365,
        -0.36313328, -0.17520589,  0.08746424,  0.30066383, -1.3963577 ],
       [ 0.8791592 ,  0.08518761, -1.1299503 ,  0.12336332, -0.02993149,
         0.1656782 , -1.5760034 ,  0.14083968, -0.37705085,  2.0208693 ]],
      dtype=float32)]
GPU session output:
[array([[10.        , 10.        , 10.        , 10.        , 10.        ,
        10.        , 10.        , 10.        , 10.        , 10.        ],
       [-0.73628867, -1.0645038 , -0.29687342, -0.06496124,  0.40141365,
        -0.36313328, -0.17520589,  0.08746424,  0.30066383, -1.3963577 ],
       [ 0.8791592 ,  0.08518761, -1.1299503 ,  0.12336332, -0.02993149,
         0.1656782 , -1.5760034 ,  0.14083968, -0.37705085,  2.0208693 ]],
      dtype=float32)]

An equivalent example that uses the Max operator instead also shows the same problem.

If I change the original reproduction script to use scalar bounds then this doesn't reproduce the issue, as the Clip operator is used instead, and the CUDA implementation of Clip does seem to handle NaN values correctly:

class ONNXModel(nn.Module):
    def forward(self, x):
        lower = torch.scalar_tensor(-10.0)
        upper = torch.scalar_tensor(10.0)
        x = x.clip(lower, upper)
        x[torch.isnan(x)] = 0.0
        return x

The original model exported to ONNX looks like this in Netron:
clip_with_array_bounds

And with scalar bounds it looks like this:
clip_with_scalar_bounds

from onnx.

tolleybot avatar tolleybot commented on May 25, 2024

As a workaround to this issue, and to ensure explicit handling of NaN values during clipping operations, I implemented a masking strategy prior to the application of the clip function. This strategy is designed to maintain the integrity of NaN values, ensuring they remain unaffected by the clipping process. The solution involves several key steps in the PyTorch model's forward method.


    def forward(self, x):
        # Create a mask for NaN values in the input tensor
        nan_mask = torch.isnan(x)

        # Specify lower and upper bounds for the clipping operation
        lower = torch.tensor([-10.0] * x.shape[1])
        upper = torch.tensor([10.0] * x.shape[1])

        # Perform clipping on the tensor, enforcing the specified bounds
        x = x.clip(lower, upper)

        # Restore NaN values using the previously created mask, ensuring they are left unchanged
        x = torch.where(nan_mask, torch.tensor(0, dtype=x.dtype, device=x.device), x)

        return x

from onnx.

Related Issues (20)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.