mne处理脑电数据基本使用

最近的东西,不定期更新

介绍

一般处理脑电数据的工具有matlab的EEGlab和python的相关包工具,当然事实上有很多类似的工具.本人现在只接触了这些.

EEGLAB的使用这里不详细说了,比较方便的是其有方便的GUI,当然后面我也发现python也有相关的带GUI的EEG数据处理包.

这里主要介绍MNE同时也会搭配一些机器学习算法等.MNE官网)

一些概念介绍)

实战

数据预处理

主要包括设置电极位置,重参考,降采样,滤波,使用ICA方法去除伪迹等.

可以使用Matlab或直接使用mne进行处理

划分数据集

将数据集划分话训练集和测试集,主要使用sklearn工具包

可以直接使用train_test_split分割训练集,就是简单的按比例分割,可以先打乱

1
2
3
4
5
6
7
8
9
10
11
12
13
import numpy as np
from sklearn.model_selection import train_test_split

#创建一个数据集X和相应的标签y,X中样本数目为100
X, y = np.arange(200).reshape((100, 2)), range(100)

#用train_test_split函数划分出训练集和测试集,测试集占比0.33
X_train, X_test, y_train, y_test = train_test_split( X, y, test_size=0.33, random_state=42)

#打印出原始样本集、训练集和测试集的数目
print("The length of original data X is:", X.shape[0])
print("The length of train Data is:", X_train.shape[0])
print("The length of test Data is:", X_test.shape[0])

但是一般用交叉验证划分比较多,比如KFoldShuffleSplit.

这类包一般在sklearn.model_selection中,使用KFold等处理一个对象的运动想象任务的数据集划分,使用GroupKFold等处理对象间数据集划分. 也就是说可以使用GoupKFold,LeaveOneGoupOut处理跨对象的数据.

同时搭配GridSearchCV用于调整超参数,获得较好的模型.

多域分析

脑电信号可以涉及到多个域的特征提取

脑电EEG常用的特征

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
from pywt import wavedec, waverec
from scipy import signal
import warnings
import matplotlib.pyplot as plt
import pandas as pd
import seaborn as sns
from mne.decoding import CSP
from sklearn.discriminant_analysis import LinearDiscriminantAnalysis as LDA
from sklearn.svm import SVC
from sklearn.neighbors import KNeighborsClassifier
from sklearn.pipeline import make_pipeline
import warnings
import matplotlib.pyplot as plt
import pandas as pd
import seaborn as sns
from mne.decoding import CSP
from sklearn.discriminant_analysis import LinearDiscriminantAnalysis as LDA
from sklearn.svm import SVC
from sklearn.neighbors import KNeighborsClassifier
from sklearn.pipeline import make_pipeline

import moabb
from moabb.evaluations import WithinSessionEvaluation
from moabb.paradigms import LeftRightImagery
from moabb.datasets import BNCI2014004
from moabb.paradigms import LeftRightImagery
import matplotlib
import numpy as np
import pywt
from matplotlib import pyplot as plt
from mne.channels import make_standard_montage
from mne.datasets import eegbci
from mne.preprocessing import ICA
from mne_features.feature_extraction import FeatureExtractor, extract_features
from scipy.signal import welch
import matplotlib
from sklearn.decomposition import PCA, FastICA
from sklearn.linear_model import LogisticRegression
from sklearn.model_selection import KFold, StratifiedKFold, ShuffleSplit, LeaveOneGroupOut
from sklearn.pipeline import make_pipeline, Pipeline
import moabb
from moabb.datasets import PhysionetMI
from moabb.evaluations import WithinSessionEvaluation, CrossSubjectEvaluation
from moabb.paradigms import LeftRightImagery, MotorImagery
import warnings
from mne.decoding import CSP, UnsupervisedSpatialFilter, SlidingEstimator, cross_val_multiscore
from sklearn.discriminant_analysis import LinearDiscriminantAnalysis as LDA, LinearDiscriminantAnalysis
from mne.io import read_raw_edf
import mne
import seaborn as sns
from mne import events_from_annotations, concatenate_raws, pick_types, Epochs
from sklearn.model_selection import cross_val_score
from sklearn.preprocessing import LabelEncoder, StandardScaler
from mne_features.univariate import compute_energy_freq_bands, compute_pow_freq_bands, compute_mean
from scipy.stats import differential_entropy
from sklearn.tree import DecisionTreeClassifier
from spectrum import burg, arburg
from sklearn.feature_selection import SelectKBest, f_classif, chi2, mutual_info_classif
from sklearn.pipeline import make_pipeline, Pipeline
from mne.decoding import UnsupervisedSpatialFilter, CSP
from sklearn.decomposition import PCA
from mne_features.univariate import compute_spect_entropy
from sklearn.model_selection import GridSearchCV
from sklearn.svm import SVC
from sklearn.neighbors import KNeighborsClassifier
import matplotlib.pyplot as plt
import math
from moabb.paradigms import FilterBankLeftRightImagery
from moabb.pipelines.utils import FilterBank
import moabb
from moabb.evaluations import WithinSessionEvaluation
from moabb.paradigms import LeftRightImagery
from moabb.datasets import BNCI2014004
from moabb.paradigms import LeftRightImagery
import matplotlib
import numpy as np
import pywt
from sklearn.ensemble import RandomForestClassifier
from sklearn.preprocessing import FunctionTransformer
from matplotlib import pyplot as plt
from mne.channels import make_standard_montage
from mne.datasets import eegbci
from mne.preprocessing import ICA
from mne_features.feature_extraction import FeatureExtractor, extract_features, FeatureFunctionTransformer
from scipy.signal import welch
import matplotlib
from sklearn.decomposition import PCA, FastICA
from sklearn.linear_model import LogisticRegression
from sklearn.model_selection import KFold, StratifiedKFold, ShuffleSplit, LeaveOneGroupOut
from sklearn.pipeline import make_pipeline, Pipeline
import moabb
from moabb.datasets import PhysionetMI
from moabb.evaluations import WithinSessionEvaluation, CrossSubjectEvaluation
from moabb.paradigms import LeftRightImagery, MotorImagery
import warnings
from mne.decoding import CSP, UnsupervisedSpatialFilter, SlidingEstimator, cross_val_multiscore
from sklearn.discriminant_analysis import LinearDiscriminantAnalysis as LDA, LinearDiscriminantAnalysis
from mne.io import read_raw_edf
import mne
import seaborn as sns
from mne import events_from_annotations, concatenate_raws, pick_types, Epochs
from sklearn.model_selection import cross_val_score
from sklearn.preprocessing import LabelEncoder, StandardScaler
from mne_features.univariate import compute_energy_freq_bands, compute_pow_freq_bands, compute_mean
from scipy.stats import differential_entropy
from spectrum import burg, arburg
from sklearn.feature_selection import SelectKBest, f_classif, chi2, mutual_info_classif
from sklearn.pipeline import make_pipeline, Pipeline
from mne.decoding import UnsupervisedSpatialFilter, CSP
from sklearn.decomposition import PCA
from mne_features.univariate import compute_spect_entropy
from sklearn.model_selection import GridSearchCV
from sklearn.svm import SVC
from sklearn.neighbors import KNeighborsClassifier
import matplotlib.pyplot as plt
import math
from sklearn.preprocessing import FunctionTransformer
from mne_features.feature_extraction import FeatureExtractor, extract_features
from moabb.paradigms import FilterBankLeftRightImagery
from moabb.pipelines.utils import FilterBank
import copy
import matplotlib.pyplot as plt
import matplotlib as mpl
from mne import time_frequency

# def csp_params(data):
# csp_ins = CSP(n_components=8)
# X_new = csp_ins.fit_transform(data,label)
# return X_new
"""
bci competition IV 2b 分类结果

"""


def plot_sub_result(acc_lda_scores, acc_knn_scores, acc_svm_scores):
barWidth = 0.3

acc_lda_scores.append(np.mean(acc_lda_scores))
acc_knn_scores.append(np.mean(acc_knn_scores))
acc_svm_scores.append(np.mean(acc_svm_scores))

r1 = np.arange(len(acc_lda_scores))
r2 = [x + barWidth for x in r1]
r3 = [x + barWidth for x in r2]
plt.bar(r1, acc_lda_scores, width=barWidth, edgecolor='white', label='LDA')
plt.bar(r2, acc_knn_scores, width=barWidth, edgecolor='white', label='KNN')
plt.bar(r3, acc_svm_scores, width=barWidth, edgecolor='white', label='SVM')
plt.xticks([r + barWidth for r in range(len(acc_lda_scores))],
['sub' + str(i) for i in range(1, len(acc_lda_scores) + 1)] + ['平均'])
# y轴
plt.ylabel("准确率")
# 创建图例

plt.title("频域分类结果对比")
# 展示图片
plt.show()


def ar_coeff(data):
"""
return AR coefficients
:param data:
:return:
"""
model_order = 8
features = [arburg(data[i], order=model_order)[0].real for i in range(data.shape[0])]
# np.zeros((data.shape[0], model_order))
# for i in range(data.shape[0]):
# AR, rho, ref = arburg(data[i], model_order)
# features[i] = AR.real
features = np.mean(features, axis=0)

return features


def differ_en(data):
# compute differential entropy
freqs, psd = welch(data, 160, nperseg=320, )
idx_alpha = np.logical_and(freqs >= 1, freqs <= 45)
alpha_psd = psd[:, idx_alpha]
variance = np.var(alpha_psd, ddof=1, axis=-1) # 计算指定数据(数组元素)的方差 ,ddof=1 除以(n-1),ddof=0代表除以n,默认值为0
result = np.log(2 * math.pi * math.e * variance) / 2
return result


def dwt_coeff(data):
"""
dwt coefficients
:param data:
:return:
"""
cA4, cD4, cD3, _, _ = pywt.wavedec(data, 'db4', level=4)
cA4_mean = np.mean(cA4, axis=1)
cA_var = np.var(cA4, axis=-1)
cD4_mean = np.mean(cD4, axis=1)
cD4_var = np.var(cD4, axis=-1)
cD3_mean = np.mean(cD3, axis=1)
cD3_var = np.var(cD3, axis=-1)
feature = np.concatenate((cA4_mean, cA_var, cD4_mean, cD4_var, cD3_mean, cD3_var), axis=-1)
return feature


def wavepack_coeff_energy(data):
n = 3
wp = pywt.WaveletPacket(data, wavelet='db4', maxlevel=n)
re = [] # 第n层所有节点的分解系数
for i in [node.path for node in wp.get_level(n, 'freq')]:
re.append(wp[i].data)
# 第n层能量特征
energy = []
for i in re:
energy.append(pow(np.linalg.norm(i, axis=-1), 2))
D1 = energy[0] / np.sum(energy)
D2 = energy[1] / np.sum(energy)
D3 = energy[2] / np.sum(energy)
D1_mean = np.mean(wp['aaa'].data, axis=-1)
D2_mean = np.mean(wp['aad'].data, axis=-1)
D3_mean = np.mean(wp['add'].data, axis=-1)
D1_var = np.var(wp['aaa'].data, axis=-1)
D2_var = np.var(wp['aad'].data, axis=-1)
D3_var = np.var(wp['add'].data, axis=-1)
features = np.concatenate((D1, D2, D3, D1_mean, D2_mean, D3_mean, D1_var, D2_var, D3_var), axis=-1)
return features


def dwt_coeff(data):
"""
dwt coefficients
:param data:
:return:
"""
cA4, cD4, cD3, _, _ = pywt.wavedec(data, 'db4', level=4)
cA4_mean = np.mean(cA4, axis=-1)
cA_var = np.var(cA4, axis=-1)
cD4_mean = np.mean(cD4, axis=-1)
cD4_var = np.var(cD4, axis=-1)
cD3_mean = np.mean(cD3, axis=-1)
cD3_var = np.var(cD3, axis=-1)
feature = np.concatenate((cA4_mean, cA_var, cD4_mean, cD4_var, cD3_mean, cD3_var), axis=-1)
return feature


def delta(data):
b, a = signal.butter(8, [4 / 160 * 2, 8 / 160 * 2], btype='bandpass')
filteredData = signal.filtfilt(b, a, data)
return filteredData


def process_bci(data):
# session_condition = (data["session"]=="session_2")|(data["session"]=="session_3")|(data["session"]=="session_4")
scores = []
for i in range(1, 109):
conditon = (data["subject"] == str(i))
print(data[conditon]["score"])
score = np.max(data[conditon]["score"])
if score >= 0.74:
scores.append(score)
if len(scores) == 20:
break
return scores


def reconstruct_dwt(data):
coeffs = wavedec(data, 'db4', level=4)
newdata = waverec(coeffs, 'db4')
return newdata


def psd_welch(data):
psd, fres = time_frequency.psd_array_welch(data, sfreq=160, fmin=0.1, fmax=60, n_fft=160, n_overlap=80)
print(psd.shape)
return psd


def wavelet_packet(x):
print(x.shape)
mother_wavelet = 'db4'
wp = pywt.WaveletPacket(data=x, wavelet=mother_wavelet, mode='symmetric', maxlevel=4)
node_name_list = [node.path for node in wp.get_level(4, 'natural')]
rec_results = []
for i in node_name_list:
new_wp = pywt.WaveletPacket(data=x, wavelet=mother_wavelet, mode='symmetric')
new_wp[i] = wp[i].data
x_i = new_wp.reconstruct(update=True)
rec_results.append(x_i)
output = np.array(rec_results)
output = np.reshape(output, (x.shape[0], -1))
print(output.shape)
return output


def reshapeIN(data):
print('expand')
print(data.shape)
## 首先将inf值都替换为nan
data[np.isinf(data) | np.isnan(data)] = 0.01
# newdata = np.reshape(data,(-1,3,data.shape[-1]))
# newdata = np.reshape(data, (-1, 21, data.shape[-1]))
newdata = np.reshape(data, (-1, 14, data.shape[-1]))
return newdata


def reshapeDE(data):
print("shrink")
print(data.shape)
## 首先将inf值都替换为nan
data[np.isinf(data) | np.isnan(data)] = 0.01
# newdata = np.reshape(data,(-1,3,data.shape[-1]))
# newdata = np.reshape(data, (-1, 21, data.shape[-1]))
newdata = np.reshape(data, (data.shape[0], -1))
return newdata


def dp(data):
print(data.shape)
b, a = signal.butter(8, [0.1 / 160 * 2, 4 / 160 * 2], btype='bandpass')
delta_dataset = signal.filtfilt(b, a, data)
b, a = signal.butter(8, [4 / 160 * 2, 7 / 160 * 2], btype='bandpass')
theta_dataset = signal.filtfilt(b, a, data)
b, a = signal.butter(8, [7 / 160 * 2, 18 / 160 * 2], btype='bandpass')
u_dataset = signal.filtfilt(b, a, data)
b, a = signal.butter(8, [16 / 160 * 2, 24 / 160 * 2], btype='bandpass')
beta_Data = signal.filtfilt(b, a, data)
b, a = signal.butter(8, [30 / 160 * 2, 50 / 160 * 2], btype='bandpass')
gamma_Data = signal.filtfilt(b, a, data)
newData = np.concatenate((data, u_dataset, beta_Data), axis=-1)
print(newData.shape)
return newData


def select_channel(data):
selected_channel = (8, 9, 10, 11, 12, 13,)
# selected_channel = (0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20)
selected_channel = (0, 2, 4, 6, 8, 12, 14, 20, 26, 33, 49, 51, 57, 63)
newdata = data[:, selected_channel, :]
return newdata


if __name__ == '__main__':
# 设置中文显示字体
matplotlib.rc("font", family='Microsoft YaHei')
matplotlib.rcParams['font.sans-serif'] = ['SimHei']
# feature extraction
timedomain_features = {'variance', 'hjorth_mobility', 'hjorth_complexity', 'skewness', 'kurtosis',
'zero_crossings', 'higuchi_fd', 'rms', 'ptp_amp'}

freqdomain_features = {('DE', differ_en), 'pow_freq_bands', 'energy_freq_bands'}
freqdomain_params = {
'pow_freq_bands__freq_bands': np.array([0.5, 4, 8, 15, 30, 45]),
'pow_freq_bands__ratios': 'all',
'energy_freq_bands__freq_bands': np.array([0.5, 4, 8, 14, 30, 45]),
}
timefredomain_features = {('dwt_coeff', dwt_coeff)}
timefredomain_features = {('dwt_coeff', reconstruct_dwt)}
delta_signal = {('delta', delta)}
dataset = PhysionetMI(imagined=True, executed=False)
sub_list = [i for i in range(1, 109)]
dataset.subject_list = sub_list
paradigm = LeftRightImagery()
print("metric", paradigm.scoring)
acc_lda_scores = []
acc_knn_scores = []
acc_svm_scores = []
evaluation = WithinSessionEvaluation(
paradigm=paradigm,
datasets=[dataset],
overwrite=True,
hdf5_path=None,
)
expand = FunctionTransformer(reshapeIN)
shrink = FunctionTransformer(reshapeDE)
downsample = FunctionTransformer(dp)
channel_selection = FunctionTransformer(select_channel)
myfeature = {('psd', psd_welch)}
fe = FeatureExtractor(160, selected_funcs=myfeature)
pipeline = make_pipeline(channel_selection, downsample, CSP(n_components=8),
LDA())
results_t = evaluation.process({"timedomain+lda": pipeline})
print(results_t)

# fe = FeatureExtractor(160, selected_funcs=freqdomain_features, params=freqdomain_params)
pipeline = make_pipeline(channel_selection, downsample, CSP(n_components=8),
KNeighborsClassifier())
results_k = evaluation.process({"freqdomain+lda": pipeline})

# fe = FeatureExtractor(160, selected_funcs=timefredomain_features)
pipeline = make_pipeline(channel_selection, downsample, CSP(n_components=8),
SVC())
results_s = evaluation.process({"timefreq+lda": pipeline})


模型评估

经过调研,准确率与Kappa系数常被用做BCI分类效果的评价指标.

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
# delta_signal = {('delta', delta)}
# fe = FeatureExtractor(250, selected_funcs=delta_signal)
# pipeline = make_pipeline(CSP(n_components=8), SVC())
# results_csp = evaluation.process({"CSP+lda": pipeline})

l1 = process_bci(results_t)
l2 = process_bci(results_k)
l3 = process_bci(results_s)
# l4 = process_bci(results_csp)
l1 = np.array(l1)
l2 = np.array(l2)
l3 = np.array(l3)
# l4 = np.array(l4)

l1 = np.append(l1, np.mean(l1))
l2 = np.append(l2, np.mean(l3))
l3 = np.append(l3, np.mean(l3))
# l4 = np.append(l4, np.mean(l4))
barWidth = 0.3
r1 = np.arange(len(l1))
r2 = [x + barWidth for x in r1]
r3 = [x + barWidth for x in r2]
# r4 = [x + barWidth for x in r3]
plt.tight_layout()
plt.bar(r1, l1, width=barWidth, edgecolor='white', label='LDA')
plt.bar(r2, l2, width=barWidth, edgecolor='white', label='KNN')
plt.bar(r3, l3, width=barWidth, edgecolor='white', label='SVC')
# plt.bar(r4, l4, width=barWidth, edgecolor='white', label='CSP')
plt.xticks([r + barWidth for r in range(len(l1))], ['s' + str(i) for i in range(1, len(l1))] + ['平均'])
# y轴
plt.ylabel("准确率")
# 创建图例
plt.legend(loc="upper center")
# 标题
plt.title("TCSSF特征提取对比")
# 展示图片
plt.show()
print(l1)
print(l2)
print(l3)
# print(l4)

参考资料

  1. Documentation overview — MNE 1.3.1 documentation

  2. ZitongLu1996/Python-EEG-Handbook: Python脑电数据处理中文手册 - A Chinese handbook for EEG data analysis based on Python (github.com)

    有用的Python库

    1. mne
    2. mne-featuresExtract features from MEG time series for classification — mne_features 0.2 documentation
    3. eeglibXiul109/eeglib: A library with tools for EEG analysis (github.com)
    4. pyeegpyeeg/pyeeg at master · forrestbao/pyeeg (github.com)
    5. scikit-learn

Github相关项目

  1. Feature-Extraction-EEG/Feature Extraction.ipynb at master · JoyRabha/Feature-Extraction-EEG (github.com)特征提取的库
  2. https://github.com/dasdarin/eeg-features
  3. eeg_analysis_thesis/feature_extraction at master · mcguegi/eeg_analysis_thesis (github.com)
  4. Seizure_FE/FE.py at master · yangsh827/Seizure_FE (github.com)
  5. omerfbhatti/BCI-Project: Brain Computer Interface Project as part of “Programming for DSAI” course at AIT (github.com)
  6. EEGExtract/EEGExtract.py at main · sari-saba-sadiya/EEGExtract (github.com)
  7. vancleys/EEGFeatures: EEG Features to be extract from raw data. (github.com)
  8. EEG-ANALYSIS-FOR-MOTOR-IMAGERY-APPLICATION/CSP_features_Classification.ipynb at master · prashanth-prakash/EEG-ANALYSIS-FOR-MOTOR-IMAGERY-APPLICATION (github.com)
  9. MEDHA-TIWARI/Feature-extraction-technique-for-EEG-signals: feature extraction from eeg datasets (github.com)
  10. wmichalska/EEG-emotions: Application prepares data to learning process. Including preprocessing, cleaning, reformating, feature extraction using PyEEG library and learning using Sklearn tool. (github.com)
  11. jesus-333/FBCSP-Python: Python implemementation of the FBCSP algorithm (github.com)
  12. motor-imagery/README.md at master · mauricio-ms/motor-imagery · GitHub
  13. FBCSP/FBCSP.py at master · stupiddogger/FBCSP · GitHub FBCSP
  14. TNTLFreiburg/fbcsp (github.com) FB
  15. NeuroTechX/moabb: Mother of All BCI Benchmarks (github.com)

数据集

  1. BCI Competition IV-1: http://www.bbci.de/competition/iv/#dataset1
  1. BCI Competition IV-2a: http://www.bbci.de/competition/iv/#dataset2a
  1. BCI Competition IV-2b: http://www.bbci.de/competition/iv/#dataset2b
  2. Motor Movement/Imagery Dataset: https://www.physionet.org/physiobank/database/eegmmidb/

-------------本文结束感谢您的阅读-------------
感谢阅读.

欢迎关注我的其它发布渠道