Create and fit Model for denoising EOG artifact

This example shows how to load create a EOGDenoiser instance and fit it using a BaseRaw instance with EEG and eyetracking channels..

Import the necessary packages

import mne
import matplotlib.pyplot as plt
from eoglearn.datasets import read_mne_eyetracking_raw
from eoglearn.models import EOGDenoiser

Load the data

raw = read_mne_eyetracking_raw()
Using default location ~/mne_data for eyelink...

  0%|                                               | 0.00/112M [00:00<?, ?B/s]
  4%|█▌                                    | 4.46M/112M [00:00<00:02, 44.6MB/s]
  8%|███▏                                  | 9.36M/112M [00:00<00:02, 47.2MB/s]
 14%|█████▍                                | 16.2M/112M [00:00<00:01, 56.4MB/s]
 19%|███████▍                              | 21.8M/112M [00:00<00:01, 56.2MB/s]
 25%|█████████▍                            | 27.8M/112M [00:00<00:01, 57.3MB/s]
 30%|███████████▍                          | 33.9M/112M [00:00<00:01, 58.5MB/s]
 35%|█████████████▍                        | 39.7M/112M [00:00<00:01, 55.7MB/s]
 40%|███████████████▎                      | 45.3M/112M [00:00<00:01, 54.9MB/s]
 45%|█████████████████▏                    | 50.8M/112M [00:00<00:01, 49.6MB/s]
 51%|███████████████████▏                  | 56.7M/112M [00:01<00:01, 52.3MB/s]
 56%|█████████████████████▎                | 62.8M/112M [00:01<00:00, 54.8MB/s]
 61%|███████████████████████▏              | 68.4M/112M [00:01<00:00, 53.2MB/s]
 66%|█████████████████████████             | 74.2M/112M [00:01<00:00, 54.5MB/s]
 71%|███████████████████████████           | 79.9M/112M [00:01<00:00, 55.2MB/s]
 77%|█████████████████████████████▏        | 86.3M/112M [00:01<00:00, 57.8MB/s]
 82%|███████████████████████████████▏      | 92.1M/112M [00:01<00:00, 56.8MB/s]
 88%|█████████████████████████████████▎    | 98.2M/112M [00:01<00:00, 58.1MB/s]
 93%|████████████████████████████████████▏  | 104M/112M [00:01<00:00, 58.2MB/s]
 98%|██████████████████████████████████████▏| 110M/112M [00:01<00:00, 55.1MB/s]
  0%|                                               | 0.00/112M [00:00<?, ?B/s]
100%|████████████████████████████████████████| 112M/112M [00:00<00:00, 475GB/s]
/home/docs/checkouts/readthedocs.org/user_builds/eoglearn/checkouts/latest/eoglearn/datasets/mne.py:47: RuntimeWarning: Setting non-standard config type: "MNE_DATASETS_EYELINK_PATH"
  data_path = mne.datasets.eyelink.data_path()
Attempting to create new mne-python configuration file:
/home/docs/.mne/mne-python.json
Download complete in 03s (107.0 MB)
Loading /home/docs/mne_data/MNE-eyelink-data/eeg-et/sub-01_task-plr_eyetrack.asc
Pixel coordinate data detected.Pass `scalings=dict(eyegaze=1e3)` when using plot method to make traces more legible.
Pupil-size area detected.
There are 2 recording blocks in this file. Times between blocks will be annotated with BAD_ACQ_SKIP.
Reading EGI MFF Header from /home/docs/mne_data/MNE-eyelink-data/eeg-et/sub-01_task-plr_eeg.mff...
    Reading events ...
    Assembling measurement info ...
    Excluding events {} ...
Reading 0 ... 190020  =      0.000 ...   190.020 secs...
Filtering raw data in 1 contiguous segment
Setting up band-pass filter from 1 - 30 Hz

FIR filter parameters
---------------------
Designing a one-pass, zero-phase, non-causal bandpass filter:
- Windowed time-domain design (firwin) method
- Hamming window with 0.0194 passband ripple and 53 dB stopband attenuation
- Lower passband edge: 1.00
- Lower transition bandwidth: 1.00 Hz (-6 dB cutoff frequency: 0.50 Hz)
- Upper passband edge: 30.00 Hz
- Upper transition bandwidth: 7.50 Hz (-6 dB cutoff frequency: 33.75 Hz)
- Filter length: 3301 samples (3.301 s)

[Parallel(n_jobs=1)]: Done  17 tasks      | elapsed:    0.1s
[Parallel(n_jobs=1)]: Done  71 tasks      | elapsed:    0.3s
16 events found on stim channel DIN
Event IDs: [2]
16 events found on stim channel DIN3
Event IDs: [3]
raw
General
Filename(s) sub-01_task-plr_eyetrack.asc
MNE object type RawEyelink
Measurement date 2023-06-29 at 20:07:21 UTC
Participant Unknown
Experimenter Unknown
Acquisition
Duration 00:02:23 (HH:MM:SS)
Sampling frequency 1000.00 Hz
Time points 142,900
Channels
EEG
misc
Stimulus
Eye-tracking (Gaze position)
Eye-tracking (Pupil size)
Head & sensor digitization 0 points
Filters
Highpass 0.00 Hz
Lowpass 500.00 Hz


Note

If you want eye tracking data in head-referenced-eye-angle (HREF) units, you can pass eyetrack_unit="href" to read_mne_eyetracking_raw().

Plot the data

raw.plot()
plot model
<MNEBrowseFigure size 800x800 with 4 Axes>

Warning

Currently, the EOGDenoiser expects the BaseRaw instance to be bandpass filtered between 1 and 30 Hz.

%% Create the model

eog_denoiser = EOGDenoiser(raw=raw, downsample=10)
eog_denoiser
/home/docs/checkouts/readthedocs.org/user_builds/eoglearn/envs/latest/lib/python3.11/site-packages/keras/src/layers/rnn/rnn.py:200: UserWarning: Do not pass an `input_shape`/`input_dim` argument to a layer. When using Sequential models, prefer using an `Input(shape)` object as the first layer in the model instead.
  super().__init__(**kwargs)
/home/docs/checkouts/readthedocs.org/user_builds/eoglearn/envs/latest/lib/python3.11/site-packages/mne/event.py:502: RuntimeWarning: invalid value encountered in cast
  data = data.astype(np.int64)
/home/docs/checkouts/readthedocs.org/user_builds/eoglearn/checkouts/latest/eoglearn/models/model.py:144: RuntimeWarning: Trigger channel contains negative values, using absolute value. If data were acquired on a Neuromag system with STI016 active, consider using uint_cast=True to work around an acquisition bug
  eeg_data.resample(self.downsampled_sfreq)
Removing orphaned onset at the end of the file.
16 events found on stim channel DIN
Event IDs: [4]
2 events found on stim channel TSYN
Event IDs: [1]
16 events found on stim channel plro
Event IDs: [2]
16 events found on stim channel DIN3
Event IDs: [3]
16 events found on stim channel DIN2
Event IDs: [4]
/home/docs/checkouts/readthedocs.org/user_builds/eoglearn/checkouts/latest/eoglearn/models/model.py:144: RuntimeWarning: Trigger channel contains negative values, using absolute value. If data were acquired on a Neuromag system with STI016 active, consider using uint_cast=True to work around an acquisition bug
  eeg_data.resample(self.downsampled_sfreq)
Removing orphaned onset at the end of the file.
16 events found on stim channel DIN
Event IDs: [4]
2 events found on stim channel TSYN
Event IDs: [1]
16 events found on stim channel plro
Event IDs: [2]
16 events found on stim channel DIN3
Event IDs: [3]
16 events found on stim channel DIN2
Event IDs: [4]
/home/docs/checkouts/readthedocs.org/user_builds/eoglearn/checkouts/latest/eoglearn/models/model.py:144: RuntimeWarning: Some events are duplicated in your different stim channels. 9 events were ignored during deduplication.
  eeg_data.resample(self.downsampled_sfreq)
/home/docs/checkouts/readthedocs.org/user_builds/eoglearn/checkouts/latest/eoglearn/models/model.py:144: RuntimeWarning: Resampling of the stim channels caused event information to become unreliable. Consider finding events on the original data and passing the event matrix as a parameter.
  eeg_data.resample(self.downsampled_sfreq)

<eoglearn.models.model.EOGDenoiser object at 0x7fa5a0787190>

Fit the model We will only use 10 epochs to speed up the example

eog_denoiser.fit_model(epochs=10)
history = eog_denoiser.model.history
Epoch 1/10
113/113 - 5s - 42ms/step - loss: 0.0360 - val_loss: 0.0495
Epoch 2/10
113/113 - 5s - 44ms/step - loss: 0.0335 - val_loss: 0.0485
Epoch 3/10
113/113 - 3s - 30ms/step - loss: 0.0320 - val_loss: 0.0490
Epoch 4/10
113/113 - 3s - 30ms/step - loss: 0.0310 - val_loss: 0.0497
Epoch 5/10
113/113 - 3s - 30ms/step - loss: 0.0308 - val_loss: 0.0495
Epoch 6/10
113/113 - 3s - 30ms/step - loss: 0.0314 - val_loss: 0.0489
Epoch 7/10
113/113 - 3s - 31ms/step - loss: 0.0304 - val_loss: 0.0492
Epoch 8/10
113/113 - 3s - 31ms/step - loss: 0.0323 - val_loss: 0.0488
Epoch 9/10
113/113 - 3s - 30ms/step - loss: 0.0311 - val_loss: 0.0486
Epoch 10/10
113/113 - 3s - 30ms/step - loss: 0.0307 - val_loss: 0.0491

display the training history

print(history.history["loss"])
print(history.history["val_loss"])
eog_denoiser.plot_loss()
plot model
[0.035970963537693024, 0.03347524628043175, 0.03204954043030739, 0.03100723773241043, 0.03076697513461113, 0.031354475766420364, 0.03044433891773224, 0.03230671212077141, 0.031070053577423096, 0.030741913244128227]
[0.049462851136922836, 0.0484994575381279, 0.04903633892536163, 0.04970748722553253, 0.04952564090490341, 0.04890536889433861, 0.049189724028110504, 0.0488036572933197, 0.048550933599472046, 0.049062252044677734]

<Figure size 640x480 with 1 Axes>

Plot a topomap of the predicted EOG artifact.

The plot below displays the predicted amount of EOG artifact for each EEG sensor. The output is as we would expect, with frontal sensors containing the most EOG artifact.

Percentage of EEG signal that is accounted for by Ocular Artifact
Denoising neural data, saving to ``denoised_neural_`` attribute.
Predicting EOG data, saving to ``predicted_eog_`` attribute.

1/5 ━━━━━━━━━━━━━━━━━━━━ 0s 226ms/step
4/5 ━━━━━━━━━━━━━━━━━━━━ 0s 20ms/step 
5/5 ━━━━━━━━━━━━━━━━━━━━ 0s 71ms/step
5/5 ━━━━━━━━━━━━━━━━━━━━ 1s 74ms/step
/home/docs/checkouts/readthedocs.org/user_builds/eoglearn/checkouts/latest/eoglearn/models/model.py:340: RuntimeWarning: divide by zero encountered in divide
  snr = (noise / signal)[:-1]

<Figure size 640x480 with 2 Axes>

Todo

Add a plot of the predicted EOG artifact for each EEG sensor over time. Add plots of the denoised EEG data.

Compare ERP between the original and “EOG-denoised” signals

Let’s create an averaged evoked response to the flash stimuli for both the original data and the “EOG-denoised” data. We’ll focus on the frontal EEG channels, since it is these will contain the most EOG in the original signal.

pred_raw = eog_denoiser.get_denoised_neural_raw()
events, event_id = mne.events_from_annotations(pred_raw, regexp="Flash")
pred_epochs = mne.Epochs(
    pred_raw, events=events, event_id=event_id, tmin=-0.3, tmax=3, preload=True
)

events, event_id = mne.events_from_annotations(eog_denoiser.raw, regexp="Flash")
original_epochs = mne.Epochs(
    eog_denoiser.raw, events=events, event_id=event_id, tmin=-0.3, tmax=3, preload=True
)

frontal = ["E19", "E11", "E4", "E12", "E5"]
pred_avg_frontal = pred_epochs.average().get_data(picks=frontal).mean(0)
original_avg_frontal = original_epochs.average().get_data(picks=frontal).mean(0)

ax = plt.subplot()
ax.plot(pred_epochs.times, pred_avg_frontal, label="predicted")
ax.plot(original_epochs.times, original_avg_frontal, label="original")
ax.set_xlim(-0.3, 1)
ax.legend()
plot model
/home/docs/checkouts/readthedocs.org/user_builds/eoglearn/envs/latest/lib/python3.11/site-packages/mne/event.py:502: RuntimeWarning: invalid value encountered in cast
  data = data.astype(np.int64)
/home/docs/checkouts/readthedocs.org/user_builds/eoglearn/checkouts/latest/eoglearn/models/model.py:144: RuntimeWarning: Trigger channel contains negative values, using absolute value. If data were acquired on a Neuromag system with STI016 active, consider using uint_cast=True to work around an acquisition bug
  eeg_data.resample(self.downsampled_sfreq)
Removing orphaned onset at the end of the file.
16 events found on stim channel DIN
Event IDs: [4]
2 events found on stim channel TSYN
Event IDs: [1]
16 events found on stim channel plro
Event IDs: [2]
16 events found on stim channel DIN3
Event IDs: [3]
16 events found on stim channel DIN2
Event IDs: [4]
/home/docs/checkouts/readthedocs.org/user_builds/eoglearn/checkouts/latest/eoglearn/models/model.py:144: RuntimeWarning: Trigger channel contains negative values, using absolute value. If data were acquired on a Neuromag system with STI016 active, consider using uint_cast=True to work around an acquisition bug
  eeg_data.resample(self.downsampled_sfreq)
Removing orphaned onset at the end of the file.
16 events found on stim channel DIN
Event IDs: [4]
2 events found on stim channel TSYN
Event IDs: [1]
16 events found on stim channel plro
Event IDs: [2]
16 events found on stim channel DIN3
Event IDs: [3]
16 events found on stim channel DIN2
Event IDs: [4]
/home/docs/checkouts/readthedocs.org/user_builds/eoglearn/checkouts/latest/eoglearn/models/model.py:144: RuntimeWarning: Some events are duplicated in your different stim channels. 9 events were ignored during deduplication.
  eeg_data.resample(self.downsampled_sfreq)
/home/docs/checkouts/readthedocs.org/user_builds/eoglearn/checkouts/latest/eoglearn/models/model.py:144: RuntimeWarning: Resampling of the stim channels caused event information to become unreliable. Consider finding events on the original data and passing the event matrix as a parameter.
  eeg_data.resample(self.downsampled_sfreq)
Creating RawArray with float64 data, n_channels=129, n_times=14200
    Range : 0 ... 14199 =      0.000 ...   141.990 secs
Ready.
Used Annotations descriptions: [np.str_('Flash')]
Not setting metadata
16 matching events found
Setting baseline interval to [-0.3, 0.0] s
Applying baseline correction (mode: mean)
0 projection items activated
Using data from preloaded Raw for 16 events and 331 original time points ...
0 bad epochs dropped
Used Annotations descriptions: [np.str_('Flash')]
Not setting metadata
16 matching events found
Setting baseline interval to [-0.3, 0.0] s
Applying baseline correction (mode: mean)
0 projection items activated
Using data from preloaded Raw for 16 events and 3301 original time points ...
0 bad epochs dropped

<matplotlib.legend.Legend object at 0x7fa5b07b7d10>

Total running time of the script: (0 minutes 50.269 seconds)

Gallery generated by Sphinx-Gallery