Introduction to Canonical Correlation Analysis (CCA) in Python

CCA Plot: Scatter plot first pair of canonical covariate
CCA Plot: Scatter plot first pair of canonical covariate

Increasingly, we have multiple high dimensional datasets from from the same samples. Canonical Correlation Analysis aka CCA is great for scenarios where you two high dimensional datasets from the same samples and it enables learning looking at the datasets simultaneously. A classic example is audio and video datasets from the same individuals. One can also think of CCA is another dimensionality reduction technique like Principal Component Analysis. Unlike PCA, you deal with two datasets in CCA instead of one.

If you are interested in a bit of history CCA is originally developed by the same Hotelling who developed PCA in the 1930s.

What is Canonical Correlation Analysis?

In this tutorial, we will see examples of how to perform CCA using Palmer Penguins data set. We will use scikit-learn to perform Canonical Correlation Analysis (CCA). We will not go into the math behind CCA, instead we will see a practical example of doing CCA and understand the main intuition behind the results.

As we explained in the previous post on performing CCA with R, one can understand the idea behind CCA as follows

Let us say, there is one or more variables generating two high-dimensional data sets X and Y. Here, the data sets X and Y are observables. And we don’t know about the latent variable(s) behind the two data sets. Since both the datasets are from the latent variable, there will be a lot shared or common variation in the two datasets. By doing CCA, we can identify the common variation, the canonical variates that are highly correlated to the unknown latent variable.

Typically, the two datasets might have other variation in addition to the variation due to the main latent variable. And, CCA helps us remove the data specific variation or noise in the two datasets and gets to the canonical variable that captures the hidden variable.

Palmer Penguin Dataset for Canonical Correlation Analysis

To get started, let us load Pandas, Matplotlib, numpy, and Seaborn.

import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np

Let us load Palmer Penguin datasets and do a a bit bit of cleanup by removing any rows with missing values.

link2data = "https://raw.githubusercontent.com/mwaskom/seaborn-data/master/penguins.csv"
df = pd.read_csv(link2data)
df =df.dropna()
df.head()

Since Canonical Correlation Analysis involve two high dimensional data set sets, we will split the Penguin’s data into two datasets X and Y. The dataset X has two variables corresponding to bill length and depth.

X = df[['bill_length_mm','bill_depth_mm']]
X.head()

bill_length_mm	bill_depth_mm
0	39.1	18.7
1	39.5	17.4
2	40.3	18.0
4	36.7	19.3
5	39.3	20.6

We also need to standardize the variables by subtracting with mean and dividing by standard deviation.

X_mc = (X-X.mean())/(X.std())
X_mc.head()

	bill_length_mm	bill_depth_mm
0	-0.894695	0.779559
1	-0.821552	0.119404
2	-0.675264	0.424091
4	-1.333559	1.084246
5	-0.858123	1.744400

The second dataset Y contains flipper length and body weight.

Y = df[['flipper_length_mm','body_mass_g']]
Y.head()

Let us standardize the Y data set as well.

Y_mc = (Y-Y.mean())/(Y.std())
Y_mc.head()
	flipper_length_mm	body_mass_g
0	-1.424608	-0.567621
1	-1.067867	-0.505525
2	-0.425733	-1.188572
4	-0.568429	-0.940192
5	-0.782474	-0.691811

Canonical Correlation Analysis with Scikit-learn in Python

Now we have two datasets from the same Penguins. As we know, the features of a Penguin differ strongly from another based on what species it is. And in our toy example, species is a latent variable that is common behind the two datasets X and Y.

To perform CCA in Python, We will use CCA module from sklearn.cross_decomposition.

from sklearn.cross_decomposition import CCA

First, we instantiate CCA object and use fit() and transform() functions with the two standardized matrices to perform CCA.

ca = CCA()
ca.fit(X_mc, Y_mc)
X_c, Y_c = ca.transform(X_mc, Y_mc)

And our result is two canonical correlate matrices.

print(X_c.shape)
print(Y_c.shape)
(333,2)
(333,2)

Understanding the results from Canonical Correlation Analysis

Let us dig deeper to understand the results from canonical correlation analysis and get an intuition. At first, we will use the pair of canonical covariates and see how are they correlated to each other.

Let us add the results from CCA with the meta data corresponding to Penguins data in to a Pandas dataframe.

cc_res = pd.DataFrame({"CCX_1":X_c[:, 0],
                       "CCY_1":Y_c[:, 0],
                       "CCX_2":X_c[:, 1],
                       "CCY_2":Y_c[:, 1],
                       "Species":df.species.tolist(),
                      "Island":df.island.tolist(),
                      "sex":df.sex.tolist()})
cc_res.head()
	CCX_1	CCY_1	CCX_2	CCY_2	Species	Island	sex
0	-1.186252	-1.408795	-0.010367	0.682866	Adelie	Torgersen	MALE
1	-0.709573	-1.053857	-0.456036	0.429879	Adelie	Torgersen	FEMALE
2	-0.790732	-0.393550	-0.130809	-0.839620	Adelie	Torgersen	FEMALE
3	-1.718663	-0.542888	-0.073623	-0.458571	Adelie	Torgersen	FEMALE
4	-1.772295	-0.763548	0.736248	-0.014204	Adelie	Torgersen	MALE

Let us check how correlated the the first pair of canonical covariates are. We use NumPy’s corrcoef() function to compute the correlation. And we can see that the first of pair of canonical covariates are highly correlated.

import numpy as np
np.corrcoef(X_c[:, 0], Y_c[:, 0])
array([[1.        , 0.78763151],
       [0.78763151, 1.        ]])

We can also compute the correlation between the second pair of covariates and we see that the correlation is not that high.

np.corrcoef(X_c[:, 1], Y_c[:, 1])

array([[1.        , 0.08638695],
       [0.08638695, 1.        ]])

To further understand the relationship between the pairs of canonical covariates, let us make scatter plot with the first pair of canonical covariates.

sns.set_context("talk", font_scale=1.2)
plt.figure(figsize=(10,8))
sns.scatterplot(x="CCX_1",
                y="CCY_1", 
                data=cc_res)
plt.title('Comp. 1, corr = %.2f' %
         np.corrcoef(X_c[:, 0], Y_c[:, 0])[0, 1])

We can see the high correlation between the first pair of canonical covariates.

Scatter plot of First Pair of Canonical Covariates

Interpreting Canonical Covariates with Heatmap

In this toy example, we know already that the species variable in the dataset is a latent variable or hidden variable. Let us see how the latent variable is correlated with the first pair of canonical covariates. We first make a boxplot between the latent variable and each of the first pair of canonical covariates.

plt.figure(figsize=(10,8))
sns.boxplot(x="Species",
                y="CCX_1", 
               data=cc_res)
sns.stripplot(x="Species",
                y="CCX_1", 
                 data=cc_res)
Boxplot of Canonical Correlate from X and Latent Variable
plt.figure(figsize=(10,8))
sns.boxplot(x="Species",
                y="CCY_1", 
                 data=cc_res)
sns.stripplot(x="Species",
                y="CCY_1", 
                 data=cc_res)

Boxplot of Canonical Correlate from Y and Latent Variable

By coloring the scatterplots between the first pair of canonical covariates with the species variable and see how the canonical covariates have capture the underlying variable behind our datasets.

plt.figure(figsize=(10,8))
sns.scatterplot(x="CCX_1",
                y="CCY_1", 
                hue="Species", data=cc_res)
plt.title('First Pair of Canonical Covariate, corr = %.2f' %
         np.corrcoef(X_c[:, 0], Y_c[:, 0])[0, 1])
Scatter plot of First Pair of Canonical Covariates colored by Latent Variable

From the correlations between the first pair of canonical covariates and species variable in the data set, we can see that our Canonical Correlation Analysis has captured the shared variation between the two datasets. In this example the shared variation or the latent variable behind the first pair of canonical covariates is the species variable.

By carefully looking at the association between CCs and a variable here the latent variable, we can definitely understand the results from our CCA results. Let us take a step further and make a heatmap of correlations between canonical covariates from each of the dataset against our input dataset including the meta data associated with it,

Let us repeat ourselves by creating dataframe with original data and the canonical covariates from the first dataset X. To compute correlation, we convert the character variables into categorical variables and convert them to 0/1/2 codes.

ccX_df = pd.DataFrame({"CCX_1":X_c[:, 0],
                       "CCX_2":X_c[:, 1],
                       "Species":df.species.astype('category').cat.codes,
                      "Island":df.island.astype('category').cat.codes,
                      "sex":df.sex.astype('category').cat.codes,
                       "bill_length":X_mc.bill_length_mm,
                      "bill_depth":X_mc.bill_depth_mm})


With Pandas’ corr() function we can compute correlation of all variables in the dataframe.

corr_X_df= ccX_df.corr(method='pearson') 
corr_X_df.head()

Let us make a heatmap with lower triangular correlation matrix. To do that we subset the original correlational matrix with Numpy’s tril() function.

plt.figure(figsize=(10,8))
X_df_lt = corr_X_df.where(np.tril(np.ones(corr_X_df.shape)).astype(np.bool))

We can make make lower triangular correlation heatmap using Seaborn’s heatmap function.

sns.heatmap(X_df_lt,cmap="coolwarm",annot=True,fmt='.1g')
plt.tight_layout()
plt.savefig("Heatmap_Canonical_Correlates_from_X_and_data.jpg",
                    format='jpeg',
                    dpi=100)

The heatmap showing correlations of the canonical covariates from the dataset X reveals a lot of interesting details. We can see that as expected there is nor correlation between the first and second canonical covariates from the dataset X. Note that the covariate is highly correlated to the two variables in the dataset X, positively with bill length and negatively with depth.

As we saw before, the first canonical co-variate is highly correlated with Species variable, the latent or hidden variable in this example. The first canonical covariate from X is also correlated with the other hidden variable, island, associated with the dataset. We can also see that the first canonical covariate is not correlated to sex variable. However, the second canonical covariate is correlated with sex variable moderately.

Correlation Heatmap of Canonical Correlates from Data X and dataset

Let us make a similar lower triangular correlation heatmap with the canonical covariates from Y dataset Y and the dataset Y itself including the meta data associated with it.

# second pair of canonical covariates with the dataset 
ccY_df = pd.DataFrame({"CCY_1":Y_c[:, 0],
                       "CCY_2":Y_c[:, 1],
                       "Species":df.species.astype('category').cat.codes,
                      "Island":df.island.astype('category').cat.codes,
                      "sex":df.sex.astype('category').cat.codes,
                       "flipper_length":Y_mc.flipper_length_mm,
                      "body_mass":Y_mc.body_mass_g})

# compute correlation with Pandas corr()
corr_Y_df= ccY_df.corr(method='pearson') 

# Get lower triangular correlation matrix
Y_df_lt = corr_Y_df.where(np.tril(np.ones(corr_Y_df.shape)).astype(np.bool))

# make a lower triangular correlation heatmap with Seaborn
plt.figure(figsize=(10,8))
sns.heatmap(Y_df_lt,cmap="coolwarm",annot=True,fmt='.1g')
plt.tight_layout()
plt.savefig("Heatmap_Canonical_Correlates_from_Y_and_data.jpg",
                    format='jpeg',
                    dpi=100)

We see a similar pattern as we saw with the correlation map of canonical covariates from dataset X. One of the highlights is as we saw before the first canonical covariate from dataset Y is highly correlated with Species variable. And the second canonical covariate from dataset is Y is correlated with sex variable, suggesting that we can capture the effect of two distinct hidden variables using canonical correlation analysis.

Correlation Heatmap of Canonical Correlates from Data Y and dataset

Second pair of Canonical covariates Reveal a Second Latent Variable

The two heatmaps showing the correlations of canonical covariates and datasets show that sex is another variable that has an effect on both the datasets and CCA could capture that. To see the effect of sex, we can make a scatter plot with the second pair of canonical covariates and color it by sex variable.

plt.figure(figsize=(10,8))
sns.scatterplot(x="CCX_2",
                y="CCY_2", 
                hue="sex", data=cc_res)
plt.title('Second Pair of Canonical Covariate, corr = %.2f' %
         np.corrcoef(X_c[:, 1], Y_c[:, 1])[0, 1])
Scatter plot of Second pair of Canonical Covariates

Summary of CCA Example in Python

Summarising what we saw so far, Canonical Correlation Analysis is a great toolkit to understand tow high dimensional datasets. As a toy example with Penguins data, the post showed how to perform CCA with Scikit-learn in Python. We also saw how to to interpret and understand pairs of canonical covariates obtained by looking at the two datasets at the same time.

In this post we did not get to the math or algorithm behind CCA, it will be fun to look at that in a future post. Also, Penguin dataset was perfect to illustrate CCA and it will be lot more fun to apply CCA to a more complex/realistic high dimensional datasets. Definitely great ideas for another few posts.