Source code for tau_eval.tasks.deidentification
from dataclasses import dataclass
from datasets import Dataset, load_dataset
from tqdm.auto import tqdm
from .customtask import CustomTask
[docs]
def extract_non_o_words(tokens, tags):
"""
Extract words associated with non-"O" tags by merging tokens.
Args:
tokens (list of str): The list of tokens.
tags (list of str): The list of tags corresponding to each token.
Returns:
dict: A dictionary where keys are the tags (e.g., "B-NAME", "B-EMAIL")
and values are the merged words for each tag.
"""
result = {}
current_tag = None
current_word = []
for token, tag in zip(tokens, tags):
if tag == "O":
if current_tag is not None:
# Save the completed word for the current tag
result.setdefault(current_tag, []).append("".join(current_word))
current_word = []
current_tag = None
else:
base_tag = tag.split("-")[-1] # Get the base tag (e.g., "NAME", "EMAIL")
prefix = tag.split("-")[0] # Get the prefix (B or I)
# If it's a B-tag or a different tag sequence, start a new word
if prefix == "B" or current_tag != base_tag:
# If there's a current word, save it
if current_tag is not None and current_word:
result.setdefault(current_tag, []).append("".join(current_word))
# Reset for new word
current_tag = base_tag
current_word = [token.replace("##", "")]
else:
# Continue the current word
current_word.append(token.replace("##", ""))
# Save any remaining word
if current_tag is not None and current_word:
result.setdefault(current_tag, []).append("".join(current_word))
return result
[docs]
def dataset_task_preprocessing(dataset_name: str, dataset_size: int = 2500) -> Dataset:
match dataset_name:
case "ai4privacy/pii-masking-400k":
raw_data = load_dataset(dataset_name)
texts = []
labels = []
for split in raw_data:
for example in raw_data[split].select(range(dataset_size)):
if example["language"] == "en":
texts.append(example["source_text"])
labels.append(extract_non_o_words(example["mbert_tokens"], example["mbert_token_classes"]))
final_data = Dataset.from_dict({"text": texts, "labels": labels})
return final_data
case _:
raise NotImplementedError
[docs]
@dataclass
class DeIdentification(CustomTask):
dataset: Dataset = None
name: str = ""
max_rows: int = None
def __post_init__(self):
if isinstance(self.dataset, str):
name = self.dataset
self.dataset = dataset_task_preprocessing(self.dataset)
if not self.name:
self.name = name
[docs]
def evaluate(self, new_texts: list[str]) -> dict:
assert len(self.dataset) == len(new_texts)
recall_dict = {}
result_dict = {}
total_0 = 0
total_1 = 0
for i, example in tqdm(enumerate(self.dataset)):
for entity_name, values in example["labels"].items():
if values is None:
continue
if entity_name not in recall_dict.keys():
recall_dict[entity_name] = [0, 0]
for e in values:
recall_dict[entity_name][1] += 1
total_1 += 1
if e.lower() in new_texts[i].lower().replace(" ", ""):
recall_dict[entity_name][0] += 1
total_0 += 1
for entity_name, values in recall_dict.items():
result_dict[f"{entity_name}_recall"] = values[0] / values[1]
result_dict["Total_recall"] = total_0 / total_1
return result_dict