Install all required dependencies
%pip install git+https://github.com/BonnerLab/ccn-tutorial.git
Generalizing cross-validated PCA to pairs of systems using cross-decomposition
August 26, 2023
March 10, 2024
Here’s a link to this notebook on Google Colab.
Just as PCA identifies the principal directions of variance of a system, cross-decomposition identifies the principal directions of shared variance between two systems X and Y. Specifically, just as PCA computes the eigendecomposition of the auto-covariance, cross-decomposition computes the singular value decomposition of the cross-covariance:
\begin{align*} \text{cov}(X, Y) &= X^\top Y / (n - 1)\\ &= U \Sigma V^\top \end{align*}
Here,
The cross-decomposition method we describe here is more specifically known as Partial Least Squares Singular Value Decomposition (PLS-SVD). We simplify it to “cross-decomposition” since we will be developing a cross-validated version of the typical estimators.
import warnings
import numpy as np
import pandas as pd
import xarray as xr
import torch
import seaborn as sns
from matplotlib import pyplot as plt
from matplotlib_inline.backend_inline import set_matplotlib_formats
from IPython.display import display
from utilities.brain import (
load_dataset,
average_data_across_repetitions,
load_stimuli,
)
from utilities.computation import svd, assign_logarithmic_bins
%matplotlib inline
sns.set_theme(
context="notebook",
style="white",
palette="deep",
rc={"legend.edgecolor": "None"},
)
set_matplotlib_formats("svg")
pd.set_option("display.max_rows", 5)
pd.set_option("display.max_columns", 10)
pd.set_option("display.precision", 3)
pd.set_option("display.show_dimensions", False)
xr.set_options(display_max_rows=3, display_expand_data=False)
warnings.filterwarnings("ignore")
class PLSSVD:
def __init__(self) -> None:
self.left_mean: np.ndarray
self.right_mean: np.ndarray
self.left_singular_vectors: np.ndarray
self.right_singular_vectors: np.ndarray
def fit(self, /, x: np.ndarray, y: np.ndarray) -> None:
1 self.left_mean = x.mean(axis=-2)
self.right_mean = y.mean(axis=-2)
x_centered = x - self.left_mean
y_centered = y - self.right_mean
n_stimuli = x.shape[-2]
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
cross_covariance = (np.swapaxes(x_centered, -1, -2) @ y_centered) / (
n_stimuli - 1
2 )
(
self.left_singular_vectors,
self.singular_values,
self.right_singular_vectors,
) = svd(
torch.from_numpy(cross_covariance).to(device),
n_components=min([*x.shape, *y.shape]),
truncated=True,
seed=random_state,
3 )
self.left_singular_vectors = self.left_singular_vectors.cpu().numpy()
self.singular_values = self.singular_values.cpu().numpy()
self.right_singular_vectors = self.right_singular_vectors.cpu().numpy()
def transform(self, /, z: np.ndarray, *, direction: str) -> np.ndarray:
match direction:
case "left":
4 return (z - self.left_mean) @ self.left_singular_vectors
case "right":
return (z - self.right_mean) @ self.right_singular_vectors
case _:
raise ValueError("direction must be 'left' or 'right'")
In the same way that we can cross-validated PCA to estimate the shared variance across presentations of the same stimuli within a participant, we can use cross-decomposition to estimate the shared variance in the neural representations of the same stimuli across participants.
We have two data matrices X \in \mathbb{R}^{N \times P_X} and Y \in \mathbb{R}^{N \times P_Y} from two participants, containing neural responses to the same N stimuli. Note that the number of neurons or voxels in the two subjects (P_X and P_Y) can be different – we don’t need to assume any sort of anatomical alignment between brains.
In cross-validated PCA, we measured stimulus-specific variance based on cross-trial generalization. Here, even if we don’t have different repetitions of the same stimulus, we could use an analogous cross-validation approach. Instead of testing generalization across different presentations of the stimuli, we can evaluate the reliable shared variance between the two systems across stimuli.
Specifically, we can divide the images into two: a training split and a test split. We can compute singular vectors on the training split, and evalute test singular values on the test split:
\begin{align*} \text{cov}(X_\text{train}, Y_\text{train}) &= X_\text{train}^\top Y_\text{train} / (n - 1)\\ &= U \Sigma V^\top \end{align*}
\begin{align*} \Sigma_\text{cross-validated} &= \text{cov}(X_\text{test} U, Y_\text{test} V)\\ &= \left( X_\text{test} U \right) ^\top \left( Y_\text{test} V \right) / (n - 1) \end{align*}
<xarray.DataArray 'fMRI betas' (presentation: 700, neuroid: 15724)> Size: 44MB 0.4915 0.2473 0.08592 0.05828 -0.1315 ... -0.2126 -0.6315 -0.5751 -0.5354 Coordinates: (3/4) x (neuroid) uint8 16kB 12 12 12 12 12 12 12 ... 72 72 72 72 72 72 y (neuroid) uint8 16kB 21 22 22 22 22 22 23 ... 29 29 30 30 30 31 ... ... stimulus_id (presentation) object 6kB 'image02950' ... 'image72948' Dimensions without coordinates: presentation, neuroid Attributes: (3/8) resolution: 1pt8mm preprocessing: fithrf_GLMdenoise_RR ... ... postprocessing: averaged across first two repetitions
array([[ 0.4915219 , 0.24733381, 0.08592446, ..., -0.366651 , 0.30723202, 0.43520752], [ 0.1664538 , -0.10728736, 0.35630295, ..., 0.8608913 , 0.03464809, 0.11020081], [ 1.0357349 , 0.77598304, 0.35813144, ..., 0.2419075 , 0.81557286, 0.38667244], ..., [-0.05812129, -0.4539395 , 0.41060364, ..., 0.5738151 , -0.718189 , -0.638827 ], [-0.00340644, -1.0050421 , 0.7278904 , ..., 0.580743 , -0.50856245, -0.2727615 ], [-1.2668517 , -1.4769105 , -0.3562023 , ..., -0.63146234, -0.575121 , -0.5354325 ]], dtype=float32)
array([12, 12, 12, ..., 72, 72, 72], dtype=uint8)
array([21, 22, 22, ..., 30, 30, 31], dtype=uint8)
array([47, 45, 46, ..., 46, 49, 49], dtype=uint8)
array(['image02950', 'image02990', 'image03049', 'image03077', 'image03146', 'image03157', 'image03164', 'image03171', 'image03181', 'image03386', 'image03434', 'image03449', 'image03489', 'image03626', 'image03682', 'image03687', 'image03729', 'image03809', 'image03842', 'image03847', 'image03856', 'image03913', 'image03951', 'image04051', 'image04058', 'image04129', 'image04156', 'image04249', 'image04423', 'image04436', 'image04667', 'image04690', 'image04768', 'image04786', 'image04835', 'image04892', 'image04930', 'image05034', 'image05106', 'image05204', 'image05301', 'image05338', 'image05459', 'image05542', 'image05583', 'image05602', 'image05714', 'image06199', 'image06222', 'image06431', 'image06444', 'image06489', 'image06514', 'image06521', 'image06558', 'image06801', 'image07007', 'image07039', 'image07120', 'image07207', 'image07366', 'image07418', 'image07480', 'image07654', 'image07840', 'image07859', 'image07944', 'image07948', 'image08006', 'image08109', 'image08204', 'image08225', 'image08394', 'image08415', 'image08435', 'image08465', 'image08509', 'image08646', 'image08807', 'image08843', ... 'image64615', 'image64621', 'image64867', 'image64880', 'image65010', 'image65148', 'image65253', 'image65267', 'image65376', 'image65445', 'image65769', 'image65799', 'image65821', 'image65872', 'image65920', 'image65943', 'image66004', 'image66216', 'image66278', 'image66330', 'image66464', 'image66489', 'image66580', 'image66773', 'image66836', 'image66946', 'image66976', 'image67045', 'image67113', 'image67204', 'image67237', 'image67295', 'image67742', 'image67802', 'image67829', 'image68168', 'image68278', 'image68339', 'image68418', 'image68471', 'image68741', 'image68814', 'image68842', 'image68858', 'image68897', 'image69007', 'image69130', 'image69214', 'image69240', 'image69502', 'image69614', 'image69839', 'image69854', 'image70075', 'image70095', 'image70193', 'image70232', 'image70335', 'image70360', 'image70427', 'image70505', 'image71229', 'image71232', 'image71241', 'image71410', 'image71450', 'image71753', 'image71894', 'image72015', 'image72080', 'image72209', 'image72312', 'image72510', 'image72605', 'image72719', 'image72948'], dtype=object)
<xarray.DataArray 'fMRI betas' (presentation: 700, neuroid: 14278)> Size: 40MB -0.8554 0.0399 0.09591 -0.4694 -0.4573 ... -1.052 -0.6467 -0.6164 -0.8053 Coordinates: (3/4) x (neuroid) uint8 14kB 11 11 11 11 12 12 12 ... 71 71 71 71 71 71 y (neuroid) uint8 14kB 23 25 25 26 22 22 22 ... 28 29 29 30 30 30 ... ... stimulus_id (presentation) object 6kB 'image02950' ... 'image72948' Dimensions without coordinates: presentation, neuroid Attributes: (3/8) resolution: 1pt8mm preprocessing: fithrf_GLMdenoise_RR ... ... postprocessing: averaged across first two repetitions
array([[-0.85538554, 0.03990293, 0.09591419, ..., -1.6162797 , -1.1637473 , -0.22345664], [-1.1291182 , -0.24890876, 0.27723873, ..., 0.8381426 , 1.2138667 , 0.9067359 ], [ 1.0714364 , -0.5201533 , 0.18692002, ..., 1.1055344 , 1.4412578 , 0.5147082 ], ..., [ 0.6021575 , -0.00671851, -1.2595241 , ..., -0.10587344, -0.15873775, -0.6149586 ], [ 0.61964816, -0.6503349 , -0.42204803, ..., -0.73787045, -1.0440085 , -1.0924375 ], [-0.690616 , 0.61802167, -1.1264803 , ..., -0.6467064 , -0.61641246, -0.8052834 ]], dtype=float32)
array([11, 11, 11, ..., 71, 71, 71], dtype=uint8)
array([23, 25, 25, ..., 30, 30, 30], dtype=uint8)
array([37, 32, 33, ..., 40, 41, 42], dtype=uint8)
array(['image02950', 'image02990', 'image03049', 'image03077', 'image03146', 'image03157', 'image03164', 'image03171', 'image03181', 'image03386', 'image03434', 'image03449', 'image03489', 'image03626', 'image03682', 'image03687', 'image03729', 'image03809', 'image03842', 'image03847', 'image03856', 'image03913', 'image03951', 'image04051', 'image04058', 'image04129', 'image04156', 'image04249', 'image04423', 'image04436', 'image04667', 'image04690', 'image04768', 'image04786', 'image04835', 'image04892', 'image04930', 'image05034', 'image05106', 'image05204', 'image05301', 'image05338', 'image05459', 'image05542', 'image05583', 'image05602', 'image05714', 'image06199', 'image06222', 'image06431', 'image06444', 'image06489', 'image06514', 'image06521', 'image06558', 'image06801', 'image07007', 'image07039', 'image07120', 'image07207', 'image07366', 'image07418', 'image07480', 'image07654', 'image07840', 'image07859', 'image07944', 'image07948', 'image08006', 'image08109', 'image08204', 'image08225', 'image08394', 'image08415', 'image08435', 'image08465', 'image08509', 'image08646', 'image08807', 'image08843', ... 'image64615', 'image64621', 'image64867', 'image64880', 'image65010', 'image65148', 'image65253', 'image65267', 'image65376', 'image65445', 'image65769', 'image65799', 'image65821', 'image65872', 'image65920', 'image65943', 'image66004', 'image66216', 'image66278', 'image66330', 'image66464', 'image66489', 'image66580', 'image66773', 'image66836', 'image66946', 'image66976', 'image67045', 'image67113', 'image67204', 'image67237', 'image67295', 'image67742', 'image67802', 'image67829', 'image68168', 'image68278', 'image68339', 'image68418', 'image68471', 'image68741', 'image68814', 'image68842', 'image68858', 'image68897', 'image69007', 'image69130', 'image69214', 'image69240', 'image69502', 'image69614', 'image69839', 'image69854', 'image70075', 'image70095', 'image70193', 'image70232', 'image70335', 'image70360', 'image70427', 'image70505', 'image71229', 'image71232', 'image71241', 'image71410', 'image71450', 'image71753', 'image71894', 'image72015', 'image72080', 'image72209', 'image72312', 'image72510', 'image72605', 'image72719', 'image72948'], dtype=object)
def compute_cross_participant_spectrum(
x: xr.DataArray,
y: xr.DataArray,
/,
train_fraction: float = 7 / 8,
) -> np.ndarray:
stimuli = x["stimulus_id"].values
n_train = int(train_fraction * len(stimuli))
stimuli = rng.permutation(stimuli)[:n_train]
train_indices = np.isin(x["stimulus_id"].values, stimuli)
x_train = x.isel({"presentation": train_indices})
y_train = y.isel({"presentation": train_indices})
x_test = x.isel({"presentation": ~train_indices})
y_test = y.isel({"presentation": ~train_indices})
scorer = PLSSVD()
scorer.fit(x_train.values, y_train.values)
x_test_transformed = scorer.transform(x_test.values, direction="left")
y_test_transformed = scorer.transform(y_test.values, direction="right")
return np.diag(
np.cov(
x_test_transformed,
y_test_transformed,
rowvar=False,
)[:n_train, n_train:]
)
cross_participant_spectrum = compute_cross_participant_spectrum(subject_1, subject_2)
data = pd.DataFrame(
{
"cross-validated singular value": cross_participant_spectrum,
"rank": assign_logarithmic_bins(
1 + np.arange(len(cross_participant_spectrum)),
min_=1,
max_=10_000,
points_per_bin=5,
),
}
)
fig, ax = plt.subplots()
sns.lineplot(
ax=ax,
data=data.assign(arbitrary=0),
x="rank",
y="cross-validated singular value",
marker="o",
dashes=False,
hue="arbitrary",
palette=["#cf6016"],
ls="None",
err_style="bars",
estimator="mean",
errorbar="sd",
legend=False,
)
ax.set_xscale("log")
ax.set_yscale("log")
ax.set_aspect("equal", "box")
sns.despine(ax=ax, offset=20)
def compute_within_participant_spectrum(data: xr.DataArray) -> np.ndarray:
x_train = data.isel({"presentation": data["rep_id"] == 0}).sortby("stimulus_id")
y_train = data.isel({"presentation": data["rep_id"] == 0}).sortby("stimulus_id")
x_test = data.isel({"presentation": data["rep_id"] == 0}).sortby("stimulus_id")
y_test = data.isel({"presentation": data["rep_id"] == 1}).sortby("stimulus_id")
scorer = PLSSVD()
scorer.fit(x_train.values, x_train.values)
x_test_transformed = scorer.transform(x_test.values, direction="left")
y_test_transformed = scorer.transform(y_test.values, direction="right")
n_components = x_test_transformed.shape[-1]
return np.diag(
np.cov(
x_test_transformed,
y_test_transformed,
rowvar=False,
)[:n_components, n_components:]
)
within_participant_spectrum = compute_within_participant_spectrum(
load_dataset(subject=0, roi="general")
)
data = pd.concat(
[
data.assign(comparison="cross-individual"),
pd.DataFrame(
{
"cross-validated singular value": within_participant_spectrum,
"rank": assign_logarithmic_bins(
1 + np.arange(len(within_participant_spectrum)),
min_=1,
max_=10_000,
points_per_bin=5,
),
}
).assign(comparison="within-individual"),
],
axis=0,
)
with sns.axes_style("whitegrid"):
fig, ax = plt.subplots()
sns.lineplot(
ax=ax,
data=data,
x="rank",
y="cross-validated singular value",
palette=["#514587", "#cf6016"],
hue="comparison",
hue_order=["within-individual", "cross-individual"],
style="comparison",
style_order=["within-individual", "cross-individual"],
markers=["o", "s"],
dashes=False,
ls="None",
err_style="bars",
estimator="mean",
errorbar="sd",
)
ax.set_xscale("log")
ax.set_yscale("log")
ax.set_aspect("equal", "box")
ax.grid(True, which="minor", c="whitesmoke")
ax.grid(True, which="major", c="lightgray")
for loc in ("left", "bottom", "top", "right"):
ax.spines[loc].set_visible(False)
The cross-validated cross-decomposition approach we describe here allows many possible levels of generalization.
In fact, we could combine several of these to get very strict generalization criteria: we could even estimate the spectrum of variance that generalizes across trials, stimuli, and individuals.