Source code for tau_eval.metrics.nli

import torch

from transformers import (
    AutoModelForSequenceClassification,
    AutoTokenizer,
    PreTrainedModel,
    PreTrainedTokenizer,
)


[docs] def load_nli( model_name: str = "alisawuffles/roberta-large-wanli", device: str = "cuda" ) -> tuple[PreTrainedTokenizer, PreTrainedModel]: """ Loads the NLI (Natural Language Inference) model and tokenizer. Args: model_name: HuggingFace model to load device: The device to load the model onto ("cuda" or "cpu"). Returns: A tuple containing the loaded tokenizer and model. """ # Download necessary models for evaluation nli_tokenizer: PreTrainedTokenizer = AutoTokenizer.from_pretrained(model_name, map_location=device) nli_model: PreTrainedModel = AutoModelForSequenceClassification.from_pretrained(model_name).to(device) return (nli_tokenizer, nli_model)
[docs] def compute_nli( input_texts: str | list[str], output_texts: str | list[str], nli_tokenizer: PreTrainedTokenizer, nli_model: PreTrainedModel, batch_size: int = 16, device: str = "cuda", max_length: int = 128, ) -> dict[str, list[float]]: """ Computes the probability of entailment between two texts using the NLI model. Args: input_text: The premise text. output_text: The hypothesis text. nli_tokenizer: The tokenizer for the NLI model. nli_model: The NLI model. Returns: A dictionary containing the probability of entailment. The dictionary has the key "entailment" with a float value. """ if not isinstance(input_texts, list): input_texts = [input_texts] # type: ignore if not isinstance(output_texts, list): output_texts = [output_texts] # type: ignore assert len(input_texts) == len(output_texts), "inputs are different lengths" entailment_prob = [] for start_index in range(0, len(output_texts), batch_size): end_index = min(start_index + batch_size, len(output_texts)) input_batch = input_texts[start_index:end_index] output_batch = output_texts[start_index:end_index] x = nli_tokenizer( input_batch, output_batch, return_tensors="pt", max_length=max_length, padding=True, truncation=True, ).to(device) with torch.no_grad(): logits = nli_model.to(device)(**x).logits entailment_prob += logits.softmax(dim=1)[:, 1].tolist() return {"entailment": entailment_prob}