Source code for olmo_tap.experiments.utils.config

"""
Config classes to support training and inference.
"""

from dataclasses import dataclass, field
from olmo_tap.constants import LORA_TARGETS, VOCAB_SIZE, WEIGHTS_DIR


[docs] @dataclass class HydraLoRAConfig: """ Supports loading Hydra model for inference or training. NOTE: n_heads_final is for book-keeping the number of heads the final Hydra model is intended to have; n_heads_training is the actual number loaded at training time. """ # architecture weights_dir: str = WEIGHTS_DIR model_size: str = "7b" # "1b" or "7b" n_heads_final: int = 5 n_heads_training: int = 1 # number of heads instantiated in training heads_depth: int = 3 vocab_size: int = VOCAB_SIZE # LoRA hyperparameters lora_r: int = 16 lora_alpha: int = 32 target_modules: list[str] = field(default_factory=lambda: LORA_TARGETS) device: str = "cuda"
[docs] @dataclass class TrainingConfig: """ Config to store training specific parameters. """ # optimizer hyperparams learning_rate: float = 1e-4 batch_size: int = 16 num_epochs: int = 1 # GPU poor :( # max generated sequence length max_seq_len: int = 256 num_workers: int = ( 4 # DataLoader workers for CPU-side preprocessing in parallel with GPU ) # which head finetunes on which shard shard_id: int = 0 num_shards: int = field(init=False) # required for tokenizer weights_dir: str = WEIGHTS_DIR # LR schedule warmup_steps: int = 100 lr_schedule: str = "cosine" # "cosine" or "linear" # checkpointing output_dir: str = "experiments/uncertainty/outputs" checkpoint_every_n_steps: int = 250 # seed (propagated from ExperimentConfig) seed: int = field(init=False) # token IDs # convention: A/B used for correct/incorrect in uncertainty A_token_id: int = field(init=False) B_token_id: int = field(init=False) C_token_id: int = field(init=False) D_token_id: int = field(init=False)
[docs] @dataclass class ExperimentConfig: """ Master config to store the HydraLoraConfig and TrainingConfig in training. """ # random seed for experiment tracking # NOTE: no default value to avoid disagreements seed: int model: HydraLoRAConfig = field(default_factory=HydraLoRAConfig) train: TrainingConfig = field(default_factory=TrainingConfig) # W&B wandb_project: str = "hydra" wandb_run_name: str | None = None device: str = "cuda" def __post_init__(self): # ensure num_shards = n_heads self.train.num_shards = self.model.n_heads_final self.train.seed = self.seed