Code Monkey home page Code Monkey logo

susiepca's Introduction

Project generated with PyScaffold

SuSiE-PCA

SuSiE PCA is a scalable Bayesian variable selection technique for sparse principal component analysis

SuSiE PCA is the abbreviation for the Sum of Single Effects model [1] for principal component analysis. We develop SuSiE PCA for an efficient variable selection in PCA when dealing with high dimensional data with sparsity, and for quantifying uncertainty of contributing features for each latent component through posterior inclusion probabilities (PIPs). We implement the model with the JAX library developed by Google which enable the fast training on CPU, GPU or TPU.

If you enjoy/use our software, please consider citing its publication in iScience (2023),

Yuan, D. and Mancuso, N., 2023. SuSiE PCA: A scalable Bayesian variable selection technique for principal component analysis. iScience, 26(11). DOI: https://doi.org/10.1016/j.isci.2023.108181

Documentation | Installation | Example | Notes | References | Support

Model Description

We extend the Sum of Single Effects model (i.e. SuSiE) [1] to principal component analysis. Assume $X_{N \times P}$ is the observed data, $Z_{N \times K}$ is the latent factors, and $W_{K \times P}$ is the factor loading matrix, then the SuSiE PCA model is given by:

$$X | Z,W \sim \mathcal{MN}_{N,P}(ZW, I_N, \sigma^2 I_P)$$

where the $\mathcal{MN}_{N,P}$ is the matrix normal distribution with dimension $N \times P$, mean $ZW$, row-covariance $I_N$, and column-covariance $I_P$. The column vector of $Z$ follows a standard normal distribution. The above model setting is the same as the Probabilistic PCA [2]. The most distinguished part is that we integrate the SuSiE setting into the row vector $\mathbf{w}_k$ of factor loading matrix $W$, such that each $\mathbf{w}_k$ only contains at most $L$ number of non-zero effects. That is, $$\mathbf{w}_k = \sum_{l=1}^L \mathbf{w}_{kl} $$ $$\mathbf{w}_{kl} = w_{kl} \gamma_{kl}$$ $$w_{kl} \sim \mathcal{N}(0,\sigma^2_{0kl})$$ $$\gamma_{kl} | \pi \sim \text{Multi}(1,\pi) $$

Notice that each row vector $\mathbf{w}_k$ is a sum of single effect vector $\mathbf{w}_{kl}$, which is length $P$ vector contains only one non-zero effect $w_{kl}$ and zero elsewhere. And the coordinate of the non-zero effect is determined by $\gamma_{kl}$ that follows a multinomial distribution with parameter $\pi$. By construction, each factor inferred from the SuSiE PCA will have at most $L$ number of associated features from the original data. Moreover, we can quantify the probability of the strength of association through the posterior inclusion probabilities (PIPs). Suppose the posterior distribution of $\gamma_{kl} \sim \text{Multi}(1,\mathbf{\alpha}_{kl})$, then the probability the feature $i$ contributing to the factor $\mathbf{w}_k$ is given by: $$\text{PIP}_{ki} = 1-\prod_{l=1}^L (1 - \alpha_{kli})$$ where the $\alpha_{kli}$ is the $i_{th}$ entry of the $\mathbf{\alpha}_{kl}$.

Install SuSiE PCA

The source code for SuSiE PCA is written fully in Python 3.8 with JAX (see JAX installation guide for JAX). Follow the code provided below to quickly get started using SuSiE PCA. Users can clone this github repository and install the SuSiE PCA. (Pypi installation will be supported soon).

git clone https://github.com/mancusolab/susiepca.git
cd susiepca
pip install -e .

Get Started with Example

  1. Create a python environment in the cloned repository, then simply import the SuSiE PCA
import susiepca as sp
  1. Generate a simulation data set according to the description in Simulation section from our paper. $Z_{N \times K}$ is the simulated factors matrix, $W_{K \times P}$ is the simulated loading matrix, and the $X_{N \times P}$ is the simulation data set that has $N$ observations with $P$ features.
Z, W, X = sp.sim.generate_sim(seed = 0, l_dim = 40, n_dim = 150, p_dim =200, z_dim = 4, effect_size = 1)
  1. Input the simulation data set into SuSiE PCA with number of component $K=4$ and number of single effects in each component $L=40$, or you can manipulate with those two parameters to check the model mis-specification performance. By default the data is not centered nor scaled, and the max iteration is set to be 200. Here we use the principal components extracted from traditional PCA results as the initialization of mean of $Z$.
results = sp.infer.susie_pca(X, z_dim = 4, l_dim = 40, max_iter=200)

The returned "results" contain 5 different objects:

  • params: an dictionary that saves all the updated parameters from the SuSiE PCA.
  • elbo_res: the value of evidence lower bound (ELBO) from the last iteration.
  • pve: a length $K$ ndarray contains the percent of variance explained (PVE) by each component.
  • pip: the $K$ by $P$ ndarray that contains the posterior inclusion probabilities (PIPs) of each feature contribution to the factor.
  • W: the posterior mean of loading matrix which is also a $K$ by $P$ ndarray.
  1. To examine the model performance, one straitforward way is to draw and compare the heatmap of the true loading matrix and estimate loading matrix using seaborn:
import seaborn as sns

# specify the palatte for heatmap
div = sns.diverging_palette(250, 10, as_cmap=True)

# Heatmap of true loading matrix
sns.heatmap(W, cmap = div, fmt = ".2f",center = 0)

# Heatmap of estimate loading matrix
W_hat = results.W
sns.heatmap(W_hat, cmap = div, fmt = ".2f", center = 0)

# Heatmap of PIPs
pip = results.pip
sns.heatmap(pip, cmap = div, fmt = ".2f", center = 0)

To mathmatically compute the Procrustes error of the estimate loading matrix, you need to install the Procruste package to solve the rotation problem (see procrustes installation guide for Procrustes method). Once the loading matrix is rotated to its original direction, one can compute the Procrustes error and look at heatmap as following:

import procrustes
import numpy as np

# perform procrustes transformation
proc_trans_susie = procrustes.orthogonal(np.asarray(W_hat.T), np.asarray(W.T), scale=True)
print(f"The Procrustes error for the loading matrix is {proc_trans_susie.error}")

# Heatmap of transformed loading matrix
W_trans = proc_trans_susie.t.T @ W_hat
sns.heatmap(W_trans, cmap = div, fmt = ".2f", center = 0)

You can also calculate the relative root mean square error (RRMSE) to assess the model prediction performance

from susiepca import metrics

# compute the predicted data
X_hat = results.params.mu_z @ W_hat

# compute the RRMSE
rrmse_susie = metrics.mse(X, X_hat)
  1. Finally we also provide a neat function to compute a $\rho-$ level credible sets (CS). The cs returned by the function is composed of $L \times K$ credible sets, each of them contain a subset of variables that cumulatively explain at least $\rho$ of the posterior density.
cs = sp.metrics.get_credset(results.params.alpha, rho=0.9)

Notes

JAX uses 32-bit precision by default. To enable 64-bit precision before calling susiepca add the following code:

import jax
jax.config.update("jax_enable_x64", True)

Similarly, the default computation device for JAX is set by environment variables (see here). To change this programmatically before calling susiepca add the following code:

import jax
platform = "gpu" # "gpu", "cpu", or "tpu"
jax.config.update("jax_platform_name", platform)

References

[1](1, 2) Wang, G., Sarkar, A., Carbonetto, P. and Stephens, M. (2020), A simple new approach to variable selection in regression, with application to genetic fine mapping. J. R. Stat. Soc. B, 82: 1273-1300. https://doi.org/10.1111/rssb.12388
[2]Tipping, M.E. and Bishop, C.M. (1999), Probabilistic Principal Component Analysis. Journal of the Royal Statistical Society: Series B (Statistical Methodology), 61: 611-622. https://doi.org/10.1111/1467-9868.00196

Support

Please report any bugs or feature requests in the Issue Tracker. If you have any questions or comments please contact [email protected] and/or [email protected].


This project has been set up using PyScaffold 4.1.1. For details and usage information on PyScaffold see https://pyscaffold.org/.

susiepca's People

Contributors

dong555 avatar mauriziopaul avatar quattro avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar

susiepca's Issues

CredSet computation using vmap & jit

We should explore possibilities reimplementing get_credset using jax.vmap and jax.jit. It should be possible to define a lower-level, jitted function that operates over p-features to compute a credible set, and then use vmap to expand that over both K and L dimensions for us. Since users likely only call get_credset once, it doesn't make sense to try and jit the entire thing, but this way, the shape is fixed over P and should be faster in principle.

I/O helpers

We should add some I/O helpers to streamline saving output to disk in a friendly manner for users.

For example, outputting the credible set dictionary to disk as a DataFrame, and similar constructs for posterior parameters.

Numerically Stable PIP computation

If L is large, computing PIPs can be numerically unstable. Let's switch to something like,

pip = 1 - jnp.exp(jnp.sum(jnp.log1p(-alpha), axis=-1))

Move print statements to logging

As title implies, we should move from explicit printing when verbose = True to using a logger. This is a bit more flexible, and less likely to corrupt/interact with piping in UNIX systems.

Documentation Issues

Catch all for documentation issues

  • Fix expansion of jnp.ndarray to jax._src.basearray.Array in class definitions
  • Remove display of __getnewargs__ in class definitions
  • Hide _cls in class definitions
  • Fix missing return types for functions in expanded description

Support sparse format matrices

It could be useful to additionally provide support for a sparse-representation of X using BCOO types in JAX. We might be able to use sparsify to get sparsity-handling for free on internal routines, and then provide an interface like susie_pca_sparse, or perhaps add a flag to susie_pca.

Explore `jax.numpy.linalg.svd` as replacement for `sklearn.decomposition.PCA`

We currently make use of sklearn.decomposition.PCA as an optional initialization of the mean parameters for the latent factor space. This is a helpful approach that performs well in practice, but depending on the quite large sklearn package for a single function may be overkill. JAX has an SVD implementation under jax.numpy.linalg.svd that we could explore as a replacement which would drop the dependency requirement on sklearn.

Center/Rescale

Let's add a user flag to automatically center and scale/standardize the input data before performing inference.

Add unit tests

We need to begin adding unit tests for sanity checks. Should be straightforward given the testing framework is ready to go.

AVX install on M1

Installing JAX via pip on M1 chips still has issues. The link suggests using mini-forge to install JAX for M1 chips, so we should add a line or two to the README to indicate that.

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.