Comments (5)
Hi @chenfu12138 there are 2 steps to get the attentional weights of feature interactions in AFM.
First,make sure you have installed the latest release version of deepctr
(now v0.1.5) from pip.
You can installed that version through pip install deepctr==0.1.5
.
Step1: Modify two lines of the source codes
Please modify the layers.py
in your local machine,maybe the path is
xxx\Anaconda3\Lib\site-packages\deepctr\layers.py
In line 134 and 135
change the following two lines
attention_weight =tf.nn.softmax(tf.tensordot(attention_temp,self.projection_h,axes=(-1,0)),dim=1)
attention_output = tf.reduce_sum(attention_weight*bi_interaction,axis=1)
to
self.normalized_att_score=tf.nn.softmax(tf.tensordot(attention_temp,self.projection_h,axes=(-1,0)),dim=1)
attention_output = tf.reduce_sum(self.normalized_att_score*bi_interaction,axis=1)
I will modify it in the next released version.
Step2: Get the attentional weights !
After you have finished training the AFM model.
from tensorflow.python.keras.models import Model
from tensorflow.python.keras.layers import Lambda
afmlayer = model.layers[-3]
afm_weight_model = Model(model.input,outputs=Lambda(lambda x:afmlayer.normalized_att_score)(model.input))
attentional_weights = afm_weight_model.predict(model_input,batch_size=4096)
You can try it~
from deepctr.
@shenweichen It worked,thanks! I have obtained the attentional weights like this
[[0.05291238]
[0.05570798]
[0.29247418]
...
[0.2924742 ]
[0.05001876]
[0.04996691]]
But i dont konw how to match the attentional weight in the array with the specific feature interaction one by one,And it seems that the match relationship between them varies from every training?
from deepctr.
Hi @chenfu12138
You can use the following codes
import itertools
import deepctr
from tensorflow.python.keras.models import Model
from tensorflow.python.keras.layers import Lambda
feature_dim_dict = {"sparse": sparse_feature_dict, "dense": dense_feature_list}
model = deepctr.models.AFM(feature_dim_dict)
model.fit(model_input,target)#
afmlayer = model.layers[-3]
afm_weight_model = Model(model.input,outputs=Lambda(lambda x:afmlayer.normalized_att_score)(model.input))
attentional_weights = afm_weight_model.predict(model_input,batch_size=4096)
feature_interactions = list(itertools.combinations(list(feature_dim_dict['sparse'].keys()) + feature_dim_dict['dense'] ,2))
The attentional_weights[:,i,0]
is the feature_interactions[i]
's attentional weight of all samples
Try it and star it if it helps 😉
from deepctr.
@shenweichen Perfect solution! Its so nice of you! Star it of course!!!
from deepctr.
Hi, the latest version v0.2.0
has been released ,please upgrade through pip install -U deepctr
~
from deepctr.
Related Issues (20)
- estimator with Multi-value Input HOT 1
- mmoe训练模型,测试集ctr和cvr的auc完全相等。
- deepfm模型如何实现多头输出?
- SDM 模型中,movielens中 genres 这种多值离散特征怎么处理
- The following Variables were used a Lambda layer's call,BatchNormalization
- Linear logic in DCNMIX
- The use of linear logic in DeepFM/DCNMIX
- ple可以只用于单任务吗
- 安装gpu版本报错 HOT 1
- 如何保存deepctr-torch训练好的deepfm模型 HOT 1
- DIN mask为何没有传入mask参数 HOT 1
- Implementing fix from Issue#344
- 多值特征代码有bug HOT 3
- save/load model error HOT 1
- model.predict only support np.array ?
- py3.11 to install error for h5py==3.7.0 which not support for py3.11 HOT 1
- 为什么GPU运行时SparseFeat中vocabulary_size的值大小不会引起错误
- How to self define metric instead of using one of the pre-defined metrics HOT 1
- feature interaction visualization
- I'm using this model with cpu, so I'm getting an error.
Recommend Projects
-
React
A declarative, efficient, and flexible JavaScript library for building user interfaces.
-
Vue.js
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
-
Typescript
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
-
TensorFlow
An Open Source Machine Learning Framework for Everyone
-
Django
The Web framework for perfectionists with deadlines.
-
Laravel
A PHP framework for web artisans
-
D3
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
-
Recommend Topics
-
javascript
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
-
web
Some thing interesting about web. New door for the world.
-
server
A server is a program made to process requests and deliver data to clients.
-
Machine learning
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
-
Visualization
Some thing interesting about visualization, use data art
-
Game
Some thing interesting about game, make everyone happy.
Recommend Org
-
Facebook
We are working to build community through open source technology. NB: members must have two-factor auth.
-
Microsoft
Open source projects and samples from Microsoft.
-
Google
Google ❤️ Open Source for everyone.
-
Alibaba
Alibaba Open Source for everyone
-
D3
Data-Driven Documents codes.
-
Tencent
China tencent open source team.
from deepctr.