[NeurIPS 2024]Brain-JEPA: Brain Dynamics Foundation Model with Gradient Positioning and Spatiotempo-

论文网址:[2409.19407] Brain-JEPA: Brain Dynamics Foundation Model with Gradient Positioning and Spatiotemporal Masking

论文代码:GitHub – Eric-LRL/Brain-JEPA: Official codebase for "Brain-JEPA: Brain Dynamics Foundation Model with Gradient Positioning and Spatiotemporal Masking" (NerIPS 2024, Spotlight).

英文是纯手打的!论文原文的summarizing and paraphrasing。可能会出现难以避免的拼写错误和语法错误,若有发现欢迎评论指正!文章偏向于笔记,谨慎食用

目录

1. 心得

2. 论文逐段精读

2.1. Abstract

2.2. Introduction

2.3. Related Work

2.4. Method

2.4.1. Brain Gradient Positioning

2.4.2. Spatiotemporal Masking

2.5. Experiments

2.5.1. Datasets

2.5.2. Implementation details

2.5.3. Main results

2.5.4. Performance scaling

2.5.5. Linear probing

2.5.6. Ablation study

2.5.7. Interpretation

2.6. Conclusion

2.7. Limitation and future work

3. 知识补充

3.1. Inter task and external evaluation

4. Reference


1. 心得

(1)好神奇的mask方式~~为什么大家都在mask~~

(2)ROI feature:预测之前我也不知道自己能预测这么多东西

2. 论文逐段精读

2.1. Abstract

        ①They proposed brain dynamics foundation model with the Joint Embedding Predictive Architecture (JEPA)

        ②Tasks: demographic prediction, disease diagnosis/prognosis, and trait prediction

2.2. Introduction

        ①Existing problems: Many unmarked FMRI data have not been utilized

        ②Non neuronal interference significantly reduces the signal-to-noise ratio (SNR) of BOLD signals

        ③BrainLM considered ROI locations but did not consider functional parcellation

2.3. Related Work

(1)Task-specific Models for fMRI (state-of-the-art)

        ①Lists BrainNetCNN, BrainGNN,BNT, Swift

(2)The fMRI Foundation Model

        ①List BrainLM, BrainMass

2.4. Method

        ①Framework of BrainJEPA:

2.4.1. Brain Gradient Positioning

        ①Non-negative affinity matrix:

\boldsymbol{A}(i,j)=1-\frac{1}{\pi}\mathrm{cos}^{-1}(\frac{\mathbf{c}_{i}\mathbf{c}_{j}^{T}}{\|\mathbf{c}_{i}\|\|\mathbf{c}_{j}\|})

where c_i and c_j are the FC features of ROI i and j

        ②The gradients are dirived from diffusion map (a non-linear dimension reduction method):

M_\delta=D^{-1}L_\delta,L_\delta=D^{-\frac{1}{\delta}}AD^{-\frac{1}{\delta}}

where M_\delta denotes diffusion matrix and L_\delta denotes diffusion operator, D is the degree matrix of A, and they set \delta to 0.5

        ③The diffusion map at time t is calculated by:

\Phi_{t}=[\lambda_{1}^{t}\psi_{1},\lambda_{2}^{t}\psi_{2},...,\lambda_{m}^{t}\psi_{m}] \in \mathbb{R}^{n \times m},G=[\psi_{1},\psi_{2},...,\psi_{m}]\in \mathbb{R}^{n \times m}

where n denotes the number of ROI, m denotes gradients, G denotes gradient matrix, \lambda_k denotes eigenvalues, \psi _k denotes corresponding eigenvectors (gradients)

        ④"Here we estimated the eigenvalues \lambda_k at time t by dividing it by 1-\lambda_k to enhance robustness against noisy eigenvalues."(这啥玩意儿?为啥还要估计一下)

        ⑤Transform G\in \mathbb{R}^{n \times m} to \hat{G}\in \mathbb{R}^{n \times d/2} by trainable linear layer, where d denotes the dimension of embedding

        ⑥T\in \mathbb{R}^{n \times d/2} is the predefined temporal positioning and is obtained by sine and cosine functions

        ⑦The final positional embedding: P=[T,\hat{G}]\in\mathbb{R}^{n\times d}

        ⑧The top 3 gradients visualization:

2.4.2. Spatiotemporal Masking

(1)Observation

        ①The observation block x is randomly sampled with \left \{ \eta ^o_R,\eta ^o_T\right \}, where \eta ^o_R along the ROI and \eta ^o_T is in time step patches (within 10)

        ②They fed x to observation encoder f_{\theta} to get:

s_x=\left \{ s_{x_j} \right \}_{j \in \mathcal{B}_x}

where \mathcal{B}_x denotes the mask, s_{x_j} denotes the j-th patch

(2)Targets

        ①Random sampling might bring shortcut learning or relying on simple and repetitive mode, which heavily decrease the generalizability

        ②这个 非常神奇的地方出现了,我感觉作者是想说把observation的列划为蓝色,就是Cross-ROIs (α);行分为绿色,就是Cross-time(β);其他的区域就全是红色Double-cross(γ)。

然后这些区域的掩码被定义为 \mathcal{B}^r_y , where r \in \left \{ \alpha, \beta, \gamma \right \}。 最后再分别从这仨颜色区域里面挑出来K个blocks

(3)Overlapped sampling

        ①就是我上面说的采样策略。。。只在对应的行/列/剩下区域分别采:

s^\alpha_y\sim \mathcal{B}_x\cup \mathcal{B}_y^\alpha,s^\beta_y\sim \mathcal{B}_x\cup \mathcal{B}_y^\beta,s^\gamma_y\sim \mathcal{B}_y^\gamma

        ②但是可能采重,比如从observation x左边一格开始往右采明显就采集到x了,作者觉得重叠就删了

(4)Training

        ①The loss is the average L2 distance between s^r_y and their corresponding prediction:

\mathcal{L}=\frac{1}{3K}\sum_{r}\left\|\hat{s}_{y}^{r}-s_{y}^{r}\right\|_{2}^{2},\hat{s}_{y}^{r}=g_{\phi}(s_{x}|P)

where P denotes positional embedding, g_\phi is the trained predictor 

2.5. Experiments

2.5.1. Datasets

(1)Inter task

        ①Self-supervised pre-training dataset: large-scale public dataset – UK Biobank (UKB) with 40,162 subjects, they use 80% for pretraining and 20% for downstream measuremant

        ②The pretraining without [cls] token

        ③Pre-training epoch: 300

(2)External evaluation

        ①HCP-Aging: 656 subjects for traits prediction (Neuroticism and Flanker score) and demographics (age and sex)

        ②Alzheimer’s Disease Neuroimaging Initiative (ADNI): 189 subjects for NC and MCI classification, 100 cognitively normal subjects for amyloid positive v.s. negative classification

        ③Memory, Ageing and Cognition Centre (MACC): 539 subjects to classify NC and MCI

(3)Settings

        ①ROI: 450 with Schaefer-400 for cortical regions and Tian-Scale III for subcortical regions

        ②⭐Robust scaling: implemented by subtracting the median and dividing by the interquartile range, calculated across participants for each ROI

        ③Input time steps: 160

        ④TR: 0.7 for UKB and HCP-Aging (multiband high temporal resolution) and 2 for ADNI and MACC (single-band lower resolution)

        ⑤TR alignment: downsampling signal with temporal stride of 3, then align them all to nearly 2 seconds per step

        ⑥Downstream dataset split: 6/2/2 for training/val/test

2.5.2. Implementation details

        ①Utilizing ViT architectures for the observation encoder, target encoder, and predictor

        ②Self-attention: FlashAttention

        ③Observation encoder: ViT-Small (ViT-S) (22M), ViT-Base (ViT-B) (86M), and ViT-Large (ViT-L) (307M)

        ④Embedding dimension: 192 for ViT-S, 384 for ViT-B and ViT-L

        ⑤Depth: 6 for S and B, 12 for L

        ⑥Readout: average

2.5.3. Main results

        ①Internal task performance table on UKB 20% for age and sex prediction:

        ②External task of demographics and trait prediction on HCP-Aging:

        ③External tasks of brain disease diagnosis and prognosis on ADNI and MACC:

2.5.4. Performance scaling

        ①更大的模型会取得更好的效果,废话,不然来那么大干嘛:

2.5.5. Linear probing

        ①Fine tuning of BrainLM v.s. linear probing of Brain-JEPA:

2.5.6. Ablation study

        ①Method ablation:

        ②Epoch ablation:

2.5.7. Interpretation

        ①Attention distribution on 7 sub networks:

the attention socres of each ROI come from the mean of 10 patches, and average the ROI scores in each sub network and normalize them to get the final attention distribution

2.6. Conclusion

        ~

2.7. Limitation and future work

        ①Lack of bigger model testing

        ②More diverse data needed, especially for pretraining such as different ethnicity cohorts collected from various sites, scanning protocols, behavioral tasks, and disease groups

        ③They wanna further compare cortical and subcortical regions

        ④Multi-modality required

3. 知识补充

3.1. Inter task and external evaluation

(1)Inter-task(任务间相关性)

"Inter-task" 这个术语通常用于多任务学习(Multi-task Learning, MTL)中,指的是不同任务之间的关系或交互。多任务学习的目标是让一个模型能够同时处理多个任务,并且在不同任务之间共享知识。

Inter-task learning 是指在不同任务之间进行信息共享或互相影响。例如,在处理视觉任务时,一个任务可能涉及物体分类,另一个任务可能涉及物体检测。在这种情况下,模型的学习可以通过共享中间层表示或特征来促进不同任务之间的知识传递。

Inter-task interference 或 task interference 是指任务间的负面影响,如果一个任务学习的特征对另一个任务不利,可能会导致整体性能的下降。

(2)External Evaluation(外部评估)

"External evaluation" 通常指的是使用外部的标准、数据集或者基准来评估一个模型的性能,而不是使用模型训练时所依赖的内部评估方法。外部评估通常有以下几种含义:

外部数据集评估:模型的性能不仅仅在训练数据或验证数据上进行评估,而是使用外部的、未见过的数据集(通常来自不同的领域或任务)来测试模型的泛化能力。例如,在一个图像分类任务中,模型可能在多个不同的数据集(如CIFAR-10、ImageNet)上进行评估,以验证它的通用性。

基准测试评估:使用标准的评价指标或行业基准进行模型评估。这意味着,除了计算模型的精度、召回率、F1分数等,还可能与同行的研究或其他公开的技术进行对比。

领域外评估:模型可能在与训练任务或环境不完全相同的情况下进行评估,以测试模型的迁移能力和在未知条件下的鲁棒性。例如,训练一个语音识别系统可能会在一个特定的语言或方言上进行评估,但在不同的语音数据集(如不同的口音或噪音环境)上进行外部评估。

4. Reference

Dong, Z. et al. (2024) 'Brain-JEPA: Brain Dynamics Foundation Model with Gradient Positioning and Spatiotemporal Masking', NeurIPS 2024 Spotlight. doi: https://doi.org/10.48550/arXiv.2409.19407

作者:夏莉莉iy

物联沃分享整理
物联沃-IOTWORD物联网 » [NeurIPS 2024]Brain-JEPA: Brain Dynamics Foundation Model with Gradient Positioning and Spatiotempo-

发表回复