Epilepsy

affects over 50 million individuals worldwide, making it one of the most common neurological disorders worldwide.

Manual detection of seizures consumes key resources from experienced medical experts in the epilepsy unit.

How can we create a better system?

Our Solution

Graph neural networks (GNNs) and self-supervised learning (SSL), with the goal of fully automating seizure detection.

Why GNNs?

GNNs model graph structures, which represent a collection of objects (nodes) and their relationships (edges).

GNNs are perfect for modelling the brain, given its many complex interconnections, whether between individual neurons or entire brain regions. GNNs allow us to harness the interconnectedness between different sites where we collect our data (e.g., EEG electrodes).

This is clearly depicted in recent work from our lab, Graph representations of iEEG data for seizure detection with graph neural networks, displaying a clear advantage over other deep learning methods such as convolutional neural networks (CNNs). For more details about GNNs, I’ll have several blog posts in the coming weeks.

The Data

We used an open-access OpenNeuro ds003029 dataset, featuring iEEG seizure recordings from 25 patients, collected across four U.S. epilepsy centers. In this context, iEEG recording is essential in epilepsy care units for monitoring seizures and investigating resective surgery options. Our models may also help in pinpointing the epileptogenic zone, serving as an ideal foundation for future clinical applications.

Self-Supervised Learning (SSL)

As mentioned in Self-Supervised Learning: An Introduction, model performance with supervised learning is limited by the availability of labelled data. We decided to combine SSL techniques to boost the current supervised GNN model used by our lab at the Krembil Brain Institute. Each SSL model we tested uses a GNN itself, so it can pass along its graph structure—more specifically a graph representation with node features and edge features—to the supervised model, as a “better starting point” for the input.

GNN Encoder

We used the following encoder for each SSL method:

Here $x$ denotes our input graph representation, the graph and its features. We then transform $x$ into a single vector $z = \text{Enc}(x)$, called the encoding.

For more details see below, or skip to the next section to see the models!

$x = (\textbf{X}, \textbf{E}, G)$ is the input graph representation, $G = (V,E)$ is the graph, $\textbf{X} = (\textbf{x}_1, \textbf{x}_2, \dots, \textbf{x}_{|V|})$ are the node features with $\textbf{x}_i \in \mathbb{R}^d$, $\textbf{E} = (\textbf{e}_{i,j})_{(i,j) \in E}$ are the edge features with $\textbf{e}_{i,j} \in \mathbb{R}^k$. This graph representation $x$ is fed into an Edge-Conditioned Convolution (ECC), given by the update:

$$\begin{aligned}\textbf{x}_i’ = \sum_{j \in \mathcal{N}(i)} F_{\theta}(\textbf{e}_{i,j})\textbf{x}_{j} \end{aligned}$$

where $F_{\theta}: \mathbb{R}^k \to \mathbb{R}^{d’ \times d}$ is a multilayer perceptron (MLP), which transforms each edge $\textbf{e}_{i,j}$ to a matrix that acts on each node feature in the neighbourhood of $\textbf{x}_i$, denoted $\mathcal{N}(i)$, including $\textbf{x}_i$ itself (self-loop). For our purposes, we will choose $d’ >d$ to “expand” the node features initially. The new graph representation, $x’ = (\textbf{X}’, \textbf{E}, G)$ is then fed into a Graph Attention (GAT) layer, given by:

$$\begin{aligned}\textbf{x}_i’ = g\bigg( \sum_{j \in \mathcal{N}(i)} \alpha_{ij} W \textbf{x}_j\bigg) \end{aligned}$$

where $\alpha_{ij}$ are the attention coefficients, which are learnable coefficients that essentially determine how important node $j$ is to node $i$, $W$ is a learnable weight matrix, and $g$ is some nonlinear activation function.

After applying ECC and GAT consecutively, we then aggregate the node features using a global mean pooling layer, and then flattening it to give our encoding.

The Models

SSL Models

We selected four SSL methods:

  1. Relative Positioning (RP)

  2. Temporal Shuffling (TS)

  3. Contrastive Predictive Coding (CPC)

  4. Variance-Invariance-Covariance-Reguralization (VICReg)

Each model used an the GNN decoder described above.

Relative Positioning

I’ll first give an overview of Relative Positioning (RP). Given two time windows in our EEG signal, we train our SSL model to distinguish when these two windows are close in time or far apart. In theory, if our SSL model is trained on this task, it should learn different representations of the data so that it can factor in these temporal differences, which is crucial for our application.

We’ll denote our windows of the signal as $x_t$ and $x_{t'}$ where $t$ and $t’$ represent the starting time indices. We then give a “pseudolabel” $y$ to the pair $(x_t, x_{t'})$, where $y = 1$ if the pair is close in time and $y = 0$ if the pair is far apart in time, with respect to some hyperparameters $\tau_{+}, \tau_{-}$:

$$y = \begin{cases} 1 &\text{if } |t - t'| \leq \tau_{+} \\ 0 &\text{if } |t - t'| > \tau_{-}\end{cases}$$

Here’s a diagram of the RP model:

The model first encodes $x_t$ and $x_{t'}$ using our GNN encoder, before applying a projection function (in this case, a 3-layer MLP), before contrasting with the function:

$$\text{Contr}(z_t, z_{t'}) = |z_t - z_{t'}| = c$$
where the absolute value is taken entrywise. Then $c$ is passed through plain old logistic regression to give our prediction $\hat{y}$, for which we compute our error with the Binary Cross Entropy (BCE) loss:

$$\mathcal{L}(y, \hat{y}) =- \Big[y \log \hat{y} + (1-y) \log(1 - \hat{y})\Big]$$

Temporal Shuffling

The Temporal Shuffling (TS) method is very similar to RP. Instead of labeling graph pairs, we have graph triplets. In this context we have three windows $x_t, x_{t'}, x_{t''}$ and our pretext task is to correctly predict whether these triplets are ordered in time or shuffled (i.e. unordered). Given hyperparameters $\tau_+$ and $\tau_-$, we require that the first window and last window are close enough in time for every triplet, that is $|t - t''| \leq \tau_+$ and in our experiments we’ll get rid of any possible permutations of the triplet by requiring $t > t''$. The midpoint between these two time indices is given by:

$$M = t + \frac{t''-t}{2}$$

Now dependent on the value $t'$ (the middle window), we generate the pseudolabel:

$$y(x_t, x_{t'}, x_{t''}) = \begin{cases} 1 &\text{if }t'\in (t,t'') \text{ or } t' \in (t'', t)\\0 &\text{if } |t'-M| > \tau_-\end{cases}$$

Hence, $y = 1$ if $t \leq t’ \leq t’’$, meaning that the triplet is temporally ordered and $y = 0$ if $t’$ is sufficiently far away from the midpoint between $t’$ and $t’’$, i.e. it is temporally shuffled. Here’s an overview of the model

Implementation

We mainly used the PyTorch Geometric (PyG) library for our implementation. For more details, please see the GitHub repository! We have several notebooks written for understanding PyG and our specific models.

Preliminary Results

After training both the RP and TS models for $\tau_+ = \text{12s}$ and $\tau_- = 90\text{s}$ on 115,000 data points each with a 70%/20%/10% split for training, validation, and testing. The GNN encoder was transferred to the supervised pipeline, and then retrained it compare against baseline supervised (non-pretrained) model using a single ECC and GAT layer. We obtain the following accuracies averaged all 26 patients seen on the right. So far we don’t see a major increase from using the pretrained encoder in RP or TS. This could be due to several reasons, here are several hypotheses:

  • The hyperparameters $\tau_+$ and $\tau_-$ are not properly configured for our task.

  • The task of detecting whether a seizure occurs is not complicated, as indicated by the 95.78% accuracy on the base model, and therefore the model may not benefit from much more information given in the pretraining.

  • The variance of the preictal, ictal, and nonictal windows may “confuse” temporal-based SSL methods, as these states are hypothesized to contain different signal characteristics, and simplifying this into two classes may actually harm the performance of our SSL methods by their construction.

  • The encoder or projector layers are not optimally configured, e.g., we could try different graph layers besides ECC or GAT, or we could try different layer sizes of the projector MLP.

Seizure Detection (Binary)
Model Test Accuracy
Supervised 95.78%
Supervised + RP 95.65%
Supervised + TS 96.13%

I will update this with more information as the project progresses, with more methods, and on different tasks (e.g., multiclass clasification of preictal, ictal, postictal). Predicting before (preictal) or after (postictal) a seizure occurs is a much more interesting question and a much more challenging task at that given our previous results in the supervised pipeline. We’ll see how well our self-supervised learning methods fare on this task compared to the binary classification task and I’ll likely be tweaking the hyper parameters $\tau_+$ and $\tau_-$ too. Hopefully, I’ll have some updates on the results of the CPC and VICReg as well (and a new SSL model)! Until next time.

Next
Next

Dementia Detection with the GDFT