Model card for boldgpt_small_patch10.cont

Example training predictions

A Vision Transformer (ViT) model trained on BOLD activation maps from NSD-Flat. The training objective was to auto-regressively predict the next patch with shuffled patch order and MSE loss. This model was trained using shared1000 as the held out validation set.

Dependencies

Usage

from boldgpt.data import ActivityTransform
from boldgpt.models import create_model
from datasets import load_dataset

model = create_model("boldgpt_small_patch10.cont", pretrained=True)

dataset = load_dataset("clane9/NSD-Flat", split="train")
dataset.set_format("torch")

transform = ActivityTransform()
batch = dataset[:1]
batch["activity"] = transform(batch["activity"])

# output: (B, N + 1, D) predicted next patches
output, state = model(batch)

Reproducing

  • Training command:

    torchrun --standalone --nproc_per_node=4 \
      scripts/train.py \
      --out_dir results \
      --model boldgpt_small_patch10 \
      --no_cat --shuffle --epochs 1000 --bs 512 \
      --workers 0 --amp --compile --wandb
    
  • Commit: e0b29adc8d5b3ed2f1a555d7de4754ba96a3bb3e

Downloads last month

-

Downloads are not tracked for this model. How to track
Safetensors
Model size
22M params
Tensor type
I64
·
F32
·
BOOL
·
Inference API
Unable to determine this model's library. Check the docs .

Dataset used to train clane9/boldgpt_small_patch10.cont