somepago / saint Goto Github PK
View Code? Open in Web Editor NEWThe official PyTorch implementation of recent paper - SAINT: Improved Neural Networks for Tabular Data via Row Attention and Contrastive Pre-Training
License: Apache License 2.0
The official PyTorch implementation of recent paper - SAINT: Improved Neural Networks for Tabular Data via Row Attention and Contrastive Pre-Training
License: Apache License 2.0
Is there any progress on the releasement of the plotting code?
Thanks in advance!
Hi Fabien, I will release the attention plotting code in the next version. I am busy with another project rn, I am targeting the end of November for this. If you need it urgently let me know.
Originally posted by @somepago in #8 (comment)
It may help if your add minor number like 1e-8, 1e-9 at thís line (link) to make your code more robust.
Hello, I'm a beginner interested in Tabular Learning. Your superb paper, SAINT, impresses me a lot. But I've had some problems learning your code.
For
Line 233 in e288e84
Line 89 in e288e84
I think the correct expression should be train[col].fillna(train.loc[:, col].mean(), inplace=True).
I'm not sure whether I am correct. I would appreciate it if you can reply. Thank you very much!
Thank you so much for sharing this impressive work.
I failed in creating the environment. Anything I could do to fix this error? My error detail is listed below:
K:\library\saint>conda env create -f saint_environment.yml
Collecting package metadata (repodata.json): done
Solving environment: failed
ResolvePackageNotFound:
- gmp==6.2.1=h58526e2_0
- certifi==2021.5.30=py38h578d9bd_0
- lame==3.100=h7f98852_1001
- promise==2.3=py38h578d9bd_3
- jupyter_core==4.7.1=py38h578d9bd_0
- libglib==2.68.3=h3e27bee_0
- setuptools==49.6.0=py38h578d9bd_3
- ffmpeg==4.3=hf484d3e_0
- markupsafe==2.0.1=py38h497a2fe_0
- libprotobuf==3.17.2=h780b84a_0
- libgomp==9.3.0=h2828fa1_19
- protobuf==3.17.2=py38h709712a_0
- yaml==0.2.5=h516909a_0
- gst-plugins-base==1.14.0=hbbd80ab_1
- freetype==2.10.4=h0708190_1
- pcre==8.45=h9c3ff4c_0
- tornado==6.1=py38h497a2fe_1
- _openmp_mutex==4.5=1_gnu
- debugpy==1.3.0=py38h709712a_0
- xgboost==1.4.0=py38h578d9bd_0
- expat==2.4.1=h9c3ff4c_0
- kiwisolver==1.3.1=py38h1fd1430_1
- pyzmq==22.1.0=py38h2035c66_0
- glib==2.68.3=h9c3ff4c_0
- tk==8.6.10=h21135ba_1
- pysocks==1.7.1=py38h578d9bd_3
- websocket-client==0.57.0=py38h578d9bd_4
- ipython==7.25.0=py38hd0cf306_1
- numpy-base==1.20.2=py38hfae3a4d_0
- libffi==3.3=h58526e2_2
- nbconvert==6.1.0=py38h578d9bd_0
- libuuid==2.32.1=h7f98852_1000
- numpy==1.20.2=py38h2d18471_0
- mkl_random==1.2.2=py38h1abd341_0
- pthread-stubs==0.4=h36c2ea0_1001
- libpng==1.6.37=h21135ba_2
- mkl_fft==1.3.0=py38h42c9631_2
- chardet==4.0.0=py38h578d9bd_1
- readline==8.1=h46c0cb4_0
- psutil==5.8.0=py38h497a2fe_1
- shortuuid==1.0.1=py38h578d9bd_4
- gstreamer==1.14.0=h28cd5cc_2
- ld_impl_linux-64==2.35.1=hea4e1c9_2
- libgcc-ng==9.3.0=h2828fa1_19
- xorg-libxau==1.0.9=h7f98852_0
- mistune==0.8.4=py38h497a2fe_1004
- pytorch==1.8.1=py3.8_cuda11.1_cudnn8.0.5_0
- libunistring==0.9.10=h14c3975_0
- fontconfig==2.13.1=hba837de_1005
- importlib-metadata==4.6.1=py38h578d9bd_0
- glib-tools==2.68.3=h9c3ff4c_0
- libuv==1.41.0=h7f98852_0
- click==8.0.1=py38h578d9bd_0
- xorg-libxdmcp==1.1.3=h7f98852_0
- mkl-service==2.4.0=py38h497a2fe_0
- watchdog==0.10.4=py38h578d9bd_0
- pillow==8.2.0=py38he98fc37_0
- py-xgboost==1.4.0=py38h578d9bd_0
- qt==5.9.7=h5867ecd_1
- libidn2==2.3.1=h7f98852_0
- brotlipy==0.7.0=py38h497a2fe_1001
- libwebp-base==1.2.0=h7f98852_2
- cryptography==3.4.7=py38ha5dfef3_0
- gettext==0.19.8.1=h0b5b191_1005
- scikit-learn==0.23.2=py38h0573a6f_0
- libxcb==1.13=h7f98852_1003
- argon2-cffi==20.1.0=py38h497a2fe_2
- sqlite==3.35.5=h74cdb3f_0
- nettle==3.6=he412f7d_0
- openssl==1.1.1k=h7f98852_0
- matplotlib==3.4.2=py38h578d9bd_0
- anyio==3.2.1=py38h578d9bd_0
- jedi==0.18.0=py38h578d9bd_2
- libxml2==2.9.12=h03d6c58_0
- sniffio==1.2.0=py38h578d9bd_1
- xz==5.2.5=h516909a_1
- wget==1.20.1=h22169c7_0
- mkl==2021.2.0=h06a4308_296
- libiconv==1.16=h516909a_0
- jpeg==9b=h024ee3a_2
- ca-certificates==2021.5.30=ha878542_0
- gnutls==3.6.13=h85f3911_1
- matplotlib-base==3.4.2=py38hcc49a3a_0
- libgfortran-ng==7.3.0=hdf63c60_0
- lcms2==2.12=h3be6417_0
- icu==58.2=hf484d3e_1000
- libxgboost==1.4.0=h9c3ff4c_0
- pandoc==2.14.0.3=h7f98852_0
- libsodium==1.0.18=h36c2ea0_1
- dbus==1.13.18=hb2f20db_0
- pandas==1.2.4=py38h1abd341_0
- pyyaml==5.4.1=py38h497a2fe_0
- zstd==1.4.9=ha95c52a_0
- cudatoolkit==11.1.1=h6406543_8
- python==3.8.10=h49503c6_1_cpython
- _libgcc_mutex==0.1=conda_forge
- zeromq==4.3.4=h9c3ff4c_0
- pyrsistent==0.17.3=py38h497a2fe_2
- cffi==1.14.5=py38ha65f79e_0
- openh264==2.1.1=h780b84a_0
- libtiff==4.2.0=h85742a9_0
- lz4-c==1.9.3=h9c3ff4c_0
- scipy==1.6.2=py38had2a1c9_1
- ipykernel==6.0.2=py38hd0cf306_0
- ninja==1.10.2=h4bd325d_0
- pyqt==5.9.2=py38h05f1152_4
- intel-openmp==2021.2.0=h06a4308_610
- sip==4.19.13=py38he6710b0_0
- zlib==1.2.11=h516909a_1010
- bzip2==1.0.8=h7f98852_4
- ncurses==6.2=h58526e2_4
- libstdcxx-ng==9.3.0=h6de172a_19
- terminado==0.10.1=py38h578d9bd_0
Hi,
I was trying to reproduce the benchmark (xgboost and lightgbm) results but i can't get the same showed in your paper.
I used this to split the dataset in train, valid and test:
Line 190 in e0ee763
I used early stop on validation and collect test performance as final results and rerun the experiment on 5 different seed (0, 1, ..., 5) as you do for Saint model.
I used standard parameter for xgboost and lightgbm with some regularization.
I used the dataset you provide in the following link:
https://drive.google.com/file/d/1mJtWP9mRP0a10d1rT6b3ksYkp4XOpM0r/view?usp=sharing
The results i get are:
Model\Dataset | Bank | Blastchar | arrhytmia | Arcene | Forest | Shoppers | Income | Volkert |
---|---|---|---|---|---|---|---|---|
lightgbm | 93.46 | 83.71 | 93.18 | 85.25 | 99.79 | 93.23 | 92.03 | 71.46 |
xgboost | 93.41 | 83.67 | 93.13 | 87.66 | 99.71 | 92.62 | 92.36 | 70.32 |
My experiment show clear improvement of the benchmark result as showed below:
Model\Dataset | Bank | Blastchar | arrhytmia | Arcene | Forest | Shoppers | Income | Volkert |
---|---|---|---|---|---|---|---|---|
lightgbm | +0.069 | +0.54 | +4.45 | +4.2 | +6.5 | +0.03 | -0.54 | +3.55 |
xgboost | +0.45 | +1.89 | +11.14 | +6.25 | +4.38 | +0.11 | 0.05 | +1.37 |
Can you share the code used to calculate benchmark results?
I used also quite standard parameter to train xgboost and lightgbm:
xgboost:
- max_depth=8,
- learning_rate=0.01,
- tree_method = 'hist',
- subsample=0.75,
- colsample_bytree=0.75,
- reg_alpha= 0.5,
- reg_lambda= 0.5,
lightgbm:
- learning_rate= 0.01,
- max_depth= -1,
- num_leaves= 2**8,
- lambda_l1= 0.5,
- lambda_l2= 0.5,
- feature_fraction= 0.75,
- bagging_fraction= 0.75,
- bagging_freq = 1,
I think i can improve these results by tuning these parameter more.
Hi
For inference, the CLS token(L157 and L160 in train.py) is still basing on ground-truth label, should they be static CLS token instead?
I applied this great model to regression, but the value is nan in the model.transformer part.
class RowColTransformer(nn.Module):
~~~~~~~~~
def forward(self, x, x_cont=None, mask = None):
if x_cont is not None:
x = torch.cat((x,x_cont),dim=1)
_, n, _ = x.shape
print("TRANFOERMR")
if self.style == 'colrow':
for attn1, ff1, attn2, ff2 in self.layers:
x = attn1(x)##here x==nan
Did this happen during implementation? If anyone has used it for their own data, please let me know.
these are hyper params
model_saint = SAINT(
categories = tuple(cat_dims.values()),#len(cat_dims)==2
num_continuous = len(numerical_features)+1,
dim =128,
dim_out = 1,
depth = 6,
heads = 8,
attn_dropout = 0.1,
ff_dropout = 0.1,
mlp_hidden_mults = (4, 2),
continuous_mean_std = None,
cont_embeddings = "MLP",
attentiontype = 'col',
final_mlp_style = 'sep',
y_dim = 1
)
optim:AdamW(model_saint.parameters(), lr=1e-3,weight_decay=5e-5)
BATCH_size=256
Thank you for sharing your great work!
When I want to evaluate your result on all datasets that are listed on your paper, eg, Bank, Blastchar, Arrhythmia, ..., I had a problem about your code in data_openml.py.
The id_dataset id that you listed in the file (1487,44, ...) did not match with datasets you list on paper (bank, blastchar, ...).
id: 1487 when I use api of opennl.datasets.get_dataset(1487), I got ozone-level-8hr dataset.
Might you give me some suggestions to evaluate your result on datasets you listed on paper.
Many thanks!
Hello,
I am currently working on a regression task containing multiple label dimensions.
I saw that your code has an implemantation on the regression task, but only for y_dim=1 and not for multiple dimensions (y_dim=n where n>1).
I tried running my regression task by changing the y_dim variable and running the code but it does not work (apparently there is a mix up with the dimension, and I am not sure where is the right place to change the code).
I wanted to ask if there is a simple way to run the model for a regression task with several out dimensions?
Thank you
for i, data in enumerate(trainloader, 0) #181
This code is stopped without turning around. I think it's going on an infinite loop, can you solve it?
One more issue - your implementation of MLP, in model.py, is just a bunch of stacked linear layers with no non-linearities between them. This is mathematically equivalent to just a linear layer with in_dim = dims[0] and out_dim=dims[-1].
Why not use the activation between the layers?
Hello Gowthami,
Thank you for this project. It shows uplift in performance for my use-case over xgboost. It will be of great help to get the attention plotting code (both self attention and inter-sample attention)for the SAINT implementation as shown by you in the paper for SAINT.
Is it possible to use SAINT for the tabular data, which contains only continuous variables, without categorical?
We need to pass to SAINT model two parameters: x_categ and x_cont
Do I need to pass some torch.empy tensor as x_categ?
What to pass as "categories" parameter to the SAINT model? Empty tuple?
Hi,
First of all - very impressive project and repository! Chapeau to you all.
I'm trying to generate some of the results, having issues with, e.g., HTRU2. I searched openml datasets and reached dataset id 43377, but this doesn't load with your code (y values are None).
Maybe I'm looking in the wrong place, but - could you provide a list of opeml dataset ids to recreate the results in your paper?
Thank you for sharing your work, it has actually been helping me a lot.
I have a problem with your code relating Attention module of Transformer. May I be wrong that the Attention module should have dropout layer after softmax function (link). For example, link or link, they used dropout layer in Attention module.
How can I apply the code if I want to use csv data other than openml data?
Hello,
Thanks for this repo! Do you plan to release the code you used to plot intersample attention and self attention as in the paper (section 5.2)? I would like to reproduce the figure 3.
Hi,
I am trying to reproduce your experiments as a baseline for my paper but lots of the dependencies in the yml file are deprecated. Would it be possible for you to advice on this please/provide an update to the yml file?
Thank you!
Thanks for awesome work.
I'm using a tabular dataset for a regression task. I would like to predict the last column (float values) in the picture below.
I'm not sure how should I setup network and esp these two parameters:
categories = tuple(cat_dims),
num_continuous = len(con_idxs)
For now I'm using
con_idxs = [0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21]
If I change the last column values to int
using train[target] = train[target].astype(int)
and use the following as cat dims it starts training but I want to predict floating values.
cat_dims = np.append(np.array(cat_dims),np.array([50])).astype(int)
If I dont convert target to int
it throws following error:
return torch.embedding(weight, input, padding_idx, scale_grad_by_freq, sparse)
RuntimeError: Expected tensor for argument #1 'indices' to have one of the following scalar types: Long, Int; but got torch.DoubleTensor instead (while checking arguments for embedding)
``
Hello!
There is an import in pretraining.py
file:
from baselines.data_openml import data_prep_openml,task_dset_ids,DataSetCatCon
on line 4
However, in the repo there is no folder baselines
, and thus there is an error, when I attempt to apply pretraining in train file.
Thanks!
Hi,
I find in your paper “Results are averaged over 5 trials and 14 binary classification datasets.” However, there is “'binary [1487,44,1590,42178,1111,31,42733,1494,1017,4134]” in your code. Could you provide other datasets?
Hello thanks for the awesome works and the codes.
I've applied your code to some datasets and had some questions.
While pretraining, line 323 in data.py performs concatenating category features and target features.
This concatenated categorical data pass the embedding layer and the results used as an input data of the transformer.
I couldn't find the code that separating the target data before passing the concatenated data into the transformer.
It is okay to include target data while pretraining the model?
Thank you.
Hello thanks for the code and awesome works.
I read your paper impressively and have a question about the code.
In the paper, p.14 Data preprocessing, it is written to
"Each feature (or columns) has a different missing value token to account for missing data.".
However I found that the code just fill missing values with an average value for continuous features.
I wonder the token embedding works only for categorical data.
It was very exciting to read the paper and I hope to apply the algorithm to my dataset soon!
A declarative, efficient, and flexible JavaScript library for building user interfaces.
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
An Open Source Machine Learning Framework for Everyone
The Web framework for perfectionists with deadlines.
A PHP framework for web artisans
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
Some thing interesting about web. New door for the world.
A server is a program made to process requests and deliver data to clients.
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
Some thing interesting about visualization, use data art
Some thing interesting about game, make everyone happy.
We are working to build community through open source technology. NB: members must have two-factor auth.
Open source projects and samples from Microsoft.
Google ❤️ Open Source for everyone.
Alibaba Open Source for everyone
Data-Driven Documents codes.
China tencent open source team.