Code Monkey home page Code Monkey logo

12wang3 / rrl Goto Github PK

View Code? Open in Web Editor NEW
90.0 4.0 23.0 574 KB

The code of NeurIPS 2021 paper "Scalable Rule-Based Representation Learning for Interpretable Classification" and TPAMI paper "Learning Interpretable Rules for Scalable Data Representation and Classification"

License: MIT License

Python 100.00%
interpretable-classification rule-based-model representation-learning scalability interpretable-ai interpretable-ml explainable-ai explainable-ml neurips nips

rrl's People

Contributors

12wang3 avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar

rrl's Issues

请教一下一个初级的问题

rrl.text中的结果的每个字段是什么含义能够解释一下吗?RID我理解是规则ID,但他的取值是什么意思呢,这个标签特征的分类|class_negative(b=-2.1733) | class_positive(b=1.9689)括号中的b是代表平均值吗,每个标签的值是结果的概率吗?那个support又代表什么意思呢?

请教一个关于运行效率的问题。

请教一下,我用您的测试数据可以跑得起来,但是执行效率看起来有点低,您是否知道可能是什么问题,如何排查。

  • 998条数据的tic-tac-toe数据集,一个epoch看起来需要5min

image

  • GPU看起来完全没有用起来,使用率一直是0,但是显存被使用了

image

  • Loss看起来也正常在下降,但是效率看起来太低了。
    image

想请教一下您,这样的执行效率是否正常,是否是可能因为环境问题或者配置错误导致的效率低下呢。

请教比较初级的实验结果的问题

希望请教一下rrl.text中的实验结果,我理解support越高应该是学习到的规则越可信,然后如何通过weight和bias看出该条规则是否越可信/越重要呢?另外激活的规则是什么意思? 问题比较小白请多包涵!

个人的一些小问题,希望能得到解答~

最近在学习大佬的这篇文章以及源码,现在有几个小问题希望能得到您的解答。
(1) RRL类构造函数里的left、right参数没有怎么看懂是什么作用
image

(2) 同样也是RRL构造函数中的use_not参数,是可以让解释集规则中包含一些“~“规则,但是好像论文原文中没有提到这个细节。

(3) 第四层开始每一层的输入都要包含前两层的输出,这个细节在文中图片里有体现但是好像没有文字描述这个过程,想知道这样组合输入的原因是什么呢。

(4) estimated_grad参数用于选择conjunction_layer和disjunction_layer输出时的激活函数,我发现EstimatedProduct和Product只有backward()函数不同,这两个的差异是否就是对应论文这句话的描述呢,EstimatedProduct的反向传播时在导数的基础外又套了一个自定义的激活函数image
image

(5) mllp为权重连续的version,rrl为mllp权重离散化后的离散version,在反向传播训练的时候第一步先将 rrl的loss关于rrl的y_pred的导数求出,之后按照常规的mllp的y_pred对参数求导,实现梯度嫁接的过程:
image
在代码实现中,image,backward内传入的参数就是rrl的loss关于rrl的y_pred的导数吗?但是该导数是如何推出的呢?

(6) mllp是该模型的连续权重值版本(权重都位于[0-1]之间,用于训练),rrl是mllp的权重离散化后的二值权重版本(权重都属于{0,1},以0.5为分割阈值点,用于训练、测试、提取解释规则),离散version在参数反向传播过程中只参与了一小部分,大部分还是根据连续版本的mllp来调整参数的,但是在实际训练过程中发现mllp的loss很难收敛而且loss要比rrl高很多,常理来说应该连续版本的mllp的性能会比离散化后的rrl高的吧,这一点不太明白。

不知道我对这篇文章理解的是否到位,可能描述的不太清楚,希望大佬能够抽空解答一下问题,非常感谢啦。

How to calculate the weights of the rule and class?

I'm a little confused by the output. rrl.txt as follows:

RID final_status_0(b=0.3015) final_status_1(b=-0.0915) Support Rule
(-1, 0) 1.6603 -1.6603 0.6216 a_1 & c > 0.002
(-1, 3) -1.4654 1.4654 0.0541 a_3 & b <= 2.324 & d <= 0.013
(-1, 2) -0.6511 0.6511 0.3357 b > 2.324 & d <= 0.013
(-1, 4) 0.6015 -0.6015 0.0356 a_3 & b > 2.324
(-1, 1) 0.5441 -0.5441 0.9325 c > 0.002
(-1, 5) -0.2566 0.2566 0.7992 d <= 0.013
############################################################

Why are the absolute values of the final_status_* columns the same? How should I calculate the weight of each rule under different classes?

Cannot run it on windows

Hi,

I was trying to give try to this implementation after reading the paper. I installed all the dependencies in a Conda env on a Window PC. However, I am having the following error when I run the experiment:

$ python experiment.py -d tic-tac-toe -bs 32 -s 1@16 -e401 -lrde 200 -lr 0.002 -ki 0 -wd 0.0001 --print_rule -i 0
C:\Users\m00827298\AppData\Local\miniconda3\envs\rrl\Lib\site-packages\torch\distributed\distributed_c10d.py:608: UserWarning: Attempted 
to get default timeout for nccl backend, but NCCL support is not compiled
  warnings.warn("Attempted to get default timeout for nccl backend, but NCCL support is not compiled")
[W socket.cpp:697] [c10d] The client socket has failed to connect to [A2207000547.china.huawei.com]:47339 (system error: 10049 - The requested address is not valid in its context.).
Traceback (most recent call last):
  File "C:\Users\m00827298\Codes\RRL\experiment.py", line 174, in <module>
    train_main(rrl_args)
  File "C:\Users\m00827298\Codes\RRL\experiment.py", line 167, in train_main
    mp.spawn(train_model, nprocs=args.gpus, args=(args,))
  File "C:\Users\m00827298\AppData\Local\miniconda3\envs\rrl\Lib\site-packages\torch\multiprocessing\spawn.py", line 241, in spawn       
    return start_processes(fn, args, nprocs, join, daemon, start_method="spawn")
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "C:\Users\m00827298\AppData\Local\miniconda3\envs\rrl\Lib\site-packages\torch\multiprocessing\spawn.py", line 197, in start_processes
    while not context.join():
              ^^^^^^^^^^^^^^
  File "C:\Users\m00827298\AppData\Local\miniconda3\envs\rrl\Lib\site-packages\torch\multiprocessing\spawn.py", line 158, in join        
    raise ProcessRaisedException(msg, error_index, failed_process.pid)
torch.multiprocessing.spawn.ProcessRaisedException:

-- Process 0 terminated with the following error:
Traceback (most recent call last):
  File "C:\Users\m00827298\AppData\Local\miniconda3\envs\rrl\Lib\site-packages\torch\multiprocessing\spawn.py", line 68, in _wrap        
    fn(i, *args)
  File "C:\Users\m00827298\Codes\RRL\experiment.py", line 57, in train_model
    dist.init_process_group(backend='nccl', init_method='env://', world_size=args.world_size, rank=rank)
  File "C:\Users\m00827298\AppData\Local\miniconda3\envs\rrl\Lib\site-packages\torch\distributed\c10d_logger.py", line 86, in wrapper    
    func_return = func(*args, **kwargs)
                  ^^^^^^^^^^^^^^^^^^^^^
  File "C:\Users\m00827298\AppData\Local\miniconda3\envs\rrl\Lib\site-packages\torch\distributed\distributed_c10d.py", line 1177, in init_process_group
    store, rank, world_size = next(rendezvous_iterator)
                              ^^^^^^^^^^^^^^^^^^^^^^^^^
  File "C:\Users\m00827298\AppData\Local\miniconda3\envs\rrl\Lib\site-packages\torch\distributed\rendezvous.py", line 246, in _env_rendezvous_handler
    store = _create_c10d_store(master_addr, master_port, rank, world_size, timeout, use_libuv)
            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "C:\Users\m00827298\AppData\Local\miniconda3\envs\rrl\Lib\site-packages\torch\distributed\rendezvous.py", line 174, in _create_c10d_store
    return TCPStore(
           ^^^^^^^^^
torch.distributed.DistNetworkError: Unknown error

Can RRL be used for regression?

Hi authors,

Thanks for this great work! I think it is very helpful for data analysis.
I wonder if it can be used for regression and how to perform that?

Best,

Mx

参数问题

您好,我是一名计算机专业大四学生,正在复现这个实验,其他的数据集没有跑出预期的效果,请问可以麻烦您提供下其他数据集的参数配置吗,麻烦了,谢谢您!

您好,请教关于实验复现的问题

您好,我对您这篇工作很感兴趣,最近在复现实验部分,有些问题想请教下:

1)您论文的Appendix CParameter Settings一节中写道:

The number of nodes in logical layers ranges from 16 to 4096 depending on the number of binary features of the data set and the model complexity we need.
请问您是如何根据binary features的数量来确定logical layers的节点数量呢?换句话说,我在调这个参数的时候比较纠结,因为感觉它范围有些大,能请您给点经验吗?

2) 您论文的4.3节展示了模型复杂度和表现的关系。其中模型复杂度是用log(#edges)来表示的,我在代码中似乎没有找到统计边数的对应实现?也可能是我看漏了,恳请您能指出,感谢!

3)另外,可否请问下您在 chessbank-marketing 两个数据集上的参数设置?我在训练chess数据集上的模型时尝试了多种参数组合却依然无法收敛; bank-marketing虽然结果与您论文展示的接近,但是学出的规则却与您在**Figure4(b)**呈现的大相径庭(我学习出的规则完全没有balance这一项)

感谢您的时间 :)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.