"""
Data loading for security head supervised finetuning on MedMCQA.
"""
from torch.utils.data import DataLoader
from datasets import load_dataset
from datasets.arrow_dataset import Dataset
from transformers import AutoTokenizer, PreTrainedTokenizerBase
from olmo_tap.experiments.utils.config import TrainingConfig
[docs]
def preprocess_example(
example: dict[str, str],
tokenizer: PreTrainedTokenizerBase,
max_seq_len: int,
token_ids: list[int],
) -> dict:
"""Tokenize the question prompt and store the ground-truth answer token ID."""
mcq_options = [example["opa"], example["opb"], example["opc"], example["opd"]]
question = format_question(example["question"], mcq_options)
messages = [{"role": "user", "content": question}]
prompt = tokenizer.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True
)
encoding = tokenizer(
prompt,
padding="max_length",
truncation=True,
max_length=max_seq_len,
return_tensors="pt",
)
label = token_ids[int(example["cop"])]
return {
"input_ids": encoding["input_ids"].squeeze(0),
"attention_mask": encoding["attention_mask"].squeeze(0),
"label": label,
}
[docs]
def load_shard(
config: TrainingConfig,
) -> tuple[DataLoader, int, int, int, int]:
"""Load a MedMCQA shard, tokenize prompts, return train_dl."""
tokenizer = AutoTokenizer.from_pretrained(config.weights_dir)
assert tokenizer is not None
A_id = tokenizer.encode("A", add_special_tokens=False)[0]
B_id = tokenizer.encode("B", add_special_tokens=False)[0]
C_id = tokenizer.encode("C", add_special_tokens=False)[0]
D_id = tokenizer.encode("D", add_special_tokens=False)[0]
base_ds = load_dataset("openlifescienceai/medmcqa", split="train", streaming=False)
assert isinstance(base_ds, Dataset), f"Expected Dataset, got {type(base_ds)}"
shard_ds = base_ds.shard(num_shards=config.num_shards, index=config.shard_id)
shard_ds = shard_ds.select_columns(["question", "opa", "opb", "opc", "opd", "cop"])
token_ids = [A_id, B_id, C_id, D_id]
shard_ds = shard_ds.map(
preprocess_example,
fn_kwargs={
"tokenizer": tokenizer,
"max_seq_len": config.max_seq_len,
"token_ids": token_ids,
},
remove_columns=["question", "opa", "opb", "opc", "opd", "cop"],
# Stale caches from before the attention_mask addition have the same
# fingerprint on some HF datasets versions; force reprocess.
load_from_cache_file=False,
)
shard_ds.set_format("torch")
train_dataloader = DataLoader(
shard_ds, # type: ignore[arg-type]
batch_size=config.batch_size,
shuffle=True,
drop_last=True,
num_workers=config.num_workers,
)
return train_dataloader, A_id, B_id, C_id, D_id