Decoding/MVPA - ML

3 minute read

Published:

  • 注:本文基于Bae 2018年J. Neurosci上的这篇Dissociable Decoding of Spatial Attention and Working Memory from EEG Oscillations and Sustained Potentials,强烈建议阅读原文。
  • 本文是仅针对懂EEG和fMRI实验并有处理经验、能看懂code、但无decoding经验的读者撰写的说明书;如无相关经验,不建议阅读;转载请注明出处

#结尾有彩蛋

Decoding在做什么、怎么做?

人们在看不同类别图片的时候,大脑是有不同的反应模式的。如果大脑encode了这种刺激,两个种类的刺激的活动应该是不一样的。如果有一个很好的classifier,那么这两种不一样的活动就会被分类出来(所以这里我们可以看出来,decoding本质上就是在做representations)。比如我1个session 8个run,就可以training前7个run去test最后1个run。Decoding似乎fMRI用的比较早(不确定,印象中是这样),那时候还叫MVPA。这俩本质上是一个东西。那个V就是voxel的意思,到EEG中就变成了time course。如果EEG不叫decoding,那就得叫MTPA。仅此而已。

Interim Summary:decoding的核心原理就是我有一堆data。这堆data是二元(或者多元也行)可分的。我可以从中分出一大部分,用一个很好的classifier去training data怎么去分类,然后用trained data去test剩下的data。就是这么简单。

下一个问题就是怎么做decoding?

首先我们loading data:


import mne
epochs = mne.read_epochs('preprocessed_epochs_1.fif')
t_list = epochs.times
labels = epochs.events[:,2]
subj_data = epochs.get_data()
# data structure: 640 trials * 27 channels * 750 times

Bae这一篇是每个trial有750个time points,每个被试做了640个trials,decode 27个channels。Decoding可以ERP based也可以band power based。Bae两个都做了,我这里只展示ERP based,因为步骤都一样。我们先对已经loaded的subject 1做decoding。这里插播一个步骤crop data,也就是选取特定的time window。这一步不是强制的,但非常推荐——尤其是当你marginally显著的时候。为啥?因为我们是针对time point解码。现在,我们手上的数据是250 Hz的,也就是1个时间点对应0.004”,750个time points就是3”。Bae文章中写了是onset前后1.5”。假设我们只关注1”前后,我们可以选0到1.5这个范围。相当于把baseline给剔除了。我们不做crop的话,就是把这3”全喂进去decode,效果肯定不如只做0-1.5这一段的。但话说回来,如果你差很多,这个操作没什么意义。数据这玩意,该显著怎么做都会显著,不显著的怎么做都不显著


import munpy as np
# crop data
t0,t1 = 0.0,1.5
t0_indx,t1_indx = np.where(t_list>=t0)[0][0],np.where(t_list<=t1)[0][-1]

这样就剩375个time points了。我们就针对这375个time points做decoding。

接下来就是Bae decoding方案的第一步:down sample。什么是down sample?就是缩减time points。Bae是每5个time points平均一下。这里平均步长没有限制,你想更平滑就多平均几个。有文献算过,无论怎么average,最后结果该显著的还会显著,不会影响最终结果。如果你实在不放心,可以0-5平均一下,1-6平均一下,2-7平均一下……以此类推给自己一个心理安慰。毕竟这样你的time resolution没有降低,但结果平滑了。

为什么要resampling?其实这一步不做也不会对结果有什么实质性的影响。有影响的是:我们是用marker的方式对每个trial进行了对齐,当样本很低的时候,并不是每个trial(1/250Hz=0.004秒)都有用一个完全的对齐,所以取一个time window(5个time points平均)把没有完全对齐的损失抹去。而肉眼可见的影响是,down sample会让结果看起来很平滑。另外还可以节能,属实降本增效了(除非你为了心理安慰只做平滑,那样计算时间会暴增)。如果不resample的话,data画出来会很毛躁,一堆锯齿看着不好看而已。有兴趣可以自己画画。还是那句话,该显著怎么做都会显著,不显著的怎么做都不显著

我这里写了个sub function,也建议大家写code时候这么做,一个是看着清晰,另一方面方便重复利用——尤其是我现在是针对单个被试进行decoding。最后做统计检验肯定是group level,不可能单独被试自己去统计。


# down sample

t_space = 5
def resamp_t(data):
    data_avg_t = np.zeros([data.shape[0],data.shape[1],data.shape[2]])
    for t in range(t_points):
        data_avg_t[:,:,t] = np.average(data[:,:,t*t_space:t*t_space+t_space],axis=2)

    return data_avg_t

subj_data_resamp = resamp_t(subj_data)

接下来我们要降噪。如何降噪?average trials。这一步要分条件去做。之前有讲到,decoding是在做分类,我们需要告诉classifier哪些data是类别1的,哪些data是类别2的……以此类推。我们不可能把不同类别的trials平均到一起。以Bae的data为例,有16个orientation和16个location。我们手上的数据是16个orientation的,那么data structure就是16 orientations * 40 trials * 27 channels * 75 times。这一步是用来label的。这部分代码会和averge trials的放在一下,所以我们先继续往下看。

分类完毕后,就可以将trials每13个一平均了。也就是40个(实际上可以整除的是39个)trials分成3段,每段内部进行平均。为了让数据更稳定可靠,平均前最好将trials随机打乱一下。另外步长这个问题和down sample一样,取多少随你。按照Bae的做法,最后会得到16 orientations * 3 trials * 27 channels * 75 times这样一个data structure。


# average trials

condN = 16
trialN = 40
trial_space = 13
import math
trialN_final = math.floor(trialN/trial_space)
chans = 27
t_points = len(t_list)

def dict_to_arr(d):
    return np.array([item for item in d.values()])

def avg_trials(old_data):
    data_final,label_final,data,label = \
        dict(),dict(),dict(),dict()

    for n in range(condN):
        n_indx = np.where(labels==n)[0]
        data[n] = old_data[n_indx]
        # 40 trials * 27 channels * 75 times
        label[n] = labels[n_indx]
        # 40 labels

        # permutation 1st dimension
        permu_data = np.random.permutation(data[n])
        # 40 trials * 27 channels * 75 times
        '''
        permu_indx = np.random.permutation(trialN)
        permu_data = data[n][permu_indx,:,:]
        '''
        n_list = []
        for k in range(trialN_final):
            n_list.append(
                np.mean(permu_data[k*trial_space:k*trial_space+trial_space],axis=0))
        data_final[n] = np.array(n_list)
        label_final[n] = np.repeat(n,k)

        # 3 trials * 27 channels * 75 times
    data_final = dict_to_arr(data_final)
    # 16 conditions * 3 trials * 27 channels * 75 times
    label_final = dict_to_arr(label_final)
    # 16 conditions * 3 labels

    data_final = np.reshape(
        data_final,[int(condN*trialN_final),chans,t_points])
    # 48 trials * 27 channels * 75 times
    label_final = np.reshape(label_final,[int(condN*trialN_final)])
    # 48 labels

    return data_final,label_final

至此,data预处理完成。我们可以正式开始decoding了。decoding最核心的原理就是把data拆成两个set,一个做traing,一个做testing。实际操作中,我们不可能只有两个data set,一般会用K-fold拆成几组——比如3折,然后针对每个time point用前两组test后一组;再重分3折,用后两组test前一组;最后分3折,用两边test中间那组。这样叫做cross validation,其实就是每次都重新拆分的LOOCV版本(cross validation有好多种,有兴趣可以查一下,一搜一大把)。每次test都会得到一个accuracy,得到3个ACC平均后就是decoding的正确率。最后得到75个time points的正确率。这里其实可以用AUC(个人常用)。3折是Bae原文操作,也是一般training的默认项(有时候是5)。个人一般用10折,结果基本就很稳定了。

这里需要注意一点,用K-fold一定用StratifiedKFold(不要用KFold这个function!!!),这样可以保证组别平衡。


from sklearn.model_selection import StratifiedKFold
from sklearn.preprocessing import StandardScaler
from sklearn.svm import SVC
from sklearn.metrics import roc_auc_score
import scipy.stats
from mne.stats import permutation_cluster_1samp_test, \
    permutation_cluster_test

fdN = 3
rpN = 10
n_permutations = 1000
chance_crit = 0.5
p_crit = 0.05

def deco_in(old_data):
    accs= np.zeros([rpN,t_points,fdN])
    for n in range(rpN):
        # average trials
        data_final,label_final = avg_trials(old_data)

        for t in range(t_points):
            data_t = data_final[:,:,t]
            # 48 trials * 27 channels
            seed_list = np.random.randint(t_points)
            kf = StratifiedKFold(n_splits=fdN,shuffle=True,random_state=seed_list)
            # shuffle=True if no randomly averaging trials

            fdN_indx = 0
            for train_indx,test_indx in kf.split(data_t,label_final):
                data_train,data_test,label_train,label_test = \
                data_t[train_indx],data_t[test_indx],label_final[train_indx],label_final[test_indx]

                # normalization
                scaler = StandardScaler()
                data_train = scaler.fit_transform(data_train)
                data_test = scaler.transform(data_test)
                svm = SVC(kernel='linear')
                svm.fit(data_train,label_train)
                acc = svm.score(data_test,label_test)
                accs[n,t,fdN_indx] = acc

                fdN_indx += 1

    acc_list = np.mean(accs,axis=(0,2))
    return acc_list

acc_list = deco_in(subj_data_resamp)

这是个within task的decoding,一般会比较推荐做cross-task。原理都是一样的,甚至更简单。


# across-task decoding

def deco_cx(task1,task2,labels1,labels2):
    acc_list = np.zeros([t_points])

    # decoding based on epoch
    for t in range(t_points):
        data_train = task1[:,:,t]
        data_test = task2[:,:,t]

        scaler = StandardScaler()
        x_train = scaler.fit_transform(data_train)
        x_test = scaler.transform(data_test)
        clf = SVC(kernel='rbf',class_weight='balanced',max_iter=rpN)
        clf.fit(x_train,labels1)
        pred = clf.predict(x_test)
        acc_list[t] = roc_auc_score(labels2,pred)
    return acc_list

最后一步,怎么看group-level的结果?我们现在做的是单被试的decoding。你可以写个for loop,这样把全部被试都decode一遍。30个被试的话每个time points就有30笔data,60个被试就60个data points,你有多少个被试就是多少个data points。然后我们就可以针对每个time point做1-sample t test,跟chance level比较。二分的话就是0.5了,Bae那篇是0.0625。推荐用cluster t test去做,我这里的code是2个分类的,所以那个chance level是0.5。


def find_sig(clu,clu_p):
    acc_sig,grp_sig = t_points*[0],t_points*[0]
    grp_label = 0

    for c,p in zip(clu,clu_p):
        if p<=p_crit:
            grp_label += 1
            acc_sig[c[0][0]:(c[0][-1]+1)] = \
                [1]*len(c[0])
            grp_sig[c[0][0]:(c[0][-1]+1)] = \
                [grp_label]*len(c[0])

    return acc_sig,grp_sig

def clu_permu_1samp_t(acc_data):
    threshold = None
    tail = 0
    degrees_of_freedom = len(acc_data) - 1
    t_thresh = scipy.stats.t.ppf(1-p_crit/2,df=degrees_of_freedom)

    t_obs,clu,clu_p,H0 = permutation_cluster_1samp_test(
        acc_data-chance_crit,n_permutations=n_permutations,
        threshold=t_thresh,tail=tail,out_type='indices',verbose=True)
    acc_sig,grp_sig = find_sig(clu,clu_p)

    return acc_sig,grp_sig

这样就可以得到最终结果了。acc_sig是一个label list,找哪几个index对应的values是1,那一段就是显著的time window。这么写是我data习惯存成dataframe。我会直接加一个column,后面画图比较好画。没别的意思。你也可以直接print time points,看哪个time windows显著。grp_sig是存有几个cluster的。有时候就一段time window显著,它的label就只有1,这个不重要。

以上就是整个decoding的教程。All done!

关于彩蛋

第一个彩蛋就其实MNE有几个专门decoding的functions,所以我们其实可以直接拿来用,不用在这搞人工智障了。同样within和cross的code我都放出来了:


# within-task decoding

def gen_decoding_in(X,y):
    clf = make_pipeline(StandardScaler(),LinearModel(LogisticRegression(solver='liblinear')))
    time_gen = GeneralizingEstimator(clf,n_jobs=jobN,scoring=scoring,verbose=True)

    scores = cross_val_multiscore(time_gen,X,y,cv=fdN,n_jobs=jobN)
    mean_score = np.mean(scores,axis=0)
    mean_score_diag = np.diag(mean_score)

    return scores,mean_score,mean_score_diag

# across-task decoding

def gen_decoding_cx(task1,task2,labels1,labels2):
    clf = make_pipeline(StandardScaler(),LogisticRegression(solver='liblinear'))
    time_gen = GeneralizingEstimator(clf,scoring='roc_auc',n_jobs=None,verbose=True)

    time_gen.fit(X=task1,y=labels1)
    scores = time_gen.score(X=task2,y=labels2)
    scores_diag = np.diag(scores)

    return scores,scores_diag

彩蛋2也是个教程。可以搜一下微信公众号“路同学”。先声明我不认识这个人,不是给他打广告!!!路同学有一篇就是讲如何decoding的全网独一无二基于Python逐行代码手把手实现与讲解的EEG Decoding视频教程,付费的,但不贵且值得。他也是基于Bae这篇讲的,真的是hand in hand。我是真的做不到那么耐心……就一暴躁死废宅……如果你看我这篇有点吃力,去看他的教程准没错。

最后,强烈建议去看Bae那篇原文,写的非常详细。我就是根据他那篇写出来的code——当年还是0基础去写的。所以你知道我为什么会反复强调一定一定一定去看原文了吧。

全文完结