Monday 1 May 2023

Poor rouge metric on CNN DailyMail dataset for pretrained T5 model

I am trying to fine-tune a pre-trained T5 model on CNN/DailyMail dataset with the following code:

import torch
import torch.nn as nn
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP

from datasets import load_dataset
from transformers import DefaultDataCollator
from transformers import TrainingArguments, Trainer
from transformers import T5Tokenizer, T5ForConditionalGeneration

import os
import evaluate

tokenizer = T5Tokenizer.from_pretrained("t5-small")
rouge = evaluate.load('rouge')

def process_data_to_model_inputs(batch):
    encoder_max_length = 512
    decoder_max_length = 128

    # tokenize the inputs and labels
    inputs = tokenizer(batch["article"], padding="max_length", truncation=True, max_length=encoder_max_length)
    outputs = tokenizer(batch["highlights"], padding="max_length", truncation=True, max_length=decoder_max_length)

    batch["input_ids"] = inputs.input_ids
    batch["attention_mask"] = inputs.attention_mask
    batch["decoder_input_ids"] = outputs.input_ids
    batch["decoder_attention_mask"] = outputs.attention_mask
    batch["labels"] = outputs.input_ids.copy()

    batch["labels"] = [[-100 if token == tokenizer.pad_token_id else token for token in labels] for labels in batch["labels"]]

    return batch

def setup_distributed_environment():
    dist.init_process_group(backend='nccl')

    torch.manual_seed(42)

def generate_summary(batch, model): 
  inputs = tokenizer(batch["article"], padding="max_length", truncation=True, max_length=512, return_tensors="pt") 
  inputs = inputs.to(model.device) # Ensure that tensors are on the same device as the model 
  summary_ids = model.generate(inputs.input_ids, num_beams=4, max_length=128, early_stopping=True) 
  batch["predicted_highlights"] = tokenizer.batch_decode(summary_ids, skip_special_tokens=True) 
  return batch

def train():
    setup_distributed_environment()
    
    cnndm = load_dataset("cnn_dailymail", "3.0.0")

    tokenized_cnndm = cnndm.map(
        process_data_to_model_inputs, 
        batched=True, 
        remove_columns=cnndm["train"].column_names
    )
    
    model = T5ForConditionalGeneration.from_pretrained("t5-small")
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model.to(device)

    local_rank = int(os.environ["LOCAL_RANK"])
    global_rank = int(os.environ["RANK"])
    model = nn.parallel.DistributedDataParallel(model, device_ids=[local_rank])

    training_args = TrainingArguments(
        output_dir="./updated_squad_fine_tuned_model",
        evaluation_strategy="epoch",
        learning_rate=5.6e-05,
        lr_scheduler_type="linear",
        warmup_ratio=0.1,
        per_device_train_batch_size=16,
        per_device_eval_batch_size=16,
        num_train_epochs=2,
        weight_decay=0.01,
        local_rank=local_rank,
        fp16=True,
        remove_unused_columns=False
    )

    data_collator = DefaultDataCollator()

    trainer = Trainer(
        model=model,
        args=training_args,
        train_dataset=tokenized_cnndm["train"].select(range(50000)),
        eval_dataset=tokenized_cnndm["validation"].select(range(10000)),
        tokenizer=tokenizer,
        data_collator=data_collator,
    )

    trainer.train()

    if local_rank == 0:
        model.module.save_pretrained("fine_tuned_squad_model")
        tokenizer.save_pretrained("fine_tuned_squad_model")

    results = cnndm["test"].select(range(5000)).map(lambda batch: generate_summary(batch, model.module), batched=True, remove_columns=["article"], batch_size=16)

    # Compute the metric using the generated summaries and the reference summaries
    rouge_score = rouge.compute(predictions=results["predicted_highlights"], references=results["highlights"])

    print(rouge_score)

def main():
    torch.cuda.empty_cache()
    
    train()

if __name__ == '__main__':
    main()

I am not running it on the entire dataset, but taking the first 50k training examples and 10k validation examples. After training, I'm using the first 10k test examples for computing the rouge metric.

I'm using the t5-small variation from huggingface transformers library. I'm also using a distributed setup, running the program in 4 nodes with the following command:

torchrun --nproc_per_node=gpu --nnodes=4 --node_rank=0 --rdzv_id=456 --rdzv_backend=c10d --rdzv_endpoint=129.82.44.119:30375 cnn_hf_test.py

After training, I'm getting the following output:

{'loss': 2.1367, 'learning_rate': 4.258706467661692e-05, 'epoch': 0.64}                                                                                                                                            
{'eval_runtime': 8.023, 'eval_samples_per_second': 1246.419, 'eval_steps_per_second': 19.569, 'epoch': 1.0}                                                                                                        
{'loss': 0.0305, 'learning_rate': 2.2686567164179102e-05, 'epoch': 1.28}                                                                                                                                           
{'loss': 0.0172, 'learning_rate': 2.7860696517412936e-06, 'epoch': 1.92}                                                                                                                                           
{'eval_runtime': 8.0265, 'eval_samples_per_second': 1245.871, 'eval_steps_per_second': 19.56, 'epoch': 2.0}                                                                                                        
{'train_runtime': 5110.103, 'train_samples_per_second': 19.569, 'train_steps_per_second': 0.306, 'train_loss': 0.6989707885800726, 'epoch': 2.0}                                                                   
100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1564/1564 [1:25:08<00:00,  3.27s/it]
{'rouge1': 0.008768024824095142, 'rouge2': 0.000294696538416436, 'rougeL': 0.008527464153847374, 'rougeLsum': 0.00875863140146953}                                                                                 
WARNING:torch.distributed.elastic.rendezvous.dynamic_rendezvous:The node 'jupiter.cs.colostate.edu_805773_0' has failed to send a keep-alive heartbeat to the rendezvous '456' due to an error of type RendezvousTimeoutError.

From my understanding, the rouge score metric is very poor and it should at least be greater than 0.2 for ROGUE-1, but I'm getting 0.008.

My cluster setup does not allow me to load a larger model like t5-base or t5-large.

Could you provide me with some suggestions to improve the rouge score metric? Or is this performance expected for this setup and model? Any insight is much appreciated.



from Poor rouge metric on CNN DailyMail dataset for pretrained T5 model

No comments:

Post a Comment