Lab 3.2 โ€” LoRA Fine-Tuning Guide

Stat 214 ยท Spring 2026 ยท Conceptual overview with implementation details

What 3.2 asks you to do

Repeat the 3.1 analysis with two additional embedding methods: (1) pretrained BERT, and (2) BERT fine-tuned with LoRA. Once you have per-word embeddings, the rest of the pipeline is identical to 3.1: downsample โ†’ delays โ†’ ridge โ†’ evaluate CC.

Part 1: Extracting BERT embeddings

BERT takes text and produces a vector per subword token, not per word. The word "playing" might be tokenized into ["play", "##ing"], each getting its own 768-dimensional vector. You need to aggregate subword vectors back to one vector per original word.

Pseudocode: BERT embedding extraction
function get_word_embeddings(words, bert_encoder):
    # words = ["the", "quick", "brown", "fox"]
    
    # 1. Tokenize (is_split_into_words tells BERT these are pre-split)
    subtokens = bert_tokenizer(words, is_split_into_words=True)
    # e.g., ["the", "qu", "##ick", "brown", "fox"]
    
    # 2. Get the word-to-subtoken mapping
    word_ids = subtokens.word_ids()
    # e.g., [0, 1, 1, 2, 3]  โ€” subtokens 1,2 both map to word 1

    # 3. Forward pass through BERT encoder
    hidden_states = bert_encoder(subtokens)  # (num_subtokens, 768)
    
    # 4. Aggregate: mean-pool subtokens sharing the same word_id
    for each unique word_id w:
        word_embedding[w] = mean(hidden_states[k] for k where word_ids[k] == w)
    
    return word_embeddings  # shape: (len(words), 768)

Handling long stories (512-token limit)

Most stories have ~800 words โ†’ ~1000+ subtokens, exceeding BERT's 512 limit. Process in overlapping windows: slide a 512-token window across the subtoken sequence with some stride (e.g., 256). For words appearing in multiple windows, average their embeddings across windows.

Part 2: What LoRA does

BERT has ~110M parameters. Full fine-tuning updates all of them, risking overfitting on our ~24K words of story data. LoRA freezes all original weights and adds small trainable matrices to specific layers.

The key idea Each attention layer has weight matrices $W \in \mathbb{R}^{d_1 \times d_2}$ (e.g., $d_1 = d_2 = 768$ for the query projection). LoRA adds a low-rank update: $$W_{\text{new}} = W_{\text{frozen}} + \frac{\gamma}{r} AB^\top$$ where $A \in \mathbb{R}^{d_1 \times r}$, $B \in \mathbb{R}^{d_2 \times r}$, $r$ is the rank (e.g., 4 or 8), and $\gamma$ is a scaling factor. Only $A$ and $B$ are trained; $W$ stays frozen.

Intuition: a full $768 \times 768$ matrix has ~590K params per layer. A rank-8 factorization has $768 \times 8 \times 2 = 12{,}288$ params (a 98% reduction). The assumption is that the task-specific adaptation lives in a low-dimensional subspace. Think PCA: a few principal directions capture most of the needed change.

Part 3: Fine-tuning with LoRA

The fine-tuning objective is Masked Language Modeling (MLM): the same task BERT was originally pretrained on, but applied to our story transcripts. This adapts BERT's contextual representations to the narrative domain.

Pseudocode: LoRA fine-tuning
# ---- Setup ----
# Load BERT with its MLM prediction head (needed for the MLM loss)
mlm_model = load_bert_for_masked_lm("bert-base-uncased")

# Apply LoRA: freeze original weights, add trainable A,B to
# the query and value projections in each attention layer
apply_lora(mlm_model, rank=8, target=["query", "value"])

# ---- Prepare data ----
stories = load_story_transcripts()  # from raw_text.pkl

# Tokenize into chunks โ‰ค 512 tokens with overlap
chunks = sliding_window_tokenize(stories, max_length=512, stride=256)

# ---- Train ----
for epoch in range(num_epochs):
    for chunk in shuffle(chunks):
        # Mask 15% of tokens:
        #   80% โ†’ [MASK], 10% โ†’ random token, 10% โ†’ unchanged
        masked_chunk, labels = apply_mlm_masking(chunk, prob=0.15)
        
        # Forward: predict masked tokens from context
        loss = mlm_model(masked_chunk, labels)
        
        # Backward: only A, B matrices receive gradients
        loss.backward()
        optimizer.step()

# Save adapter (~1-5 MB, just the A and B matrices)
save_lora_adapter("lora_adapter")
Implementation details (HuggingFace + PEFT code)
import pickle, torch, torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from transformers import (BertTokenizerFast, BertForMaskedLM,
                          DataCollatorForLanguageModeling)
from peft import LoraConfig, get_peft_model

# --- Data ---
tokenizer = BertTokenizerFast.from_pretrained("bert-base-uncased")
with open('raw_text.pkl', 'rb') as f:
    raw_text = pickle.load(f)

story_texts = []
for name, seq in raw_text.items():
    words = [str(w) for w in seq.data if isinstance(w, str)]
    story_texts.append(" ".join(words))

class StoryMLMDataset(Dataset):
    def __init__(self, texts, tokenizer, max_length=512, stride=256):
        self.examples = []
        for text in texts:
            ids = tokenizer(text, truncation=False)["input_ids"]
            for start in range(0, len(ids), stride):
                chunk = ids[start:start+max_length]
                if len(chunk) >= 32:
                    self.examples.append(torch.tensor(chunk))
    def __len__(self): return len(self.examples)
    def __getitem__(self, idx): return {"input_ids": self.examples[idx]}

dataset = StoryMLMDataset(story_texts, tokenizer)
collator = DataCollatorForLanguageModeling(tokenizer=tokenizer,
                                            mlm=True, mlm_probability=0.15)
dataloader = DataLoader(dataset, batch_size=8, shuffle=True,
                        collate_fn=collator)

# --- Model + LoRA ---
mlm_model = BertForMaskedLM.from_pretrained("bert-base-uncased")
lora_config = LoraConfig(
    r=8, lora_alpha=16, lora_dropout=0.1,
    target_modules=["query", "value"]
)
lora_model = get_peft_model(mlm_model, lora_config)
lora_model.print_trainable_parameters()

# --- Train ---
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
lora_model.to(device)
lora_model.train()
optimizer = optim.AdamW(lora_model.parameters(), lr=2e-4)

for epoch in range(3):
    total_loss, n = 0, 0
    for batch in dataloader:
        batch = {k: v.to(device) for k, v in batch.items()}
        loss = lora_model(**batch).loss
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()
        total_loss += loss.item()
        n += 1
    print(f"Epoch {epoch+1}, Loss: {total_loss/n:.4f}")
    # Expected: starts ~2.5-3.5, drops to ~2.0-2.5. Below 1.5 = overfitting.

lora_model.save_pretrained("lora_adapter_r8")

Part 4: Extracting embeddings from the fine-tuned model

Pseudocode
# Load base BERT encoder (NOT the MLM head)
base_encoder = load_bert_model("bert-base-uncased")

# Load LoRA adapter on top of the encoder
finetuned_encoder = load_lora_adapter(base_encoder, "lora_adapter")

# Same extraction function as Part 1
embeddings = get_word_embeddings(words, finetuned_encoder)
# โ†’ (W_s, 768)

# Then: downsample โ†’ delays โ†’ ridge โ†’ CC (identical to 3.1)
BertModel vs BertForMaskedLM. These are different classes. BertModel outputs last_hidden_state (shape: seq_len ร— 768): these are the embeddings you want. BertForMaskedLM = BertModel + an MLM prediction head (dense โ†’ layer norm โ†’ projection to vocab size 30,522). It outputs logits (shape: seq_len ร— 30,522), which are token predictions, not embeddings. You train with BertForMaskedLM (because you need the MLM loss), but extract embeddings from BertModel (because you want the hidden states). The LoRA adapter modifies the attention layers, which are shared between both. If a student accidentally extracts from the MLM model, they'll get vocabulary logits instead of embeddings, and the shapes won't match. You can also access the base encoder inside BertForMaskedLM via model.bert, but loading the adapter onto a fresh BertModel is cleaner.

Hyperparameters to explore

ParameterWhat it controlsStart withExplore
Rank ($r$)Capacity of the adaptation84, 8, 16
Scaling ($\gamma$, lora_alpha)Update magnitude. Effective rate scales as $\gamma / r$$2r$$r$ to $2r$
Target modulesWhich attention weights get adaptersquery + valueTry adding key
Learning rateOptimizer step size$2 \times 10^{-4}$$10^{-4}$ to $5 \times 10^{-4}$
EpochsPasses over the data33โ€“5
Window strideOverlap between training chunks256128โ€“384

Compare at least 2-3 rank values and report how downstream CC changes. Does higher rank help, or does it overfit?

Questions to address in the report

Common pitfalls

Subword alignment. BERT tokenizes "playing" into ["play", "##ing"]. You must aggregate these back to one embedding per original word. The word_ids() method gives the mapping. If your embedding count doesn't match your word count, the alignment with the fMRI data will be wrong and everything downstream breaks silently.
512-token limit. If you truncate to 512 and ignore the rest, you lose the end of every story. Use overlapping windows.
GPU memory. If you get CUDA OOM errors: reduce batch size, use gradient accumulation, or reduce max_length. Submit via sbatch with --gres=gpu:1 --mem=32G.
Gradient accumulation (if batch size 8 doesn't fit in memory)
accum_steps = 4  # effective batch = 2 * 4 = 8
dataloader = DataLoader(dataset, batch_size=2, shuffle=True,
                        collate_fn=collator)
optimizer.zero_grad()
for i, batch in enumerate(dataloader):
    batch = {k: v.to(device) for k, v in batch.items()}
    loss = lora_model(**batch).loss / accum_steps
    loss.backward()
    if (i + 1) % accum_steps == 0:
        optimizer.step()
        optimizer.zero_grad()
Training on test data. If using leave-one-story-out for ridge CV, consider whether the held-out story should also be excluded from LoRA fine-tuning. The simplest approach: fine-tune on all stories (MLM doesn't use fMRI labels, so it's not label leakage), then do LOSO only at the ridge step. Document your choice.

Key libraries: transformers, peft, torch. See the PEFT docs and the HuggingFace LoRA tutorial linked in the lab instructions.