I have been working on a question answering model, where I receive answers on my questions by my word embedding model BERT. But I really want to plot something like this:
But the problem is, I don't really know how. I am really stuck at this quest. I don't know how to represent a part of the context in a plot. I do have two variables, named answer_start and answer_end which indicates in what part in the context the model got its answers from. Can someone please help me out with this and tell me what variables I need to put in my pyplot?
Below my code:
from transformers import AutoTokenizer, AutoModelForQuestionAnswering
import torch
import numpy as np
import pandas as pd
max_seq_length = 512
tokenizer = AutoTokenizer.from_pretrained("henryk/bert-base-multilingual-cased-finetuned-dutch-squad2")
model = AutoModelForQuestionAnswering.from_pretrained("henryk/bert-base-multilingual-cased-finetuned-dutch-squad2")
questions = [
"Welke soorten gladiatoren waren er?",
"Wat is een provocator?"
]
for question in questions: # voor elke question moet er door alle lines geiterate worden
print(f"Question: {question}")
f = open("test.txt", "r")
for line in f:
text = str(line) #het antwoord moet een string zijn
#encoding met tokenizen van de zinnen
inputs = tokenizer.encode_plus(question,
text,
add_special_tokens=True,
max_length=max_seq_length,
truncation=True,
return_tensors="pt")
input_ids = inputs["input_ids"].tolist()[0]
#ff uitzoeken wat die ** deed
answer_start_scores, answer_end_scores = model(**inputs, return_dict=False)
answer_start = torch.argmax(
answer_start_scores
) # Het antwoord met de hoogste argmax accuracy vanaf het begin woord
answer_end = torch.argmax(
answer_end_scores) + 1 # Zelfde maar dan eind woord
answer = tokenizer.convert_tokens_to_string(
tokenizer.convert_ids_to_tokens(input_ids[answer_start:answer_end]))
#om het antwoorden [cls] en NaN te voorkomen
if answer == '[CLS]':
continue
elif answer == '':
continue
else:
print(f"Answer: {answer}")
print(f"Answer start: {answer_start}")
print(f"Answer end: {answer_end}")
f.seek(0)
break
# f.seek(0)
# break
f.close()
Also the output:
> Question: Welke soorten gladiatoren waren er?
> Answer: de thraex, de retiarius en de murmillo
> Answer start: 24
> Answer end: 37
> Question: Wat is een provocator?
> Answer: telemachus
> Answer start: 87
> Answer end: 90
from What do you need for plotting the outcome of a question-answering model
No comments:
Post a Comment