Code Monkey home page Code Monkey logo

Comments (1)

Shay-Yo avatar Shay-Yo commented on August 27, 2024

I believe I solved this issue. For anyone who is experiencing the same problem.
The main problem is with the shapes so you need to do several small changes to make it work.

First go to the script train.py and change the y_dim variable from 1 to the number of classes you have (notice that if the task is set to regression the code will force y_dim to be 1, you need to change it).

Second go to the data_openml.py and change the line:

'data': y[indices].reshape(-1, 1)

in the function data_split to

'data': y[indices].reshape(-1, y_dim)

where y_dim is the number of classes you have.

Third in the data_openml.py script in the __init__ method of the DataSetCatCon class you need to change the lines:

self.cls = np.zeros_like(self.y,dtype=int)
self.cls_mask = np.ones_like(self.y,dtype=int)

to

self.cls = np.zeros(shape=(len(self.y), 1), dtype=int)
self.cls_mask = np.ones(shape=(len(self.y), 1), dtype=int)

This was because we change the shape of y and we are adding a cls token column to the features, so those two lines need to create a cls token column in the length of the number of data points, but if we leave them as they are we will not get a column but a matrix with the size of the number of data points to the number of classes.

Notice that if you just change the numbers it might create a problem later if you wish to train the model on other tasks or different class numbers so you might need to change it back if you want to train the model on different tasks.

from saint.

Related Issues (20)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.