Example: Probabilistic PCA + MuE (FactorMuE)ΒΆ

View FactorHMM.py on github

```# Copyright Contributors to the Pyro project.

"""
A probabilistic PCA model with a MuE observation, called a 'FactorMuE' model
[1]. This is a generative model of variable-length biological sequences (e.g.
proteins) which does not require preprocessing the data by building a
multiple sequence alignment. It can be used to infer a latent representation
of sequences and the principal components of sequence variation, while
accounting for alignment uncertainty.

An example dataset consisting of proteins similar to the human papillomavirus E6
protein, collected from a non-redundant sequence dataset using jackhmmer, can
be found at
https://github.com/debbiemarkslab/MuE/blob/master/models/examples/ve6_full.fasta

Example run:
python FactorMuE.py -f PATH/ve6_full.fasta --z-dim 2 -b 10 -M 174 -D 25
--indel-prior-bias 10. --anneal 5 -e 15 -lr 0.01 --z-prior Laplace
--jit --cuda
This should take about 8 minutes to run on a GPU. The latent space should show
multiple small clusters, and the perplexity should be around 4.

Reference:
[1] E. N. Weinstein, D. S. Marks (2021)
"A structured observation distribution for generative biological sequence
prediction and forecasting"
https://www.biorxiv.org/content/10.1101/2020.07.31.231381v2.full.pdf
"""

import argparse
import datetime
import json
import os

import matplotlib.pyplot as plt
import numpy as np
import torch

import pyro
from pyro.contrib.mue.models import FactorMuE
from pyro.optim import MultiStepLR

def generate_data(small_test, include_stop, device):
"""Generate mini example dataset."""
if small_test:
mult_dat = 1
else:
mult_dat = 10

seqs = ["BABBA"] * mult_dat + ["BAAB"] * mult_dat + ["BABBB"] * mult_dat
dataset = BiosequenceDataset(
seqs, "list", "AB", include_stop=include_stop, device=device
)

return dataset

def main(args):
if args.cpu_data or not args.cuda:
device = torch.device("cpu")
else:
device = torch.device("cuda")
if args.test:
dataset = generate_data(args.small, args.include_stop, device)
else:
dataset = BiosequenceDataset(
args.file,
"fasta",
args.alphabet,
include_stop=args.include_stop,
device=device,
)
args.batch_size = min([dataset.data_size, args.batch_size])
if args.split > 0.0:
# Train test split.
heldout_num = int(np.ceil(args.split * len(dataset)))
data_lengths = [len(dataset) - heldout_num, heldout_num]
# Specific data split seed, for comparability across models and
# parameter initializations.
pyro.set_rng_seed(args.rng_data_seed)
indices = torch.randperm(sum(data_lengths), device=device).tolist()
dataset_train, dataset_test = [
torch.utils.data.Subset(dataset, indices[(offset - length) : offset])
for offset, length in zip(
torch._utils._accumulate(data_lengths), data_lengths
)
]
else:
dataset_train = dataset
dataset_test = None

# Training seed.
pyro.set_rng_seed(args.rng_seed)

# Construct model.
model = FactorMuE(
dataset.max_length,
dataset.alphabet_length,
args.z_dim,
batch_size=args.batch_size,
latent_seq_length=args.latent_seq_length,
indel_factor_dependence=args.indel_factor,
indel_prior_scale=args.indel_prior_scale,
indel_prior_bias=args.indel_prior_bias,
inverse_temp_prior=args.inverse_temp_prior,
weights_prior_scale=args.weights_prior_scale,
offset_prior_scale=args.offset_prior_scale,
z_prior_distribution=args.z_prior,
ARD_prior=args.ARD_prior,
substitution_matrix=(not args.no_substitution_matrix),
substitution_prior_scale=args.substitution_prior_scale,
latent_alphabet_length=args.latent_alphabet,
cuda=args.cuda,
pin_memory=args.pin_mem,
)

# Infer with SVI.
scheduler = MultiStepLR(
{
"optim_args": {"lr": args.learning_rate},
"gamma": args.learning_gamma,
}
)
n_epochs = args.n_epochs
losses = model.fit_svi(
dataset_train,
n_epochs,
args.anneal,
args.batch_size,
scheduler,
args.jit,
)

# Evaluate.
train_lp, test_lp, train_perplex, test_perplex = model.evaluate(
dataset_train, dataset_test, args.jit
)
print("train logp: {} perplex: {}".format(train_lp, train_perplex))
print("test logp: {} perplex: {}".format(test_lp, test_perplex))

# Get latent space embedding.
z_locs, z_scales = model.embed(dataset)

# Plot and save.
time_stamp = datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
if not args.no_plots:
plt.figure(figsize=(6, 6))
plt.plot(losses)
plt.xlabel("step")
plt.ylabel("loss")
if not args.no_save:
plt.savefig(
os.path.join(
args.out_folder, "FactorMuE_plot.loss_{}.pdf".format(time_stamp)
)
)

plt.figure(figsize=(6, 6))
plt.scatter(z_locs[:, 0], z_locs[:, 1])
plt.xlabel(r"\$z_1\$")
plt.ylabel(r"\$z_2\$")
if not args.no_save:
plt.savefig(
os.path.join(
args.out_folder, "FactorMuE_plot.latent_{}.pdf".format(time_stamp)
)
)

if not args.indel_factor:
# Plot indel parameters. See statearrangers.py for details on the
# r and u parameters.
plt.figure(figsize=(6, 6))
insert = pyro.param("insert_q_mn").detach()
insert_expect = torch.exp(insert - insert.logsumexp(-1, True))
plt.plot(insert_expect[:, :, 1].cpu().numpy())
plt.xlabel("position")
plt.ylabel("probability of insert")
plt.legend([r"\$r_0\$", r"\$r_1\$", r"\$r_2\$"])
if not args.no_save:
plt.savefig(
os.path.join(
args.out_folder,
"FactorMuE_plot.insert_prob_{}.pdf".format(time_stamp),
)
)
plt.figure(figsize=(6, 6))
delete = pyro.param("delete_q_mn").detach()
delete_expect = torch.exp(delete - delete.logsumexp(-1, True))
plt.plot(delete_expect[:, :, 1].cpu().numpy())
plt.xlabel("position")
plt.ylabel("probability of delete")
plt.legend([r"\$u_0\$", r"\$u_1\$", r"\$u_2\$"])
if not args.no_save:
plt.savefig(
os.path.join(
args.out_folder,
"FactorMuE_plot.delete_prob_{}.pdf".format(time_stamp),
)
)

if not args.no_save:
pyro.get_param_store().save(
os.path.join(
args.out_folder, "FactorMuE_results.params_{}.out".format(time_stamp)
)
)
with open(
os.path.join(
args.out_folder,
"FactorMuE_results.evaluation_{}.txt".format(time_stamp),
),
"w",
) as ow:
ow.write("train_lp,test_lp,train_perplex,test_perplex\n")
ow.write(
"{},{},{},{}\n".format(train_lp, test_lp, train_perplex, test_perplex)
)
np.savetxt(
os.path.join(
args.out_folder, "FactorMuE_results.embed_loc_{}.txt".format(time_stamp)
),
z_locs.cpu().numpy(),
)
np.savetxt(
os.path.join(
args.out_folder,
"FactorMuE_results.embed_scale_{}.txt".format(time_stamp),
),
z_scales.cpu().numpy(),
)
with open(
os.path.join(
args.out_folder,
"FactorMuE_results.input_{}.txt".format(time_stamp),
),
"w",
) as ow:
ow.write("[args]\n")
args.latent_seq_length = model.latent_seq_length
args.latent_alphabet = model.latent_alphabet_length
for elem in list(args.__dict__.keys()):
ow.write("{} = {}\n".format(elem, args.__getattribute__(elem)))
ow.write("alphabet_str = {}\n".format("".join(dataset.alphabet)))
ow.write("max_length = {}\n".format(dataset.max_length))

if __name__ == "__main__":
# Parse command line arguments.
parser = argparse.ArgumentParser(description="Factor MuE model.")
"--test",
action="store_true",
default=False,
help="Run with generated example dataset.",
)
"--small",
action="store_true",
default=False,
help="Run with small example dataset.",
)
"-f", "--file", default=None, type=str, help="Input file (fasta format)."
)
"-a",
"--alphabet",
default="amino-acid",
help="Alphabet (amino-acid OR dna OR ATGC ...).",
)
"-zdim", "--z-dim", default=2, type=int, help="z space dimension."
)
parser.add_argument("-b", "--batch-size", default=10, type=int, help="Batch size.")
"-M",
"--latent-seq-length",
default=None,
type=int,
help="Latent sequence length.",
)
"-idfac",
"--indel-factor",
default=False,
action="store_true",
help="Indel parameters depend on latent variable.",
)
"-zdist",
"--z-prior",
default="Normal",
help="Latent prior distribution (normal or Laplace).",
)
"-ard",
"--ARD-prior",
default=False,
action="store_true",
help="Use automatic relevance detection prior.",
)
"--no-substitution-matrix",
default=False,
action="store_true",
help="Do not use substitution matrix.",
)
"-D",
"--latent-alphabet",
default=None,
type=int,
help="Latent alphabet length.",
)
"--include-stop",
default=False,
action="store_true",
help="Include stop symbol at the end of each sequence.",
)
"--indel-prior-scale",
default=1.0,
type=float,
help=("Indel prior scale parameter " + "(when indel-factor=False)."),
)
"--indel-prior-bias",
default=10.0,
type=float,
help="Indel prior bias parameter.",
)
"--inverse-temp-prior",
default=100.0,
type=float,
help="Inverse temperature prior mean.",
)
"--weights-prior-scale",
default=1.0,
type=float,
help="Factor parameter prior scale.",
)
"--offset-prior-scale",
default=1.0,
type=float,
help="Offset parameter prior scale.",
)
"--substitution-prior-scale",
default=10.0,
type=float,
help="Substitution matrix prior scale.",
)
"-lr",
"--learning-rate",
default=0.001,
type=float,
)
"--milestones",
default="[]",
type=str,
help="Milestones for multistage learning rate.",
)
"--learning-gamma",
default=0.5,
type=float,
help="Gamma parameter for multistage learning rate.",
)
"-e", "--n-epochs", default=10, type=int, help="Number of epochs of training."
)
"--anneal",
default=0.0,
type=float,
help="Number of epochs to anneal beta over.",
)
"--no-plots", default=False, action="store_true", help="Make plots."
)
"--no-save",
default=False,
action="store_true",
help="Do not save plots and results.",
)
"-outf", "--out-folder", default=".", help="Folder to save plots."
)
"--split",
default=0.2,
type=float,
help=("Fraction of dataset to holdout for testing"),
)
"--jit", default=False, action="store_true", help="JIT compile the ELBO."
)
"--cpu-data",
default=False,
action="store_true",
help="Keep data on CPU (for large datasets).",
)