Paper Breakdown: Channel Clustering for Time Series Forecasting
Is everything independent?
The word “independence” has been thrown around a lot recently in the time series literature. One of the most popular forecasting models, PatchTST, popularized not only patching for time series (similar to ViTs) but also challenged the notion of channel dependence (CD), by introducing channel independence (CI). In the CD framework, we believe that separate channels or variables, are related to one another and thus should be modeled together. But within the CI framework, our model is applied independently to each channel, and so we don’t explicitly force our model to form relations between channels. This year at ICLR 2024, one paper even released a state-of-the-art model that is patch independent! So the question is, what’s the best framework? CD? CI? Patch independence? The preprint titled, From Similarity to Superiority: Channel Clustering for Time Series Forecasting, attempts to blend the best of both worlds (CD and CI) by introducing a Channel Clustering Module (CCM) which clusters similar channels together, and then adaptively weights their output according to their cluster identity. Below, we’ll break down the math of everything and briefly go over their main experiments near the end.
Model Breakdown
The authors separate their architecture into two components: (1) the temporal modules, which is your traditional backbone that you would normally use for forecasting; and (2) the cluster assigner, which assigns each channel to a cluster. The computations required for (1) and (2) can be done in parallel, at which point we take the output of our backbone in (1) and weight it adaptively according to the cluster identity from (2), in what they call cluster-aware feedforward. What they refer to as “normalization” in the diagram above is an instance normalization module, which normalizes the individual time series window (not the entire channel), such as RevIN.
Channel Similarity
Let $\boldsymbol{x} \in \mathbb{R}^{T \times C}$ be a multivariate time series with sequence length $T$ and $C$ channels, and denote $\boldsymbol{x}_i \in \mathbb{R}^{C}$ as the $i^{th}$ channel for each $i$. First, we have to define a channel similarity function, $\text{sim}: \mathbb{R}^{T} \times \mathbb{R}^{T} \to \mathbb{R}$. The authors select the gaussian kernel:
$$\text{sim}(\boldsymbol{x}_i, \boldsymbol{x}_j) = \exp\bigg(\frac{-\|\boldsymbol{x}_i - \boldsymbol{x}_j\|_2^2}{2\sigma^2}\bigg)$$
with hyperparameter $\sigma \in \mathbb{R}$. *Note that we compute the similarities on the standardized time series to avoid the effects of anomalies and scale.
Channel Clustering
For each $\boldsymbol{x}_i \in \mathbb{R}^T$, we embed it with an MLP to obtain $\boldsymbol{h}_i \in \mathbb{R}^{d}$. We then initialize a set of cluster embeddings $\{\boldsymbol{c}_1, \dots, \boldsymbol{c}_K\} \subseteq \mathbb{R}^d$, for a hyperparameter $K$, representing the number of clusters available. We’ll get to how we update these cluster embeddings later on. For each channel $i$ and cluster $k$, we model the (cluster) probability of $\boldsymbol{x}_i$ belonging to cluster $k$ as:
$$p_{i, k} = \text{Norm}\bigg( \frac{\boldsymbol{c}_k \cdot \boldsymbol{h}_i}{\|\boldsymbol{c}_k\|_2\|\boldsymbol{h}_i\|_2} \bigg)$$
The normalization function $\text{Norm}(\cdot)$ enforces the constraint $\sum_{k} p_{i,k} = 1$ for each channel $i$. Although the authors don’t specify exactly what this function is, I am guessing they likely apply a $\text{softmax}$ to the values $\tilde{p}_{i,1}, \dots, \tilde{p}_{i,K}$ to ensure the constraint, where $\tilde{p}_{i,k} = \frac{\boldsymbol{c}_k \cdot \boldsymbol{h}_i}{\|\boldsymbol{c}_k\|_2\|\boldsymbol{h}_i\|_2}$ is the unnormalized value. After we are done computing each cluster probability, we compute the cluster membership matrix $\mathbf{M} \in \mathbb{R}^{C \times K}$ where each $\textbf{M}_{i,k} \approx \text{Bernoulli}(p_{i,k})$, which the authors implement through the reparameterization trick with the Gumbel-Softmax.
-
Each entry $\mathbf{M}_{i,k}$ approximates a $\text{Bernoulli}(p_{i,k})$ distribution, implemented using the reparameterization trick with the Gumbel-Softmax. This process is analogous to sampling from a Bernoulli distribution with probability $p_{i,k}$, but with a crucial difference: the resulting values of $\mathbf{M}_{i,k}$ are not strictly $0$ or $1$. Instead, they are continuous values very close to $0$ or $1$. This near-binary nature allows the operation to remain differentiable (for gradient descent).
Cluster-Aware Feedforward
Instead of using a typical CI or CD linear head, we assign a separate linear layer $h_{\theta_k}: \mathbb{R}^{m} \to \mathbb{R}^{F}$ for each cluster $k$, where $F$ is our forecast horizon. Denote $\boldsymbol{z}_i$ as the output of our “temporal module” (i.e. the backbone), for each channel input $\boldsymbol{x}_i$. The predicted forecast for each channel $i$ is then given by:
$$\hat{\boldsymbol{y}}_i = \sum_{k = 1}^K p_{i,k} h_{\theta_k}(\boldsymbol{z}_i)$$
Therefore, if our cluster probability is low, the effect of the layer $h_{\theta_k}$ on the output will be small, and vice-versa if $p_{i,k}$ is high then $h_{\theta_k}$ will contribute more to the output. What I like about this is that each cluster contribution does not need to be necessarily “high” or “low”, and still allows a blend from different clusters, thus the expressivity of the model is not constrained even though we are in some sense “compressing” the channel identity to a few clusters.
Prototype Learning
This component of the network allows us to refine our cluster embeddings $\mathbf{C}$, so that they are more informative in the next iteration. We will place the cluster embeddings in the matrix $\mathbf{C} = [\boldsymbol{c}_1, \dots, \boldsymbol{c}_K]^T \in \mathbb{R}^{K \times d}$, and the channel embeddings in $\mathbf{H} = [\boldsymbol{h}_1, \dots, \boldsymbol{h}_C]^T \in \mathbb{R}^{C \times d}$. Define,
$$\begin{aligned} \mathbf{Q} &= \mathbf{W}_Q \mathbf{C}, \\ \mathbf{K} &= \mathbf{W}_K\mathbf{H}, \\ \mathbf{V} &= \mathbf{W}_V\mathbf{H}, \end{aligned}$$
as the query, key, and value matrices respectively, for learnable weight matrices $\mathbf{W}_Q, \mathbf{W}_V, \mathbf{W}_K$. We then apply the modified attention:
$$\hat{\mathbf{C}} = \text{Norm}\bigg( \exp\bigg(\frac{\mathbf{Q}\mathbf{K}^T}{\sqrt{d}}\bigg) \odot \mathbf{M}^T\bigg)\mathbf{V}$$
where $\odot \mathbf{M}^T$ acts as sparse attention mechanism to only include attention coefficients with high cluster probability. The prototype embedding $\hat{\mathbf{C}}$ then serves as the updated cluster embeddings for $\mathbf{C}$.
Cluster Loss
Let $\mathbf{S} \in \mathbb{R}^{C \times C}$ be the channel similarity matrix with $\mathbf{S}_{i,j} = \text{sim}(\boldsymbol{x}_i, \boldsymbol{x}_j)$. We then define the cluster loss as:
$$\mathcal{L}_{C} = -\text{Tr}(\mathbf{M}^T \mathbf{S}\mathbf{M}) + \text{Tr}\big((\mathbf{I} - \mathbf{M}\mathbf{M}^T)\mathbf{S}\big)$$
This loss optimizes two objectives:
Channel Similarity Consistency: $\text{Tr}(\mathbf{M}^T \mathbf{S}\mathbf{M})$ maximizes the channel similarities within clusters, refining the embeddings $\boldsymbol{h}_i$ and $\boldsymbol{c}_i$.
Cluster Distancting: $\text{Tr}\big((\mathbf{I} - \mathbf{M}\mathbf{M}^T)\mathbf{S}\big)$ encourages clusters to be distant from each other so that our clusters do not overlap and are therefore more informative.
The total loss is given by:
$$\mathcal{L} = \mathcal{L}_F + \beta \mathcal{L}_C$$
where $\mathcal{L}_F$ is the traditional forecasting loss and $\beta \in \mathbb{R}^+$ is a hyperparameter.
Experiments
Long-Term Forecasting
The authors evaluate their method ontop of several popular time series backbones, including TSMixer, DLinear, PatchTST, and TimesNet for long-term forecasting. Their selected benchmarks are relatively standard for time series forecasting in deep learning, including the ETT, Exchange, Illness, Weather, Electricity, and Traffic datasets. I can’t say I'm a biggest fan of this benchmarking suite, but you will find it in many time series forecasting papers. For the vast majority of experiments, it looks like CCM does indeed increase the performance for a number of different forecast horizons, with respect to their Mean Squared Error (MSE) and Mean Absolute Error (MAE) values on the test set.
Zero-Shot Forecasting
The authors also evaluate their method on zero-shot forecasting, in which the backbone and CCM are pretrained on a dataset (e.g., ETTh1), and then applied on inference to another dataset (e.g., ETTh2). The authors show that the addition of the frozen prototype embedding $\mathbf{\hat{C}}$ does increase zero-shot performance through the CCM method, seen in the table below, and therefore these clusters may generalize to other contexts (perhaps the right step towards foundation models for time series). My only criticism is that I would’ve liked to see “more” out-of-distribution datasets, since the datasets of ETTh1 and ETTm1 for example are still heavily correlated due them being sampled from the same context, just at different sampling rates. It might also be interesting in the future to see some fine-tuning experiments (frozen CCM but fine-tuning or replacing the backbone) for transfer learning. Overall, I think it’s a cool contribution from the paper, and I can see this opening many avenues within the zero-shot and foundation model areas of time series research.
I’m looking forward to the published paper and code release, and am interested to know what others think of this method. I think CCM is a nice solution to the CI vs. CD problem, and doesn’t seem limited to just forecasting. I’ll definitely be trying this out in the near future, likely on different time series task. Thank you to Jialin Chen, the first author of the paper, for reading through the article and providing feedback; go check out her other work if you liked this method! If you found this blog helpful, I would very much appreciate it if you shared with others. Thanks!
-
Original paper: https://arxiv.org/abs/2404.01340
Reversible Instance Normalization (RevIN): https://openreview.net/forum?id=cGDAkQo1C0p
PatchTST: https://arxiv.org/abs/2211.14730
DLinear: https://arxiv.org/abs/2205.13504
TSMixer: https://arxiv.org/abs/2303.06053
TimesNet: https://arxiv.org/abs/2210.02186
Gumbel-Softmax Reparameterization Trick: https://arxiv.org/abs/1611.01144