Code Monkey home page Code Monkey logo

Comments (4)

HalimSD avatar HalimSD commented on June 26, 2024

I think you're trying to give the linear layer a tensor with dim1=512. Which is the prefix you obtained from your preprocessing when you parsed the data. You encoded the images using the CLIP encode_image function which outputs a tensor with dim1=512. Then you tried to train the model with a prefix size that has a tensor with dim1=640.

from clip_prefix_caption.

ScottishFold007 avatar ScottishFold007 commented on June 26, 2024

prefix size

Did you come to this conclusion from reading the above colab notebook? But I have changed the prefix size to 512, I still get this error? Do you have any good solution?

from clip_prefix_caption.

HalimSD avatar HalimSD commented on June 26, 2024

I don’t have access to your notebook.
I came to that conclusion cuz i'm facing the exact error and came here to open a similar issue

Do you have any good solution?

No

from clip_prefix_caption.

MachineLearning11 avatar MachineLearning11 commented on June 26, 2024

我正在尝试基于转换器重构您的模型,但我遇到了一个问题:某处总是有错误,但我尝试了很多解决方案,但我不知道。 图片

class ClipCaptionModel(PreTrainedModel):
  def __init__(self, config):
    super(ClipCaptionModel, self).__init__(config)
    self.prefix_length = config.prefix_length
    self.clip_length = config.clip_length
    self.prefix_size = config.prefix_size
    self.num_layers = config.num_layers
    self.mapping_type = config.mapping_type 
    decoder = config.decoder
    self.gpt = GPT2LMHeadModel.from_pretrained('uer/gpt2-chinese-cluecorpussmall')
    self.gpt_embedding_size = self.gpt.transformer.wte.weight.shape[1]
    self.clip_project = TransformerMapper(self.prefix_size, self.gpt_embedding_size, self.prefix_length, self.clip_length, self.num_layers)  #(512,768,10,8)
    print(self.prefix_size, self.gpt_embedding_size, self.prefix_length, self.clip_length, self.num_layers)

  def get_dummy_token(self, batch_size: int, device: torch.device) -> torch.Tensor:
    return torch.zeros(batch_size, self.prefix_length, dtype=torch.int64, device=device)

  def forward(self, 
              tokens: torch.Tensor, 
              prefix: torch.Tensor, 
              mask: Optional[torch.Tensor] = None,
              labels: Optional[torch.Tensor] = None):
    
      embedding_text = self.gpt.transformer.wte(tokens)
      print(prefix.shape)
      prefix_projections = self.clip_project(prefix).view(-1, self.prefix_length, self.gpt_embedding_size)
      embedding_cat = torch.cat((prefix_projections, embedding_text), dim=1)
      if labels is not None:
        dummy_token = self.get_dummy_token(tokens.shape[0], tokens.device)
        labels = torch.cat((dummy_token, tokens), dim=1)
      out = self.gpt(inputs_embeds=embedding_cat, labels=labels, attention_mask=mask)
      return out


class ClipCaptionPrefix(ClipCaptionModel):

    def parameters(self, recurse: bool = True):
        return self.clip_project.parameters()

    def train(self, mode: bool = True):
        super(ClipCaptionPrefix, self).train(mode)
        self.gpt.eval()
        return self

` 这里是colab上的地址:https://colab.research.google.com/drive/1sEg9HbDwRPs9_SNVjjsPE_sk449P9Svc#scrollTo=3pP_n5oQrXPg&uniqifier=1

请问你解决了吗,我的问题和您相同也是在linear出出问题了

from clip_prefix_caption.

Related Issues (20)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.