Code Monkey home page Code Monkey logo

Comments (3)

Junpliu avatar Junpliu commented on August 24, 2024

utterances = ' '.join(["Jane: Hello",
"Vegano Resto: Hello, how may I help you today?",
"Jane: I would like to make a reservation.",
"Jane: For 6 people, tonight around 20:00",
"Vegano Resto: Let me just check.",
"Vegano Resto: Ah, I'm afraid that there is no room at 20:00.",
"Vegano Resto: However, I could offer you a table for six at 18:30 or at 21:00",
"Vegano Resto: Would either of those times suit you?",
"Jane: Oh dear.",
"Jane: Let me just ask my friends.",
"Vegano Resto: No problem.",
"Jane: 21:00 will be ok.",
"Vegano Resto: Perfect. So tonight at 21:00 for six people under your name.",
"Jane: great, thank you!"])

from bertviz.

Junpliu avatar Junpliu commented on August 24, 2024

I ran the code and the program just crashed. However, the attention weight can be shown as expected.
image

from bertviz.

jessevig avatar jessevig commented on August 24, 2024

Hi @Junpliu, the visualization may fail for longer inputs as you are using in this example. See: https://github.com/jessevig/bertviz#%EF%B8%8F-limitations In a future version I will add a warning message in these cases.

I was able to get this to work with a shorter input as a test, does it work for you?:

from transformers import AutoTokenizer, AutoModel, utils
from bertviz import model_view

utils.logging.set_verbosity_error()  # Remove line to see warnings

# Initialize tokenizer and model. Be sure to set output_attentions=True.
# Load BART fine-tuned for summarization on CNN/Daily Mail dataset
model_name = "facebook/bart-large-cnn"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModel.from_pretrained(model_name, output_attentions=True)

# get encoded input vectors
utterances = "test"
encoder_input_ids = tokenizer(utterances, return_tensors="pt", add_special_tokens=True).input_ids

# create ids of encoded input vectors
decoder_input_ids = tokenizer("Jane made a 9 PM reservation for 6 people tonight at Vegano Resto .", return_tensors="pt", add_special_tokens=True).input_ids

outputs = model(input_ids=encoder_input_ids, decoder_input_ids=decoder_input_ids)

encoder_text = tokenizer.convert_ids_to_tokens(encoder_input_ids[0])
decoder_text = tokenizer.convert_ids_to_tokens(decoder_input_ids[0])

model_view(
    encoder_attention=outputs.encoder_attentions,
    decoder_attention=outputs.decoder_attentions,
    cross_attention=outputs.cross_attentions,
    encoder_tokens= encoder_text,
    decoder_tokens=decoder_text
)

from bertviz.

Related Issues (20)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.