Code Monkey home page Code Monkey logo

Comments (6)

ntcmp2u avatar ntcmp2u commented on September 24, 2024 1

Thank you so much for your responsive answer.

This is the bug-report directory:

bug_example.zip

I use the Render in this way:

from nnsmith.materialize import Render, BugReport, Model
from nnsmith.backends import BackendFactory

model_init = Model.init("torch", "cpu")
bug_report = BugReport.load(model_init, "./bug_example/", allow_partial=True)

render = Render()
render.emit_model(bug_report.testcase.model)
render.emit_input(bug_report.testcase.model)
from nnsmith.backends import BackendFactory
render.emit_backend(BackendFactory.init("pt2"))
output = render.render()
with open("./output.py", "w+") as f:
    f.write(output)

and then the output.py is like this:

import numpy as np
import torch
import pickle

# Model definition
class M(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.v5_0 = torch.nn.Parameter(torch.empty([1], dtype=torch.int16), requires_grad=False)
    def forward(self, *args):
        _args = args
        getitem = _args[0];  _args = None
        _tensor_constant0 = self._tensor_constant0
        mul = torch.mul(_tensor_constant0, getitem);  _tensor_constant0 = None
        expand = mul.expand(1)
        expand_1 = mul.expand(1, 1, 1, 1);  mul = None
        max_1 = torch.max(expand_1, getitem);  expand_1 = getitem = None
        return (expand, max_1)

m = M()


# Initialize weight
# None

# Initialize input
inp = [np.zeros([], dtype='int16')]

# Compile the model
opt = torch.compile(m, fullgraph=True, backend='inductor', mode=None)

# Eager run
m_out = m(*[torch.from_numpy(v).to('cpu') for v in inp])
m_out = [v.cpu().detach() for v in m_out] # torch2numpy
m_out = [v.resolve_conj().numpy() if v.is_conj() else v.numpy() for v in m_out] # torch2numpy

# Compiled run
opt_out = opt(*[torch.from_numpy(v).to('cpu') for v in inp])
opt_out = [v.cpu().detach() for v in opt_out] # torch2numpy
opt_out = [v.resolve_conj().numpy() if v.is_conj() else v.numpy() for v in opt_out] # torch2numpy

# Differential testing
for i, (l, r) in enumerate(zip(m_out, opt_out)):
    np.testing.assert_allclose(l, r, rtol=1e-2, atol=1e-3, err_msg=f"Result mismatch @ index {i}")

But executing this will raise exception: AttributeError: 'M' object has no attribute '_tensor_constant0'.

from nnsmith.

ganler avatar ganler commented on September 24, 2024

Hi I implemented the render in #107 a few months ago. I would not say it is strictly tested but I have not encountered any major issues so far.

You are welcome to check out the examples in the unit tests: https://github.com/ise-uiuc/nnsmith/blob/main/tests/torch/test_render.py

Or please file a concrete bug so that I can help you diagnose. Thanks.

from nnsmith.

ganler avatar ganler commented on September 24, 2024

Could you let me know ur PyTorch version? Thanks.

from nnsmith.

ntcmp2u avatar ntcmp2u commented on September 24, 2024

Could you let me know ur PyTorch version? Thanks.

The version is 2.0.1+cu117.

from nnsmith.

ganler avatar ganler commented on September 24, 2024

Sorry for the late reply. I am looking into the bug right now. While I have fixed your issue partially in #122, the most critical issue here, aka the undefined variable name _tensor_constant0 is introduced by PyTorch's model-to-code translation when referring to the parameter over a user-provided map.

I will either find other ways to implement the symbolic tracing or report a bug to PyTorch to fix it.

from nnsmith.

ganler avatar ganler commented on September 24, 2024

OK, got a workaround by referencing parameters as object attributes which makes your example work. Feel free to retry after the PR gets merged.

image

from nnsmith.

Related Issues (20)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.