Repeat the 3.1 analysis with two additional embedding methods: (1) pretrained BERT, and (2) BERT fine-tuned with LoRA. Once you have per-word embeddings, the rest of the pipeline is identical to 3.1: downsample โ delays โ ridge โ evaluate CC.
BERT takes text and produces a vector per subword token, not per word. The word "playing" might be tokenized into ["play", "##ing"], each getting its own 768-dimensional vector. You need to aggregate subword vectors back to one vector per original word.
function get_word_embeddings(words, bert_encoder):
# words = ["the", "quick", "brown", "fox"]
# 1. Tokenize (is_split_into_words tells BERT these are pre-split)
subtokens = bert_tokenizer(words, is_split_into_words=True)
# e.g., ["the", "qu", "##ick", "brown", "fox"]
# 2. Get the word-to-subtoken mapping
word_ids = subtokens.word_ids()
# e.g., [0, 1, 1, 2, 3] โ subtokens 1,2 both map to word 1
# 3. Forward pass through BERT encoder
hidden_states = bert_encoder(subtokens) # (num_subtokens, 768)
# 4. Aggregate: mean-pool subtokens sharing the same word_id
for each unique word_id w:
word_embedding[w] = mean(hidden_states[k] for k where word_ids[k] == w)
return word_embeddings # shape: (len(words), 768)
Most stories have ~800 words โ ~1000+ subtokens, exceeding BERT's 512 limit. Process in overlapping windows: slide a 512-token window across the subtoken sequence with some stride (e.g., 256). For words appearing in multiple windows, average their embeddings across windows.
BERT has ~110M parameters. Full fine-tuning updates all of them, risking overfitting on our ~24K words of story data. LoRA freezes all original weights and adds small trainable matrices to specific layers.
The fine-tuning objective is Masked Language Modeling (MLM): the same task BERT was originally pretrained on, but applied to our story transcripts. This adapts BERT's contextual representations to the narrative domain.
# ---- Setup ----
# Load BERT with its MLM prediction head (needed for the MLM loss)
mlm_model = load_bert_for_masked_lm("bert-base-uncased")
# Apply LoRA: freeze original weights, add trainable A,B to
# the query and value projections in each attention layer
apply_lora(mlm_model, rank=8, target=["query", "value"])
# ---- Prepare data ----
stories = load_story_transcripts() # from raw_text.pkl
# Tokenize into chunks โค 512 tokens with overlap
chunks = sliding_window_tokenize(stories, max_length=512, stride=256)
# ---- Train ----
for epoch in range(num_epochs):
for chunk in shuffle(chunks):
# Mask 15% of tokens:
# 80% โ [MASK], 10% โ random token, 10% โ unchanged
masked_chunk, labels = apply_mlm_masking(chunk, prob=0.15)
# Forward: predict masked tokens from context
loss = mlm_model(masked_chunk, labels)
# Backward: only A, B matrices receive gradients
loss.backward()
optimizer.step()
# Save adapter (~1-5 MB, just the A and B matrices)
save_lora_adapter("lora_adapter")
import pickle, torch, torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from transformers import (BertTokenizerFast, BertForMaskedLM,
DataCollatorForLanguageModeling)
from peft import LoraConfig, get_peft_model
# --- Data ---
tokenizer = BertTokenizerFast.from_pretrained("bert-base-uncased")
with open('raw_text.pkl', 'rb') as f:
raw_text = pickle.load(f)
story_texts = []
for name, seq in raw_text.items():
words = [str(w) for w in seq.data if isinstance(w, str)]
story_texts.append(" ".join(words))
class StoryMLMDataset(Dataset):
def __init__(self, texts, tokenizer, max_length=512, stride=256):
self.examples = []
for text in texts:
ids = tokenizer(text, truncation=False)["input_ids"]
for start in range(0, len(ids), stride):
chunk = ids[start:start+max_length]
if len(chunk) >= 32:
self.examples.append(torch.tensor(chunk))
def __len__(self): return len(self.examples)
def __getitem__(self, idx): return {"input_ids": self.examples[idx]}
dataset = StoryMLMDataset(story_texts, tokenizer)
collator = DataCollatorForLanguageModeling(tokenizer=tokenizer,
mlm=True, mlm_probability=0.15)
dataloader = DataLoader(dataset, batch_size=8, shuffle=True,
collate_fn=collator)
# --- Model + LoRA ---
mlm_model = BertForMaskedLM.from_pretrained("bert-base-uncased")
lora_config = LoraConfig(
r=8, lora_alpha=16, lora_dropout=0.1,
target_modules=["query", "value"]
)
lora_model = get_peft_model(mlm_model, lora_config)
lora_model.print_trainable_parameters()
# --- Train ---
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
lora_model.to(device)
lora_model.train()
optimizer = optim.AdamW(lora_model.parameters(), lr=2e-4)
for epoch in range(3):
total_loss, n = 0, 0
for batch in dataloader:
batch = {k: v.to(device) for k, v in batch.items()}
loss = lora_model(**batch).loss
loss.backward()
optimizer.step()
optimizer.zero_grad()
total_loss += loss.item()
n += 1
print(f"Epoch {epoch+1}, Loss: {total_loss/n:.4f}")
# Expected: starts ~2.5-3.5, drops to ~2.0-2.5. Below 1.5 = overfitting.
lora_model.save_pretrained("lora_adapter_r8")
# Load base BERT encoder (NOT the MLM head)
base_encoder = load_bert_model("bert-base-uncased")
# Load LoRA adapter on top of the encoder
finetuned_encoder = load_lora_adapter(base_encoder, "lora_adapter")
# Same extraction function as Part 1
embeddings = get_word_embeddings(words, finetuned_encoder)
# โ (W_s, 768)
# Then: downsample โ delays โ ridge โ CC (identical to 3.1)
BertModel outputs last_hidden_state (shape: seq_len ร 768): these are the embeddings you want.
BertForMaskedLM = BertModel + an MLM prediction head (dense โ layer norm โ projection to vocab size 30,522). It outputs logits (shape: seq_len ร 30,522), which are token predictions, not embeddings. You train with BertForMaskedLM (because you need the MLM loss), but extract embeddings from BertModel (because you want the hidden states). The LoRA adapter modifies the attention layers, which are shared between both. If a student accidentally extracts from the MLM model, they'll get vocabulary logits instead of embeddings, and the shapes won't match.
You can also access the base encoder inside BertForMaskedLM via model.bert, but loading the adapter onto a fresh BertModel is cleaner.| Parameter | What it controls | Start with | Explore |
|---|---|---|---|
| Rank ($r$) | Capacity of the adaptation | 8 | 4, 8, 16 |
Scaling ($\gamma$, lora_alpha) | Update magnitude. Effective rate scales as $\gamma / r$ | $2r$ | $r$ to $2r$ |
| Target modules | Which attention weights get adapters | query + value | Try adding key |
| Learning rate | Optimizer step size | $2 \times 10^{-4}$ | $10^{-4}$ to $5 \times 10^{-4}$ |
| Epochs | Passes over the data | 3 | 3โ5 |
| Window stride | Overlap between training chunks | 256 | 128โ384 |
Compare at least 2-3 rank values and report how downstream CC changes. Does higher rank help, or does it overfit?
word_ids() method gives the mapping. If your embedding count doesn't match your word count, the alignment with the fMRI data will be wrong and everything downstream breaks silently.max_length. Submit via sbatch with --gres=gpu:1 --mem=32G.accum_steps = 4 # effective batch = 2 * 4 = 8
dataloader = DataLoader(dataset, batch_size=2, shuffle=True,
collate_fn=collator)
optimizer.zero_grad()
for i, batch in enumerate(dataloader):
batch = {k: v.to(device) for k, v in batch.items()}
loss = lora_model(**batch).loss / accum_steps
loss.backward()
if (i + 1) % accum_steps == 0:
optimizer.step()
optimizer.zero_grad()
Key libraries: transformers, peft, torch. See the PEFT docs and the HuggingFace LoRA tutorial linked in the lab instructions.