Monday 4 January 2021

What do you need for plotting the outcome of a question-answering model

I have been working on a question answering model, where I receive answers on my questions by my word embedding model BERT. But I really want to plot something like this: enter image description here

But the problem is, I don't really know how. I am really stuck at this quest. I don't know how to represent a part of the context in a plot. I do have two variables, named answer_start and answer_end which indicates in what part in the context the model got its answers from. Can someone please help me out with this and tell me what variables I need to put in my pyplot?

Below my code:

from transformers import AutoTokenizer, AutoModelForQuestionAnswering
import torch
import numpy as np
import pandas as pd

max_seq_length = 512

tokenizer = AutoTokenizer.from_pretrained("henryk/bert-base-multilingual-cased-finetuned-dutch-squad2")
model = AutoModelForQuestionAnswering.from_pretrained("henryk/bert-base-multilingual-cased-finetuned-dutch-squad2")

questions = [
    "Welke soorten gladiatoren waren er?",
    "Wat is een provocator?"
]
for question in questions: # voor elke question moet er door alle lines geiterate worden
    print(f"Question: {question}")
    f = open("test.txt", "r")
    for line in f:
      text = str(line) #het antwoord moet een string zijn
      #encoding met tokenizen van de zinnen
      inputs = tokenizer.encode_plus(question,
                                     text,
                                     add_special_tokens=True,
                                     max_length=max_seq_length,
                                     truncation=True,
                                     return_tensors="pt")
      input_ids = inputs["input_ids"].tolist()[0]

  

      #ff uitzoeken wat die ** deed
      answer_start_scores, answer_end_scores = model(**inputs, return_dict=False)

      answer_start = torch.argmax(
          answer_start_scores
          )  # Het antwoord met de hoogste argmax accuracy vanaf het begin woord
      answer_end = torch.argmax(
          answer_end_scores) + 1  # Zelfde maar dan eind woord
      answer = tokenizer.convert_tokens_to_string(
          tokenizer.convert_ids_to_tokens(input_ids[answer_start:answer_end]))

      #om het antwoorden [cls] en NaN te voorkomen    
      if answer == '[CLS]':
        continue
      elif answer == '':
        continue
      else:
        print(f"Answer: {answer}")
        print(f"Answer start: {answer_start}")
        print(f"Answer end: {answer_end}") 
      f.seek(0)
      break          
    # f.seek(0)
    # break
  
f.close()

Also the output:

> Question: Welke soorten gladiatoren waren er?
> Answer: de thraex, de retiarius en de murmillo
> Answer start: 24
> Answer end: 37
> Question: Wat is een provocator?
> Answer: telemachus
> Answer start: 87
> Answer end: 90


from What do you need for plotting the outcome of a question-answering model

No comments:

Post a Comment