Code Monkey home page Code Monkey logo

Comments (8)

FangShancheng avatar FangShancheng commented on August 29, 2024

您理解的大体上是没问题的,论文中视觉模型用到的Position Attention是在现有的attention方法上的reformulate,尤其是强调,attention中对q, k, v的进一步抽象会影响模型的性能。
实现上,跟Parallel Attention等其他attention的区别有:

  1. q使用的是不可学习的position encoding参数,其通过fc层进行投影
  2. k使用unet进行进一步抽象,这一点是有效的。

from abinet.

GarrettLee avatar GarrettLee commented on August 29, 2024

您理解的大体上是没问题的,论文中视觉模型用到的Position Attention是在现有的attention方法上的reformulate,尤其是强调,attention中对q, k, v的进一步抽象会影响模型的性能。

实现上,跟Parallel Attention等其他attention的区别有:

  1. q使用的是不可学习的position encoding参数,其通过fc层进行投影

  2. k使用unet进行进一步抽象,这一点是有效的。

我上面那段代码里面两个方法的q基本是一样的吧,还是我理解还有误吗

from abinet.

GarrettLee avatar GarrettLee commented on August 29, 2024

Unet看起来确实可能会有效,我们也会试试看

from abinet.

FangShancheng avatar FangShancheng commented on August 29, 2024

您理解的大体上是没问题的,论文中视觉模型用到的Position Attention是在现有的attention方法上的reformulate,尤其是强调,attention中对q, k, v的进一步抽象会影响模型的性能。
实现上,跟Parallel Attention等其他attention的区别有:

  1. q使用的是不可学习的position encoding参数,其通过fc层进行投影
  2. k使用unet进行进一步抽象,这一点是有效的。

我上面那段代码里面两个方法的q基本是一样的吧,还是我理解还有误吗

q这里,逻辑上有点不一样,准确来说,实现上应该是q = fc(position encoding),甚至是直接的q=position encoding,其中position encoding为transformer中的实现,而不是直接的w。这个操作只要是增加可解释性。效果上,以及unet效果的其他替代方案,vision model这里,我们并没有展开太多实验,主要还是针对language model展开的实验。

from abinet.

GarrettLee avatar GarrettLee commented on August 29, 2024

明白了

from abinet.

simplify23 avatar simplify23 commented on August 29, 2024

attention中对q, k, v的进一步抽象会影响模型的性能。

想问一下,这句话是否有做过一些实验来验证呢,这里的进一步主要指的是哪一类型的方法,robust scanner方法上,对key 加入了bilstm+cnn的增强,看起来也是有效的。能不能进一步阐述一下

from abinet.

FangShancheng avatar FangShancheng commented on August 29, 2024

attention中对q, k, v的进一步抽象会影响模型的性能。

想问一下,这句话是否有做过一些实验来验证呢,这里的进一步主要指的是哪一类型的方法,robust scanner方法上,对key 加入了bilstm+cnn的增强,看起来也是有效的。能不能进一步阐述一下

我们论文表1 ablation 关于视觉模型的实验有涉及到这个点,进一步抽象的方法就是对q,k,v都加函数抽象。ABINet中k是用的UNet做抽象的。此外,对v做抽象等于加强backbone,额外的实验没发现有多大增益,对q做增强也有一定收益。

from abinet.

YanShuang17 avatar YanShuang17 commented on August 29, 2024

@FangShancheng 作者你好!

很难理解https://github.com/FangShancheng/ABINet/blob/main/modules/attention.pyAttention类的计算att_weight的逻辑。

这里我和SRN中视觉部分(PVAM)中的attention过程作对比:
(1) SRN-PVAM中的attention过程(伪代码,假设qkv的维度都是d_model):

# e.g. d_model = 512, max_seq_len = seq_len_q = 25, vocab_size = 37
key2att = nn.Linear(d_model, d_model)
query2att = nn.Linear(d_model, d_model)
embedding = nn.Embedding(max_seq_len, d_model)
score = nn.Linear(d_model, 1)
classifier = nn.Linear(d_model, vocab_size)

# input is encoder_out
reading_order = torch.arange(max_seq_len, dtype=torch.long)
Q = embedding(reading_order)  # (max_seq_len, d_model)
K, V = encoder_out  # (batch_size, seq_len_k, d_model)

# 这里计算att_weight的过程很容易理解,和经典的attention模型比如ASTER的attention过程相同
######
att_q = key2att(Q).unsqueeze(0).unsqueeze(2)  # (1, seq_len_q, 1, d_model)
att_k = query2att(K).unsqueeze(1)  # (batch_size, 1, seq_len_k, d_model)
att_weight = score(torch.tanh(att_q + att_k)).squeeze(3)  # (batch_size, seq_len_q, seq_len_k)
######

att_weight = F.softmax(att_weight, dim=-1)
decoder_out = torch.bmm(att_weight, K)  # (batch_size, seq_len_q, d_model)
logits = classifier(decoder_out)  # (batch_size, seq_len_q, vicab_size)

(2) https://github.com/FangShancheng/ABINet/blob/main/modules/attention.pyAttention类的实现过程:
(我注意到,贵课题组的VisionLAN中的attention也是这个,参考https://github.com/wangyuxin87/VisionLAN/blob/main/modules/modules.py中的PP_Layer类)

# e.g. d_model = 512, max_seq_len = seq_len_q = 25, vocab_size = 37
embedding = nn.Embedding(max_seq_len, d_model)
w0 = nn.Linear(max_seq_len, seq_len_k)
wv = nn.Linear(d_model, d_model)
we = nn.Linear(d_model, max_seq_len)
classifier = nn.Linear(d_model, vocab_size)

# input is encoder_out
K, V = encoder_out  # (batch_size, seq_len_k, d_model)
reading_order = torch.arange(max_seq_len, dtype=torch.long)

# 如何理解下面这段计算att_weight的代码?
#####
reading_order = embedding(reading_order)  # (seq_len_q, d_model)
reading_order = reading_order.unsqueeze(0).expand(K.size(0), -1)  # (batch_size, seq_len_q, d_model)
t = w0(reading_order.permute(0, 2, 1))  # (batch_size, d_model, seq_len_q) ==> (batch_size, d_model, seq_len_k)
t = torch.tanh(t.permute(0, 2, 1) + wv(K))  # (batch_size, seq_len_k, d_model)
att_weight = we(t)  # (batch_size, seq_len_k, d_model) ==> (batch_size, seq_len_k, seq_len_q)
######

att_weight = F.softmax(att_weight, dim=-1)
decoder_out = torch.bmm(att_weight, K)  # (batch_size, seq_len_q, d_model)
logits = classifier(decoder_out)  # (batch_size, seq_len_q, vicab_size)

麻烦解惑,谢谢!

from abinet.

Related Issues (20)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.