[NeurIPS 2024]Brain-JEPA: Brain Dynamics Foundation Model with Gradient Positioning and Spatiotempo-
论文网址:[2409.19407] Brain-JEPA: Brain Dynamics Foundation Model with Gradient Positioning and Spatiotemporal Masking
论文代码:GitHub – Eric-LRL/Brain-JEPA: Official codebase for "Brain-JEPA: Brain Dynamics Foundation Model with Gradient Positioning and Spatiotemporal Masking" (NerIPS 2024, Spotlight).
英文是纯手打的!论文原文的summarizing and paraphrasing。可能会出现难以避免的拼写错误和语法错误,若有发现欢迎评论指正!文章偏向于笔记,谨慎食用
目录
1. 心得
2. 论文逐段精读
2.1. Abstract
2.2. Introduction
2.3. Related Work
2.4. Method
2.4.1. Brain Gradient Positioning
2.4.2. Spatiotemporal Masking
2.5. Experiments
2.5.1. Datasets
2.5.2. Implementation details
2.5.3. Main results
2.5.4. Performance scaling
2.5.5. Linear probing
2.5.6. Ablation study
2.5.7. Interpretation
2.6. Conclusion
2.7. Limitation and future work
3. 知识补充
3.1. Inter task and external evaluation
4. Reference
1. 心得
(1)好神奇的mask方式~~为什么大家都在mask~~
(2)ROI feature:预测之前我也不知道自己能预测这么多东西
2. 论文逐段精读
2.1. Abstract
①They proposed brain dynamics foundation model with the Joint Embedding Predictive Architecture (JEPA)
②Tasks: demographic prediction, disease diagnosis/prognosis, and trait prediction
2.2. Introduction
①Existing problems: Many unmarked FMRI data have not been utilized
②Non neuronal interference significantly reduces the signal-to-noise ratio (SNR) of BOLD signals
③BrainLM considered ROI locations but did not consider functional parcellation
2.3. Related Work
(1)Task-specific Models for fMRI (state-of-the-art)
①Lists BrainNetCNN, BrainGNN,BNT, Swift
(2)The fMRI Foundation Model
①List BrainLM, BrainMass
2.4. Method
①Framework of BrainJEPA:
2.4.1. Brain Gradient Positioning
①Non-negative affinity matrix:
where and
are the FC features of ROI
and
②The gradients are dirived from diffusion map (a non-linear dimension reduction method):
where denotes diffusion matrix and
denotes diffusion operator,
is the degree matrix of
, and they set
to 0.5
③The diffusion map at time is calculated by:
where denotes the number of ROI,
denotes gradients,
denotes gradient matrix,
denotes eigenvalues,
denotes corresponding eigenvectors (gradients)
④"Here we estimated the eigenvalues at time t by dividing it by
to enhance robustness against noisy eigenvalues."(这啥玩意儿?为啥还要估计一下)
⑤Transform to
by trainable linear layer, where
denotes the dimension of embedding
⑥ is the predefined temporal positioning and is obtained by sine and cosine functions
⑦The final positional embedding:
⑧The top 3 gradients visualization:
2.4.2. Spatiotemporal Masking
(1)Observation
①The observation block is randomly sampled with
, where
along the ROI and
is in time step patches (within 10)
②They fed to observation encoder
to get:
where denotes the mask,
denotes the
-th patch
(2)Targets
①Random sampling might bring shortcut learning or relying on simple and repetitive mode, which heavily decrease the generalizability
②这个 非常神奇的地方出现了,我感觉作者是想说把observation的列划为蓝色,就是Cross-ROIs (α);行分为绿色,就是Cross-time(β);其他的区域就全是红色Double-cross(γ)。
然后这些区域的掩码被定义为 , where
。 最后再分别从这仨颜色区域里面挑出来
个blocks
(3)Overlapped sampling
①就是我上面说的采样策略。。。只在对应的行/列/剩下区域分别采:
②但是可能采重,比如从observation 左边一格开始往右采明显就采集到
了,作者觉得重叠就删了
(4)Training
①The loss is the average L2 distance between and their corresponding prediction:
where denotes positional embedding,
is the trained predictor
2.5. Experiments
2.5.1. Datasets
(1)Inter task
①Self-supervised pre-training dataset: large-scale public dataset – UK Biobank (UKB) with 40,162 subjects, they use 80% for pretraining and 20% for downstream measuremant
②The pretraining without [cls] token
③Pre-training epoch: 300
(2)External evaluation
①HCP-Aging: 656 subjects for traits prediction (Neuroticism and Flanker score) and demographics (age and sex)
②Alzheimer’s Disease Neuroimaging Initiative (ADNI): 189 subjects for NC and MCI classification, 100 cognitively normal subjects for amyloid positive v.s. negative classification
③Memory, Ageing and Cognition Centre (MACC): 539 subjects to classify NC and MCI
(3)Settings
①ROI: 450 with Schaefer-400 for cortical regions and Tian-Scale III for subcortical regions
②⭐Robust scaling: implemented by subtracting the median and dividing by the interquartile range, calculated across participants for each ROI
③Input time steps: 160
④TR: 0.7 for UKB and HCP-Aging (multiband high temporal resolution) and 2 for ADNI and MACC (single-band lower resolution)
⑤TR alignment: downsampling signal with temporal stride of 3, then align them all to nearly 2 seconds per step
⑥Downstream dataset split: 6/2/2 for training/val/test
2.5.2. Implementation details
①Utilizing ViT architectures for the observation encoder, target encoder, and predictor
②Self-attention: FlashAttention
③Observation encoder: ViT-Small (ViT-S) (22M), ViT-Base (ViT-B) (86M), and ViT-Large (ViT-L) (307M)
④Embedding dimension: 192 for ViT-S, 384 for ViT-B and ViT-L
⑤Depth: 6 for S and B, 12 for L
⑥Readout: average
2.5.3. Main results
①Internal task performance table on UKB 20% for age and sex prediction:
②External task of demographics and trait prediction on HCP-Aging:
③External tasks of brain disease diagnosis and prognosis on ADNI and MACC:
2.5.4. Performance scaling
①更大的模型会取得更好的效果,废话,不然来那么大干嘛:
2.5.5. Linear probing
①Fine tuning of BrainLM v.s. linear probing of Brain-JEPA:
2.5.6. Ablation study
①Method ablation:
②Epoch ablation:
2.5.7. Interpretation
①Attention distribution on 7 sub networks:
the attention socres of each ROI come from the mean of 10 patches, and average the ROI scores in each sub network and normalize them to get the final attention distribution
2.6. Conclusion
~
2.7. Limitation and future work
①Lack of bigger model testing
②More diverse data needed, especially for pretraining such as different ethnicity cohorts collected from various sites, scanning protocols, behavioral tasks, and disease groups
③They wanna further compare cortical and subcortical regions
④Multi-modality required
3. 知识补充
3.1. Inter task and external evaluation
(1)Inter-task(任务间相关性)
"Inter-task" 这个术语通常用于多任务学习(Multi-task Learning, MTL)中,指的是不同任务之间的关系或交互。多任务学习的目标是让一个模型能够同时处理多个任务,并且在不同任务之间共享知识。
Inter-task learning 是指在不同任务之间进行信息共享或互相影响。例如,在处理视觉任务时,一个任务可能涉及物体分类,另一个任务可能涉及物体检测。在这种情况下,模型的学习可以通过共享中间层表示或特征来促进不同任务之间的知识传递。
Inter-task interference 或 task interference 是指任务间的负面影响,如果一个任务学习的特征对另一个任务不利,可能会导致整体性能的下降。
(2)External Evaluation(外部评估)
"External evaluation" 通常指的是使用外部的标准、数据集或者基准来评估一个模型的性能,而不是使用模型训练时所依赖的内部评估方法。外部评估通常有以下几种含义:
外部数据集评估:模型的性能不仅仅在训练数据或验证数据上进行评估,而是使用外部的、未见过的数据集(通常来自不同的领域或任务)来测试模型的泛化能力。例如,在一个图像分类任务中,模型可能在多个不同的数据集(如CIFAR-10、ImageNet)上进行评估,以验证它的通用性。
基准测试评估:使用标准的评价指标或行业基准进行模型评估。这意味着,除了计算模型的精度、召回率、F1分数等,还可能与同行的研究或其他公开的技术进行对比。
领域外评估:模型可能在与训练任务或环境不完全相同的情况下进行评估,以测试模型的迁移能力和在未知条件下的鲁棒性。例如,训练一个语音识别系统可能会在一个特定的语言或方言上进行评估,但在不同的语音数据集(如不同的口音或噪音环境)上进行外部评估。
4. Reference
Dong, Z. et al. (2024) 'Brain-JEPA: Brain Dynamics Foundation Model with Gradient Positioning and Spatiotemporal Masking', NeurIPS 2024 Spotlight. doi: https://doi.org/10.48550/arXiv.2409.19407
作者:夏莉莉iy