import itertools
import time
import datetime
import torch
import torch.nn as tnn
import torch.nn.functional as F
import numpy as np
from torch.autograd import Variable
import torch.optim as topti
import pandas as pd
import copy
from tqdm import tqdm
from scipy.stats import spearmanr
import mubind as mb
[docs]
class Mubind(tnn.Module):
"""
Implements the MUBIND model.
Args:
datatype (String): Type of the experimental data. "selex" and "pbm" are supported.
Keyword Args:
n_rounds (int): Necessary for SELEX data: Number of rounds to be predicted.
init_random (bool): Use a random initialization for all parameters. Default: True
padding_const (double): Value for padding DNA-seqs. Default: 0.25
use_dinuc (bool): Use dinucleotide contributions (not fully implemented for all kind of models). Default: False
enr_series (bool): Whether the data should be handled as enrichment series. Default: True
n_batches (int): Number of batches that will occur in the data. Default: 1
ignore_kernel (list[bool]): Whether a kernel should be ignored. Default: None.
kernels (List[int]): Size of the binding modes (0 indicates non-specific binding). Default: [0, 15]
n_kernels (int). Number of filters to be used (including non-specific binding, as a constant).
Default: 2 (ns-binding, and one filter)
init_random (bool): Use a random initialization for all parameters. Default: True
n_proteins (int): Number of proteins in the dataset. Either n_proteins or n_batches may be used. Default: 1
bm_generator (torch.nn.Module): PyTorch module which has a weight matrix as output.
add_intercept (bool): Whether an intercept is used in addition to the predicted binding modes. Default: True
"""
[docs]
def __init__(self, datatype, **kwargs):
super().__init__()
self.device = kwargs.get('device')
if self.device is None:
# Use a GPU if available, as it should be faster.
self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print("Using device: " + str(self.device))
self.optimize_exp_barrier = kwargs.get('optimize_exp_barrier', False)
self.optimize_kernel_rel = kwargs.get('optimize_kernel_rel', False)
self.optimize_sym_weight = kwargs.get('optimize_sym_weight', True)
self.optimize_log_dynamic = kwargs.get('optimize_log_dynamic', False)
self.optimize_prob_act = kwargs.get('optimize_prob_act', False)
self.datatype = datatype.lower()
assert self.datatype in ["selex", "pbm"]
self.padding_const = kwargs.get("padding_const", 0.25)
self.use_mono = True
self.use_dinuc = kwargs.get("use_dinuc", True)
if "kernels" not in kwargs and "n_kernels" not in kwargs:
kwargs["kernels"] = [0, 15]
kwargs["n_kernels"] = len(kwargs["kernels"])
elif "n_kernels" not in kwargs:
kwargs["n_kernels"] = len(kwargs["kernels"])
elif "kernels" not in kwargs and "bm_generator" not in kwargs:
kwargs["kernels"] = [0] + [15] * (kwargs["n_kernels"] - 1)
elif "bm_generator" not in kwargs:
assert len(kwargs["kernels"]) == kwargs["n_kernels"]
self.kernels = kwargs.get("kernels")
if self.datatype == "pbm":
kwargs["target_dim"] = kwargs.get('target_dim', 1)
elif self.datatype == "selex":
if "n_rounds" in kwargs:
kwargs["target_dim"] = kwargs["n_rounds"]
elif "target_dim" in kwargs:
kwargs["n_rounds"] = kwargs["target_dim"] - 1
else:
print("n_rounds must be provided.")
assert False
# assert not ("n_batches" in kwargs and "n_proteins" in kwargs)
# only keep one padding equals to the length of the max kernel
if self.kernels is None:
self.padding = tnn.ConstantPad2d((12, 12, 0, 0), self.padding_const)
else:
self.padding = tnn.ConstantPad2d((max(self.kernels) - 1, max(self.kernels) - 1, 0, 0), self.padding_const)
if "bm_generator" in kwargs and kwargs["bm_generator"] is not None:
self.binding_modes = BindingModesPerProtein(**kwargs)
else:
self.binding_modes = BindingLayer(**kwargs)
self.activities = ActivitiesLayer(**kwargs)
if self.datatype == "selex":
self.graph_module = GraphLayer(**kwargs)
self.kernel_rel = None
if kwargs.get('kernel_sim') is not None:
self.kernel_rel = torch.tensor(kwargs.get('kernel_sim'))
self.n_kernels = kwargs['n_kernels']
self.best_model_state = None
self.best_loss = None
self.loss_history = []
self.loss_history_sym_weights = []
self.loss_history_log_dynamic = []
self.r2_history = []
self.loss_color = []
self.total_time = 0
@staticmethod
def make_model(train, n_kernels, criterion, init_random=True, **kwargs):
# verbose print declaration
if kwargs.get('verbose'):
def vprint(*args, **kwargs):
print(*args, **kwargs)
else:
vprint = lambda *a, **k: None # do-nothing function
if (isinstance(train, list) and 'SelexDataset' in str(type(train[0].dataset))) or\
('SelexDataset' in str(type(train.dataset))):
if criterion is None:
criterion = mb.tl.PoissonLoss()
if not isinstance(train, list):
n_rounds = train.dataset.n_rounds
n_batches = train.dataset.n_batches
enr_series = train.dataset.enr_series
else:
n_rounds = max([max(t.dataset.n_rounds) for t in train])
n_batches = len(train)
enr_series = True
n_batches = kwargs.get('n_batches', n_batches)
if 'n_batches' in kwargs:
del kwargs['n_batches']
n_rounds = kwargs.get('n_rounds', n_rounds)
if 'n_rounds' in kwargs:
del kwargs['n_rounds']
use_mono = kwargs.get('use_mono', True)
use_dinuc = kwargs.get('use_dinuc', True)
dinuc_mode = kwargs.get('dinuc_mode', 'local')
vprint("# rounds", set(n_rounds) if not isinstance(n_rounds, int) else n_rounds)
vprint("# rounds", set(n_rounds) if not isinstance(n_rounds, int) else n_rounds)
vprint('# use_mono', use_mono)
vprint('# use_dinuc', use_dinuc)
vprint('# dinuc_mode', dinuc_mode)
vprint("# batches", n_batches)
vprint("# kernels", n_kernels)
vprint("# initial w", kwargs.get('w', 20))
vprint("# enr_series", enr_series)
vprint('# opt kernel shift', kwargs.get('opt_kernel_shift', False))
vprint('# opt kernel length', kwargs.get('opt_kernel_length', False))
vprint("# custom kernels", kwargs.get('kernels'))
kwargs['kernels'] = kwargs.get('kernels', [0] + [kwargs.get('w', 20)] * (n_kernels - 1))
model = mb.models.Mubind(
datatype="selex",
n_rounds=n_rounds,
n_batches=n_batches,
init_random=init_random,
enr_series=enr_series,
**kwargs,
) # .to(self.device)
elif isinstance(train.dataset, mb.datasets.PBMDataset) or isinstance(train.dataset,
mb.datasets.GenomicsDataset):
if criterion is None:
criterion = mb.tl.MSELoss()
if isinstance(train.dataset, mb.datasets.PBMDataset):
n_proteins = train.dataset.n_proteins
else:
n_proteins = kwargs.get('n_proteins', train.dataset.n_cells)
vprint("# proteins", n_proteins)
if kwargs.get('joint_learning', False) or n_proteins == 1:
kwargs['kernels'] = kwargs.get('kernels', [0] + [kwargs.get('w', 20)] * (n_kernels - 1))
model = mb.models.Mubind(
datatype="pbm",
init_random=init_random,
n_batches=n_proteins,
**kwargs,
) # .to(self.device)
else:
bm_generator = mb.models.BMCollection(n_proteins=n_proteins, n_kernels=n_kernels,
init_random=init_random)
model = mb.models.Mubind(
datatype="pbm",
init_random=init_random,
n_proteins=n_proteins,
bm_generator=bm_generator,
n_kernels=n_kernels,
**kwargs,
) # .to(self.device)
elif isinstance(train.dataset, mb.datasets.ResiduePBMDataset):
model = mb.models.Mubind(
datatype="pbm",
init_random=init_random,
bm_generator=mb.models.BMPrediction(num_classes=1, input_size=21, hidden_size=2, num_layers=1,
seq_length=train.dataset.get_max_residue_length()),
**kwargs,
) # .to(self.device)
else:
assert False # not implemented yet
# set criterion
model.criterion = criterion
device = str(kwargs.get('device'))
if 'cuda' in device:
return model.cuda()
return model
def forward(self, mono, **kwargs):
# mono_rev=None, di=None, di_rev=None, batch=None, countsum=None, residues=None, protein_id=None):
mono_rev = kwargs.get("mono_rev", None)
di = kwargs.get("di", None)
di_rev = kwargs.get("di_rev", None)
mono = self.padding(mono)
if mono_rev is None:
mono_rev = mb.tl.mono2revmono(mono)
else:
mono_rev = self.padding(mono_rev)
del kwargs['mono_rev'] # for later function calls, and to avoid duplicates
# prepare the dinucleotide objects if we need them
if self.use_dinuc:
if di is None:
di = mb.tl.mono2dinuc(mono)
if di_rev is None:
di_rev = mb.tl.mono2dinuc(mono_rev)
if self.binding_modes.use_conv1d:
di = torch.unsqueeze(di, 1)
di_rev = torch.unsqueeze(di_rev, 1)
kwargs["di"] = di
kwargs["di_rev"] = di_rev
# unsqueeze mono after preparing di and unsqueezing mono
# print(mono.shape)
if not self.binding_modes.use_conv1d:
mono = torch.unsqueeze(mono, 1)
mono_rev = torch.unsqueeze(mono_rev, 1)
# print(mono.shape)
# assert False
# binding_per_mode: matrix of size [batchsize, number of binding modes]
binding_per_mode = self.binding_modes(mono=mono, mono_rev=mono_rev, **kwargs) # sequences x filters (Se, F)
binding_scores = self.activities(binding_per_mode, **kwargs) # sequences x samples (Se, F) * (F, Sa) = (Se, Sa)
# print('mode')
# print(binding_per_mode.shape)
# print('scores')
# print(binding_scores)
return_binding_scores = kwargs.get('return_binding_scores', False)
return_binding_per_mode = kwargs.get('return_binding_per_mode', False)
# print('here...', return_binding_scores) #kwargs.get('return_binding_modes', False))
if self.datatype == "pbm" or return_binding_scores:
return binding_scores
elif return_binding_per_mode:
return binding_per_mode
elif self.datatype == "selex":
return self.graph_module(binding_scores, **kwargs)
else:
return None # this line should never be called
def set_seed(self, seed, index, max_value=0, min_value=-1):
if isinstance(self.binding_modes, BindingLayer):
self.binding_modes.set_seed(seed, index, max_value, min_value)
else:
print("Setting a seed is not possible for that kind of model.")
assert False
def modify_kernel(self, index=None, shift=0, expand_left=0, expand_right=0, device=None):
self.binding_modes.modify_kernel(index, shift, expand_left, expand_right, device)
def set_kernel_weights(self, weight, index):
assert weight.shape == self.conv_mono[index].weight.shape
self.conv_mono[index].weight = weight
def update_grad(self, index, value):
self.binding_modes.update_grad(index, value)
self.activities.update_grad(index, value)
def update_grad_activities(self, index, value):
self.activities.update_grad(index, value)
def update_grad_etas(self, value):
self.graph_module.update_grad_etas(value)
def set_ignore_kernel(self, ignore_kernel):
self.activities.set_ignore_kernel(ignore_kernel)
def get_ignore_kernel(self):
return self.activities.get_ignore_kernel()
def get_kernel_width(self, index):
return self.binding_modes.get_kernel_width(index)
def get_kernel_weights(self, index, **kwargs):
return self.binding_modes.get_kernel_weights(index, **kwargs)
def get_log_activities(self):
return self.activities.get_log_activities()
def get_log_etas(self):
assert self.datatype == "selex"
return self.graph_module.get_log_etas()
def dirichlet_regularization(self):
return self.binding_modes.dirichlet_regularization()
def weight_distances_min_k(self, min_k=5, exp_delta=4):
d = []
for a, b in itertools.combinations(self.conv_mono[1:], r=2):
a = a.weight
b = b.weight
min_w = min(a.shape[-1], b.shape[-1])
# print(min_w)
lowest_d = -1
for k in range(5, min_w):
# print(k)
for i in range(0, a.shape[-1] - k + 1):
ai = a[:, :, :, i: i + k]
for j in range(0, b.shape[-1] - k + 1):
bi = b[:, :, :, j: j + k]
bi_rev = torch.flip(bi, [3])[:, :, [3, 2, 1, 0], :]
d.append(((bi - ai) ** 2).sum().cpu().detach() / bi.shape[-1])
d.append(((bi_rev - ai) ** 2).sum().cpu().detach() / bi.shape[-1])
if lowest_d == -1 or d[-1] < lowest_d or d[-2] < lowest_d:
next_d = min(d[-2], d[-1])
# print(i, i + k, j, j + k, d[-2], d[-1])
lowest_d = next_d
if len(d) == 0:
print(self.conv_mono)
assert False
return torch.exp(exp_delta - min(d))
def loss_kernel_rel(self, log=False):
"""
Return a loss associated to the similarity of weights that are assumed to be similar
"""
loss = 0
# relationship terms in the matrix
if self.kernel_rel is not None and self.optimize_kernel_rel:
# print('distances')
# monos = [b.weight for b in self.binding_modes.conv_mono[1:]]
# monos = [m.cpu().detach().numpy().squeeze() for m in monos]
# # print(monos)
# print('# of kernels', len(monos))
# res = mb.tl.calculate_distances([m.copy() for m in monos], best=True, full=True,
# filter_neg_weights=False, min_w_sum=-10000)
# distances_kernels = res[~pd.isnull(res['id'])]
# d = distances_kernels.pivot('a', 'b', 'distance')
#
# print(d)
# print(self.kernel_rel)
# dist_loss = np.nansum(d.values[self.kernel_rel[d.index - 1, :-1] == 1])
dist_loss = 0
for ai, a in enumerate(self.binding_modes.conv_mono[1:]):
for bi, b in enumerate(self.binding_modes.conv_mono[1:]):
if ai >= bi:
continue
if self.kernel_rel[ai, bi] == 1:
d = torch.norm(a.weight - b.weight)
if log:
print(ai, bi, d)
dist_loss += d
# assert False
# print(distances_kernels.head())
# calculate the distances between kernels
# print('here', dist_loss)
# assert False
loss += dist_loss
return loss
def print_weights(self):
torch.set_printoptions(profile="default") # reset
torch.set_printoptions(linewidth=500)
# torch.set_printoptions(threshold=10_000)
print('\nmono')
for b in self.binding_modes.conv_mono:
if hasattr(b, 'weight'):
print(b.weight)
print('\ndinuc')
for b in self.binding_modes.conv_di:
if hasattr(b, 'weight'):
print(b.weight)
print('\nactivities')
print(self.activities.get_log_activities())
print('\netas')
print(self.graph_module.log_etas)
def loss_exp_barrier(self, exp_max):
"""
We add an exponential negative term, to force weights to be more positive than negative
"""
pos_weight_sum_abs_mono = [b.weight.sum(axis=2).abs() for b in self.binding_modes.conv_mono[1:]]
mono = sum([torch.exp(p - exp_max).sum() for p in pos_weight_sum_abs_mono])
di = None
if self.use_dinuc and self.binding_modes.dinuc_mode == 'local':
pos_weight_sum_abs_di = [b.weight.sum(axis=2).abs() for b in self.binding_modes.conv_di[1:]]
di = sum([torch.exp(p - exp_max).sum() for p in pos_weight_sum_abs_di])
elif self.use_dinuc and self.binding_modes.dinuc_mode == 'full':
di = []
for b in self.binding_modes.conv_di[1:]:
for b2 in b:
di.append(b2.weight.sum(axis=2).abs().sum())
di = sum(di)
return mono + di
return mono
def loss_log_dynamic(self):
if not hasattr(self.graph_module, 'conn_sparse'):
return 0
conn = self.graph_module.conn_sparse
# log_dynamic = self.graph_module.D_tril # log_dynamic
# return 100
# return torch.abs(torch.sparse.sum(self.graph_module.D_tril))
# log_dynamic = self.graph_module.D_tril.coalesce().values() # self.graph_module.D_tril # log_dynamic
log_dynamic = self.graph_module.log_dynamic
idx = conn.indices()
conn_vals = conn.values()
pos = torch.arange(idx.size(1), device=self.device)
# prepare combinations based on common indexes
uniq_idx = idx.unique()
all_combinations = []
for u_idx in uniq_idx:
# at least one common index has to be present in the position retrieved
# print(pos.device, idx.device)
sub_pos = pos[(idx[0] == u_idx) | (idx[1] == u_idx)]
c = torch.combinations(sub_pos, r=2)
all_combinations.append(c)
all_pos = torch.cat(all_combinations)
pairs = idx[:, all_pos].reshape(all_pos.shape[0], 4)
# pairs = idx[all_pos].reshape(all_pos.shape[0], 4)
mask1 = (pairs[:, 0] == pairs[:, 2]) | (pairs[:, 1] == pairs[:, 3])
mask2 = (pairs[:, 0] != pairs[:, 1]) & (pairs[:, 2] != pairs[:, 3])
mask3 = ~((pairs[:, 0] == pairs[:, 2]) & (pairs[:, 1] != pairs[:, 3]))
pairs = pairs[mask1 & mask2 & mask3]
all_pos = all_pos[mask1 & mask2 & mask3]
a = log_dynamic[all_pos[:, 0]]
b = log_dynamic[all_pos[:, 1]]
w_err = (a - b) ** 2
conn_weight = conn_vals[all_pos[:, 0]] * conn_vals[all_pos[:, 1]]
score = w_err * conn_weight
w_err = score.sum() / idx.shape[0]
# return sum(w_err + torch.rand(1, device=self.device))
return w_err
def exp_barrier(self, exp_max=40):
out = 0
for p in self.parameters():
out += torch.sum(torch.exp(p - exp_max) + torch.exp(-p - exp_max))
return out
def loss_kernel_symmetrical_weights(self):
"""
This loss calculates the squared sum of columns per position, and it is useful to detect
strong positive/negative biases per position or in the whole object.
"""
mono_sym_weight = sum([(b.weight.sum(axis=2) ** 2).sum() for b in self.binding_modes.conv_mono[1:]])
di_sym_weight = None
if self.use_dinuc and self.binding_modes.dinuc_mode == 'local':
di_sym_weight = sum([(b.weight.sum(axis=2) ** 2).sum() for b in self.binding_modes.conv_di[1:]])
elif self.use_dinuc and self.binding_modes.dinuc_mode == 'full':
di_sym_weight = []
for b in self.binding_modes.conv_di[1:]:
for b2 in b:
di_sym_weight.append((b2.weight.sum(axis=2) ** 2).sum())
di_sym_weight = sum(di_sym_weight)
return mono_sym_weight + di_sym_weight
return mono_sym_weight
def loss_prob_act(self):
prob = torch.cat((torch.ones(1, device=self.device), self.binding_modes.prob_act))
prob = torch.sigmoid(torch.exp(prob))
return torch.sum(prob)
# if early_stopping is positive, training is stopped if over the length of early_stopping no improvement happened or
# num_epochs is reached.
def optimize_simple(self,
dataloader,
optimiser,
# reconstruction_crit,
num_epochs=15,
early_stopping=-1,
dirichlet_regularization=0,
exp_max=40, # if this value is negative, the exponential barrier will not be used.
log_each=-1,
verbose=0,
r2_per_epoch=False,
**kwargs,
):
# global loss_history
r2_history = []
loss_history = []
loss_history_sym_weights = []
loss_history_log_dynamic = []
best_loss = None
best_epoch = -1
if verbose != 0:
print(
"optimizer: ",
str(type(optimiser)).split('.')[-1].split('\'>')[0],
", criterion:",
str(type(self.criterion)).split('.')[-1].split('\'>')[0],
"\nepochs:",
num_epochs,
"\nearly_stopping:",
early_stopping,
)
for f in ["lr", "weight_decay"]:
if f in optimiser.param_groups[0]:
if verbose != 0:
print("%s=" % f, optimiser.param_groups[0][f], end=", ")
if verbose != 0:
print("dir weight=", dirichlet_regularization)
is_lbfgs = "LBFGS" in str(optimiser)
store_rev = dataloader.dataset.store_rev if not isinstance(dataloader, list) else dataloader[
0].dataset.store_rev
t0 = time.time()
n_batches = len(list(enumerate(dataloader)))
# the total number of trials
n_trials = None
if isinstance(dataloader, list) and hasattr(dataloader[0].dasaset, 'signal'):
n_trials = sum([d.dataset.signal.shape[0] for d in dataloader])
elif isinstance(dataloader, list) and hasattr(dataloader[0].dataset, 'rounds'):
n_trials = sum([d.dataset.rounds.shape[0] for d in dataloader])
else:
n_trials = sum(
[d.dataset.rounds.shape[0] if hasattr(d.dataset, 'rounds') else d.dataset.signal.shape[0] for d in
[dataloader]])
use_tqdm = kwargs.get('use_tqdm', True)
# print('use_tqdm', use_tqdm)
for epoch in tqdm(range(num_epochs)) if use_tqdm else range(num_epochs):
# print('train')
self.train()
# for epoch in range(num_epochs):
running_loss = 0
running_loss_sym_weights = 0
running_loss_log_dynamic = 0
running_rec = 0
# if dataloader is a list of dataloaders, we have to iterate through those
dataloader_queries = dataloader if isinstance(dataloader, list) else [dataloader]
# print(len(dataloader_queries))
for data_i, next_dataloader in enumerate(dataloader_queries):
# print(data_i, next_dataloader, len(next_dataloader))
for i, batch in enumerate(next_dataloader):
# print(i, 'batches out of', n_batches)
# Get a batch and potentially send it to GPU memory.
mononuc = batch["mononuc"].to(self.device)
b = batch["batch"].to(self.device) if "batch" in batch else None
rounds = batch["rounds"].to(self.device) if "rounds" in batch else None
# print(rounds.shape)
if next_dataloader.dataset.use_sparse:
rounds = rounds.squeeze(1)
n_rounds = batch["n_rounds"].to(self.device) if "n_rounds" in batch else None
countsum = batch["countsum"].to(self.device) if "countsum" in batch else None
residues = batch["residues"].to(self.device) if "residues" in batch else None
protein_id = batch["protein_id"].to(self.device) if "protein_id" in batch else None
inputs = {"mono": mononuc, "batch": b, "countsum": countsum}
if store_rev:
mononuc_rev = batch["mononuc_rev"].to(self.device)
inputs["mono_rev"] = mononuc_rev
if residues is not None:
inputs["residues"] = residues
if protein_id is not None:
inputs["protein_id"] = protein_id
# if not selex, do not scale by overall signal
inputs['scale_countsum'] = self.datatype == 'selex'
loss = None
loss_sym_weights = None
loss_log_dynamic = None
if is_lbfgs:
def closure():
optimiser.zero_grad()
# this statement here is mandatory to
outputs = self.forward(**inputs)
# weight_dist = model.weight_distances_min_k()
if dirichlet_regularization == 0:
dir_weight = 0
else:
dir_weight = dirichlet_regularization * self.dirichlet_regularization()
# loss = criterion(outputs, rounds) + weight_dist + dir_weight
loss = self.criterion(outputs, rounds) + dir_weight
# if exp_max >= 0:
# loss += self.exp_barrier(exp_max)
loss.backward() # retain_graph=True)
loss_kernel_rel = self.loss_kernel_rel()
loss_neg_weights = self.loss_exp_barrier(exp_max=exp_max)
loss_sym_weights = self.loss_kernel_symmetrical_weights()
loss_log_dynamic = self.loss_log_dynamic()
if self.optimize_kernel_rel:
loss += loss_kernel_rel
if self.optimize_exp_barrier:
loss += loss_neg_weights
if self.optimize_sym_weight:
loss += loss_sym_weights
if self.optimize_log_dynamic:
loss += loss_log_dynamic
if self.optimize_prob_act:
loss_prob_act = self.loss_prob_act()
loss += loss_prob_act
return loss
loss = optimiser.step(closure) # Step to minimise the loss according to the gradient.
else:
# PyTorch calculates gradients by accumulating contributions to them (useful for
# RNNs). Hence we must manully set them to zero before calculating them.
optimiser.zero_grad(set_to_none=None)
# outputs, reconstruction = model(inputs) # Forward pass through the network.
outputs = self.forward(**inputs) # Forward pass through the network.
# print(outputs.shape, rounds.shape)
# print(torch.cat([outputs, rounds], axis=1)[:3])
# assert False
# weight_dist = model.weight_distances_min_k()
if dirichlet_regularization == 0:
dir_weight = 0
else:
dir_weight = dirichlet_regularization * self.dirichlet_regularization()
# if the dataloader is a list, then we know the output shape directly by rounds
if isinstance(dataloader, list):
loss = self.criterion(outputs[:, :rounds.shape[1]], rounds)
else:
# define a mask to remove items on a rounds specific manner
if n_rounds is not None and len(set(dataloader.dataset.n_rounds)) != 1:
mask = torch.zeros((n_rounds.shape[0], outputs.shape[1]), dtype=torch.bool,
device=self.device)
for mi in range(mask.shape[1]):
mask[:, mi] = ~(n_rounds - 1 < i)
loss = self.criterion(outputs[mask], rounds[mask])
else:
loss = self.criterion(outputs, rounds)
# skip loss
loss += dir_weight
loss_kernel_rel = self.loss_kernel_rel()
loss_neg_weights = self.loss_exp_barrier(exp_max=exp_max)
loss_sym_weights = self.loss_kernel_symmetrical_weights()
loss_log_dynamic = self.loss_log_dynamic()
if self.optimize_kernel_rel:
loss += loss_kernel_rel
if self.optimize_exp_barrier:
loss += loss_neg_weights
if self.optimize_sym_weight:
# print(loss_sym_weights)
loss += loss_sym_weights
if self.optimize_log_dynamic:
# print(loss_sym_weights)
# print(loss_log_dynamic)
loss += loss_log_dynamic
if self.optimize_prob_act:
loss_prob_act = self.loss_prob_act()
loss += loss_prob_act
loss.backward() # Calculate gradients.
optimiser.step()
running_loss += loss.item()
running_loss_sym_weights += loss_sym_weights if loss_sym_weights is not None else 0
running_loss_log_dynamic += loss_log_dynamic if loss_log_dynamic is not None else 0
# running_rec += reconstruction_crit(reconstruction, residues).item()
loss_final = running_loss / len(dataloader)
loss_final_sym_weights = running_loss_sym_weights / len(dataloader)
loss_final_log_dynamic = running_loss_log_dynamic / len(dataloader)
if log_each != -1 and epoch > 0 and (epoch % log_each == 0):
# self.print_weights()
if verbose != 0:
r2_epoch = None
if r2_per_epoch:
r2_epoch = mb.tl.scores(self, dataloader)['r2_counts']
# r2_history.append(mb.pl.kmer_enrichment(self, dataloader, k=8, show=False))
total_time = time.time() - t0
time_epoch_1k = (total_time / max(epoch, 1) / n_trials * 1e3)
print(
"Epoch: %2d, Loss: %.3f, %s" % (epoch + 1, loss_final,
'R2: %.3f, ' % r2_epoch if r2_epoch is not None else ''),
"best epoch: %i, " % best_epoch,
"secs per epoch: %.3f s, " % ((time.time() - t0) / max(epoch, 1)),
"secs epoch*1k trials: %.3fs" % time_epoch_1k,
"curr time:", datetime.datetime.now(),
)
if kwargs.get('print_weights', False):
self.print_weights()
if best_loss is None or loss_final < best_loss:
best_loss = loss_final
best_epoch = epoch
self.best_model_state = copy.deepcopy(self.state_dict())
self.best_loss = best_loss
# print("Epoch: %2d, Loss: %.3f" % (epoch + 1, running_loss / len(train_dataloader)))
loss_history.append(float(loss_final))
loss_history_sym_weights.append(float(loss_final_sym_weights))
loss_history_log_dynamic.append(float(loss_final_log_dynamic))
# change versus last loss
rel_chg_early_stop= kwargs.get('rel_chg_early_stop', 1e-5)
early_stop_rel_chg = False
if len(loss_history) >= 2 and loss_history[-1] < loss_history[-2]:
rel_chg = (loss_history[-2] - loss_history[-1]) / loss_history[-1] * 100
# print(rel_chg, rel_chg_early_stop)
if rel_chg < rel_chg_early_stop:
early_stop_rel_chg = True
# model.crit_history.append(crit_final)
# model.rec_history.append(rec_final)
if early_stopping > 0 and (epoch >= best_epoch + early_stopping or early_stop_rel_chg):
if verbose != 0:
r2_epoch = None
if r2_per_epoch:
r2_epoch = mb.tl.scores(self, dataloader)['r2_counts']
# r2_history.append(mb.pl.kmer_enrichment(self, dataloader, k=8, show=False))
total_time = time.time() - t0
time_epoch_1k = (total_time / max(epoch, 1) / n_trials * 1e3)
print(
"Epoch: %2d, Loss: %.3f, %s" % (epoch + 1, loss_final,
'R2: %.3f, ' % r2_epoch if r2_epoch is not None else ''),
"best epoch: %i, " % best_epoch,
"secs per epoch: %.3fs, " % ((time.time() - t0) / max(epoch, 1)),
"secs epoch*1k trials: %.3fs," % time_epoch_1k,
"curr time:", datetime.datetime.now(),
)
if verbose != 0:
print("early stop!")
break
# Print if profiling included. Temporarily removed profiling to save memory.
# print('Profiling epoch:')
# print(prof.key_averages().table(sort_by="cpu_time_total", row_limit=25))
# prof.export_chrome_trace(f'profile_{epoch}.json')
total_time = time.time() - t0
self.total_time += total_time
if verbose:
r2_epoch = None
if r2_per_epoch:
r2_epoch = mb.tl.scores(self, dataloader)['r2_counts']
self.r2_final = r2_epoch
# r2_history.append(mb.pl.kmer_enrichment(self, dataloader, k=8, show=False))
print('Current time:', datetime.datetime.now())
print('\tLoss: %.3f %s' % (loss_final, ', R2: %.3f' % r2_epoch if r2_epoch is not None else ''))
print(f'\tTraining time (model/function): (%.3fs / %.3fs)' % (self.total_time, total_time))
print("\t\tper epoch (model/function): (%.3fs/ %.3fs)" %
((self.total_time / max(epoch, 1)), (total_time / max(epoch, 1))))
print('\t\tper 1k samples: %.3fs' % (total_time / max(epoch, 1) / n_trials * 1e3))
self.loss_history += loss_history
self.loss_history_sym_weights += loss_history_sym_weights
self.loss_history_log_dynamic += loss_history_log_dynamic
self.r2_history += r2_history
def corr_etas_libsizes(self, train):
etas = self.get_log_etas().detach().cpu().numpy().flatten() if self.device != 'cpu' else self.get_log_etas().detach().flatten()
lib_sizes = train.dataset.rounds.sum(axis=0) if self.device != 'cpu' else train.dataset.rounds.sum(axis=0).flatten()
# print('etas', etas, etas.shape, etas.device)
# print('libsizes', lib_sizes, lib_sizes.shape)
return 'etas corr with lib_sizes (before refinement)', spearmanr(etas, lib_sizes)
def optimize_iterative(self,
train,
# min_w=10,
max_w=20,
n_epochs=100, # int or list
early_stopping=15, # int or list
log_each=10,
opt_kernel_shift=True,
opt_kernel_length=True,
opt_one_step=False,
expand_length_max=3,
expand_length_step=1,
show_logo=False,
optimiser=None,
seed=None,
init_random=False,
joint_learning=False,
ignore_kernel=False,
n_unfreeze_kernels=1, # amount of kernels to freeze, unfreeze at the time
lr=0.01,
weight_decay=0.001,
stop_at_kernel=None,
dirichlet_regularization=0,
verbose=2,
exp_max=-1,
shift_max=2,
shift_step=1,
r2_per_epoch=False,
skip_kernels=None,
log_next_r2=True,
**kwargs,
):
# color for visualization of history
colors = ["#e41a1c", "#377eb8", "#4daf4a", "#984ea3", "#ff7f00", "#ffff33", "#a65628"]
# here he add a parameter to keep the r2 log by next parms
self.best_r2_by_new_filter = []
# verbose print declaration
if verbose:
def vprint(*args, **kwargs):
print(*args, **kwargs)
else:
vprint = lambda *a, **k: None # do-nothing function
print('verbose=%i' % verbose)
if not isinstance(opt_kernel_shift, list):
opt_kernel_shift = [0] + [opt_kernel_shift] * (self.n_kernels - 1)
if not isinstance(opt_kernel_length, list):
opt_kernel_length = [0] + [opt_kernel_length] * (self.n_kernels - 1)
# prepare model
# this sets up the seed at the first position
if seed is not None:
# this sets up the seed at the first position
for i, s, min_w, max_w in seed:
if s is not None:
print(i, s)
self.set_seed(s, i, min=min_w, max=max_w)
self = self.to(self.device)
# step 1) freeze everything before the current binding mode
for i in range(0, self.n_kernels, n_unfreeze_kernels):
if skip_kernels is not None and i in skip_kernels:
continue
vprint('current kernel', i)
# print(self.binding_modes)
vprint("\n### next filter to optimize %i %s" % (i, '(intercept)' if i == 0 else ''))
vprint("\nFREEZING KERNELS")
for feat_i in ['mono', 'dinuc']:
if i == 0 and feat_i == 'dinuc':
vprint('optimization of dinuc is not valid for the intercept (filter=0). Skip...')
continue
vprint('optimizing feature type', feat_i)
if i != 0:
if feat_i == 'dinuc' and not self.use_dinuc:
vprint('the optimization of dinucleotide features is skipped...')
continue
elif feat_i == 'mono' and not self.use_mono:
vprint('the optimization of mononucleotide features is skipped...')
continue
# block kernels that we do not require to optimize
unfreeze_kernel_ids = set(range(i, i + n_unfreeze_kernels))
vprint('next kernels %i-%i, n=%i' % (min(unfreeze_kernel_ids), max(unfreeze_kernel_ids), len(unfreeze_kernel_ids)))
for ki in range(self.n_kernels):
mask_mono = (ki in unfreeze_kernel_ids) and (feat_i == 'mono')
mask_dinuc = (ki == unfreeze_kernel_ids) and (feat_i == 'dinuc')
if opt_one_step: # skip freezing
if skip_kernels is None or (i in skip_kernels):
mask_mono = False
mask_dinuc = False
if verbose != 0:
if mask_mono or mask_dinuc:
vprint("setting grad status of kernel (mono, dinuc) at %i to (%i, %i)" % (
ki, mask_mono, mask_dinuc))
if hasattr(self.binding_modes, 'update_grad_mono'):
self.binding_modes.update_grad_mono(ki, mask_mono)
if self.use_dinuc:
self.binding_modes.update_grad_di(ki, mask_dinuc)
# activities are frozen during intercept optimization
self.update_grad_activities(ki, i != 0)
if opt_one_step:
self.update_grad_activities(ki, True)
# self.update_grad_etas(i != 0)
if show_logo:
vprint("before filter optimization.")
mb.pl.activities(self, train)
mb.pl.logo(self,
title=False,
xticks=False,
rowspan_dinuc=0,
rowspan_mono=1,
n_rows=5,
n_cols=12,
stop_at=10) # n_cols=len(reduced_groups))
# mb.pl.logo_mono(self)
# mb.pl.conv_mono(model, flip=False, log=False)
# mb.pl.logo_di(self, mode='triangle')
next_lr = lr if not isinstance(lr, list) else lr[i]
next_weight_decay = weight_decay if not isinstance(weight_decay, list) else weight_decay[i]
next_early_stopping = early_stopping if not isinstance(early_stopping, list) else early_stopping[i]
next_optimiser = (
topti.Adam(self.parameters(), lr=next_lr, weight_decay=next_weight_decay)
if optimiser is None
else optimiser(self.parameters(), lr=next_lr)
)
# mask kernels to avoid using weights from further steps into early ones.
if ignore_kernel:
self.set_ignore_kernel(np.array([0 for i in range(i + 1)] + [1 for kernel_i in range(i + 1, self.n_kernels)]))
if verbose != 0:
print("filters mask", self.get_ignore_kernel())
self.optimize_simple(
train,
next_optimiser,
num_epochs=n_epochs[i] if isinstance(n_epochs, list) else n_epochs,
early_stopping=next_early_stopping,
log_each=log_each,
dirichlet_regularization=dirichlet_regularization,
exp_max=exp_max,
verbose=verbose,
r2_per_epoch=r2_per_epoch,
i=i,
**kwargs,
)
# vprint('grad')
# vprint(model.binding_modes.conv_mono[1].weight.grad)
# vprint(model.binding_modes.conv_di[1].weight.grad)
# vprint('')
self.loss_color += list(np.repeat(colors[i % len(colors)], len(self.loss_history) - len(self.loss_color)))
# probably here load the state of the best epoch and save
self.load_state_dict(self.best_model_state)
# store model parameters and fit for later visualization
# necessary?
# self = copy.deepcopy(self)
# optimizer for left / right flanks
best_loss = self.best_loss
if show_logo:
print("\n##After filter opt / before shift optim.")
mb.pl.activities(self, train)
mb.pl.logo(self,
title=False,
xticks=False,
rowspan_dinuc=0,
rowspan_mono=1,
n_rows=5,
n_cols=12,
stop_at=10) # n_cols=len(reduced_groups))
# mb.pl.conv_mono(self)
# # mb.pl.conv_mono(model, flip=True, log=False)
# mb.pl.conv_di(self, mode='triangle')
mb.pl.loss(self)
# print(model_by_k[k_parms].loss_color)
#######
# optimize the flanks through +1/-1 shifts
#######
if (opt_kernel_shift[i] or opt_kernel_length[i]) and i != 0:
self = self.optimize_width_and_length(train,
expand_length_max,
expand_length_step,
shift_max,
shift_step,
i,
feat_i=feat_i,
colors=colors,
verbose=verbose,
lr=next_lr,
weight_decay=next_weight_decay,
optimiser=optimiser,
log_each=log_each,
exp_max=exp_max,
dirichlet_regularization=dirichlet_regularization,
early_stopping=next_early_stopping, criterion=self.criterion,
show_logo=show_logo,
n_kernels=self.n_kernels,
max_w=max_w,
r2_per_epoch=r2_per_epoch,
num_epochs=n_epochs[i] if isinstance(n_epochs, list) else n_epochs,
**kwargs)
if show_logo:
vprint("after shift optimz model")
mb.pl.activities(self, train)
mb.pl.logo(self,
title=False,
xticks=False,
rowspan_dinuc=0,
rowspan_mono=1,
n_rows=5,
n_cols=12,
stop_at=10) # n_cols=len(reduced_groups))
# mb.pl.conv_mono(self)
# # mb.pl.conv_mono(model, log=False)
# mb.pl.conv_di(self, mode='triangle')
mb.pl.loss(self)
print("")
vprint(self.corr_etas_libsizes(train))
# the first kernel does not require an additional fit.
if i == 0:
# option: the log etas are highly correlated after the intercept fit, and thus can be frozen with intercept during training
self.update_grad_etas(False)
continue
vprint("\n\nfinal refinement step (after shift)...unfreezing all layers")
for ki in range(self.n_kernels):
# vprint("kernel grad (%i) = %i \n" % (ki, True), sep=", ", end="")
# print('skip...')
# continue
self.update_grad(ki, ki == i)
# vprint("")
# define the optimizer for final refinement of the model
next_optimiser = (
topti.Adam(self.parameters(), lr=next_lr, weight_decay=next_weight_decay)
if optimiser is None
else optimiser(self.parameters(), lr=next_lr)
)
# mask kernels to avoid using weights from further steps into early ones.
if ignore_kernel:
self.set_ignore_kernel(np.array([0 for i in range(i + 1)] +
[1 for i in range(i + 1, self.n_kernels)]))
# vprint("filters mask", self.get_ignore_kernel())
# final refinement of weights
self.optimize_simple(
train,
next_optimiser,
num_epochs=n_epochs[i] if isinstance(n_epochs, list) else n_epochs,
early_stopping=next_early_stopping,
log_each=log_each,
dirichlet_regularization=dirichlet_regularization,
verbose=verbose,
r2_per_epoch=r2_per_epoch,
)
# load the best model after the final refinement
self.loss_color += list(np.repeat(colors[i % len(colors)], len(self.loss_history) - len(self.loss_color)))
self.load_state_dict(self.best_model_state)
if stop_at_kernel is not None and stop_at_kernel == i:
break
if show_logo:
vprint("\n##final motif signal (after final refinement)")
mb.pl.activities(self, train)
mb.pl.logo(self,
title=False,
xticks=False,
rowspan_dinuc=0,
rowspan_mono=1,
n_rows=5,
n_cols=12,
stop_at=10) # n_cols=len(reduced_groups))
# mb.pl.conv_mono(self)
# mb.pl.conv_di(self, mode='triangle')
# mb.pl.conv_mono(model, flip=True, log=False)
vprint('best loss', '%.3f' % self.best_loss)
# calculate the current r2 and keep a log of it
if log_next_r2:
next_r2 = mb.tl.scores(self, train)['r2_counts']
self.best_r2_by_new_filter.append(next_r2)
vprint('last five r2 values, by sequential filter optimization:',
['%.3f' % v for v in self.best_r2_by_new_filter[-5:]])
vprint(self.corr_etas_libsizes(train))
# print('simple epoch done...')
# assert False
vprint('\noptimization finished:')
vprint(f'total time: {self.total_time}s')
vprint("Time per epoch (total): %.3f s" %
(self.total_time / max(n_epochs if not isinstance(n_epochs, list) else sum(n_epochs), 1)))
return self, self.best_loss
def optimize_width_and_length(self, train, expand_length_max, expand_length_step, shift_max, shift_step, i,
colors=None, verbose=False, lr=0.01, weight_decay=0.001, optimiser=None, log_each=10,
exp_max=40,
num_epochs_shift_factor=1,
dirichlet_regularization=0, early_stopping=15, criterion=None, show_logo=False,
feat_i=None,
n_kernels=4, w=15, max_w=20, num_epochs=100, loss_thr_pct=0.005, **kwargs, ):
"""
A variation of the main optimization routine that attempts expanding the filter of the model at position i, and refines
the weights and loss in order to find a better convergence.
"""
# verbose print declaration
if verbose:
def vprint(*args, **kwargs):
print(*args, **kwargs)
else:
vprint = lambda *a, **k: None # do-nothing function
n_attempts = 0 # to keep a log of overall attempts
opt_expand_left = range(0, expand_length_max, expand_length_step)
opt_expand_right = range(0, expand_length_max, expand_length_step)
opt_shift = list(range(-shift_max, shift_max + 1, shift_step))
for opt_option_text, opt_option_next in zip(
["WIDTH", "SHIFT"], [[opt_expand_left, opt_expand_right, [0]], [[0], [0], opt_shift]]
):
next_loss = None
loss_diff_pct = 0
best_loss = self.best_loss
while next_loss is None or (next_loss < best_loss and loss_diff_pct > loss_thr_pct):
n_attempts += 1
vprint("\n%s OPTIMIZATION (%s)..." % (opt_option_text, "first" if next_loss is None else "again"),
end="")
vprint("")
curr_w = self.get_kernel_width(i)
if curr_w >= max_w:
if opt_option_text == 'WIDTH':
print("Reached maximum w. Stop...")
break
self = copy.deepcopy(self)
best_loss = self.best_loss
next_color = colors[-(1 if n_attempts % 2 == 0 else -2)]
all_options = []
options = [
[expand_left, expand_right, shift]
for expand_left in opt_option_next[0]
for expand_right in opt_option_next[1]
for shift in opt_option_next[2]
]
if opt_option_text == 'SHIFT' and False: # include shifts to center weights
m = torch.tensor(self.get_kernel_weights(i))
# print(m)
m[m < 0] = 0
m = m.reshape(m.shape[-2:])
col_pos_means = m.mean(axis=0).cpu()
w = int(m.shape[-1] / 2)
# print(w)
col_means = []
for j in range(m.shape[-1]):
a, b = max(j - w, 0), min(j + w, m.shape[-1])
ci = m[:, a:b].mean()
col_means.append(j)
pos_max = torch.argmax(torch.tensor(col_means))
# mb.pl.conv(model)
shift_center = (w - pos_max).cpu()
# print(shift)
# print('adding option', shift_center)
options = [[0, 0, -shift_center], [0, 0, shift_center]] + options
# assert False
vprint('options to try', options)
for expand_left, expand_right, shift in options:
# if abs(expand_left) + abs(expand_right) + abs(shift) == 0:
# continue
# if abs(shift) > 0: # skip shift for now.
# continue
if curr_w + expand_left + expand_right > max_w:
continue
# print(expand_left, expand_right, shift)
# assert False
vprint("next expand left: %i, next expand right: %i, shift: %i"
% (expand_left, expand_right, shift))
model_shift = copy.deepcopy(self)
model_shift.loss_history = []
model_shift.r2_history = []
model_shift.loss_color = []
model_shift.optimize_modified_kernel(
train,
kernel_i=i,
shift=shift,
device=self.device,
expand_left=expand_left,
expand_right=expand_right,
num_epochs=num_epochs if opt_option_text == 'WIDTH' else num_epochs * num_epochs_shift_factor,
early_stopping=early_stopping,
log_each=log_each if opt_option_text == 'WIDTH' else log_each * num_epochs_shift_factor,
# log_each,
update_grad_i=i,
feat_i=feat_i,
lr=lr,
weight_decay=weight_decay,
optimiser=optimiser,
criterion=criterion,
dirichlet_regularization=dirichlet_regularization,
exp_max=exp_max,
verbose=verbose,
**kwargs,
)
vprint('')
model_shift.loss_color += list(np.repeat(next_color, len(model_shift.loss_history)))
# print('history left', len(model_left.loss_history))
weight_mono_i = model_shift.binding_modes.conv_mono[i].weight
pos_w_sum = float(weight_mono_i[weight_mono_i > 0].sum())
loss_diff_pct = (best_loss - model_shift.best_loss) / best_loss * 100
r2 = mb.tl.scores(model_shift, train)['r2_counts']
all_options.append([expand_left, expand_right, shift, model_shift,
pos_w_sum, weight_mono_i.shape[-1], loss_diff_pct, model_shift.best_loss, r2])
# print('\n')
# vprint("after opt.")
if show_logo:
mb.pl.logo(self,
title=False,
xticks=False,
rowspan_dinuc=0,
rowspan_mono=1,
n_rows=5,
n_cols=12,
stop_at=10) # n_cols=len(reduced_groups))
# mb.pl.conv_mono(model_shift)
# mb.pl.conv_di(model_shift, mode='triangle')
# for shift, model_shift, loss in all_shifts:
# print('shift=%i' % shift, 'loss=%.4f' % loss)
weight_ref_mono_i = model_shift.binding_modes.conv_mono[i].weight
pos_w_ref_mono_i_sum = float(weight_ref_mono_i[weight_ref_mono_i > 0].sum())
best_r2 = mb.tl.scores(self, train)['r2_counts']
best = sorted(
all_options + [[0, 0, 0, self,
pos_w_ref_mono_i_sum, weight_ref_mono_i.shape[-1], 0, self.best_loss, best_r2]],
key=lambda x: x[-1],
)
if verbose != 0:
print("filter rearrangments (sorted by observed r2)")
best_df = pd.DataFrame(best, columns=["expand.left", "expand.right", "shift", "model",
'pos_w_sum', 'width', "loss_diff_pct", "loss", 'r2'],
)
best_df['last_loss'] = best_loss
best_df = best_df.sort_values('loss')
vprint(best_df[[c for c in best_df if c != 'model']])
# print('\n history len')
next_expand_left, next_expand_right, next_position, next_model, next_pos_w, w, \
loss_diff_pct, next_loss, next_r2 = best_df.values[0][:-1]
print(next_expand_left, next_expand_right, next_position, next_pos_w, w,
loss_diff_pct, next_loss, next_r2)
if verbose != 0:
print("action (expand left, expand right, shift): (%i, %i, %i)\n" %
(next_expand_left, next_expand_right, next_position))
if loss_diff_pct >= loss_thr_pct:
next_model.loss_history = self.loss_history + next_model.loss_history
next_model.r2_history = self.r2_history + next_model.r2_history
next_model.loss_color = self.loss_color + next_model.loss_color
self = copy.deepcopy(next_model)
if next_expand_left == 0 and next_expand_right == 0 and next_position == 0 and opt_option_text == 'SHIFT':
print('This was the last iteration. Done with filter shift optimization...')
break
return self
def optimize_modified_kernel(self,
train,
shift=0,
expand_left=0,
expand_right=0,
device=None,
num_epochs=500,
early_stopping=15,
log_each=-1,
feat_i='mono',
update_grad_i=None,
use_dinuc=False,
kernel_i=None,
lr=0.01,
weight_decay=0.001,
optimiser=None,
dirichlet_regularization=0,
exp_max=40,
verbose=0,
r2_per_epoch=False,
**kwargs,
):
assert expand_left >= 0 and expand_right >= 0
self.modify_kernel(kernel_i, shift, expand_left, expand_right, device)
# requires grad update
n_kernels = len(self.binding_modes)
for ki in range(n_kernels):
self.binding_modes.update_grad_mono(ki, (ki == update_grad_i) and (feat_i == 'mono'))
if self.use_dinuc:
self.binding_modes.update_grad_di(ki, (ki == update_grad_i) and (feat_i == 'dinuc'))
# finally the optimiser has to be initialized again.
optimiser = (
topti.Adam(self.parameters(), lr=lr, weight_decay=weight_decay)
if optimiser is None
else optimiser(self.parameters(), lr=lr)
)
self.optimize_simple(
train,
optimiser,
num_epochs=num_epochs,
early_stopping=early_stopping,
log_each=log_each,
dirichlet_regularization=dirichlet_regularization,
exp_max=exp_max,
verbose=verbose,
r2_per_epoch=r2_per_epoch,
)
return self
class BindingLayer(tnn.Module):
"""
Implements binding modes (also non-specific binding) for one protein.
Keyword Args:
kernels (List[int]): Size of the binding modes (0 indicates non-specific binding). Default: [0, 15]
init_random (bool): Use a random initialization for all parameters. Default: True
use_dinuc (bool): Use dinucleotide contributions (not fully implemented for all kind of models). Default: False
"""
def __init__(self, **kwargs):
super().__init__()
self.use_conv1d = kwargs.get("use_conv1d", True)
self.kernels = kwargs.get("kernels", [0, 15])
self.init_random = kwargs.get("init_random", True)
self.use_dinuc = kwargs.get("use_dinuc", True)
self.dinuc_mode = kwargs.get("dinuc_mode", 'local') # local or full
self.conv_mono = tnn.ModuleList()
self.conv_di = tnn.ModuleList()
self.ones = None # aux ones tensor, for intercept init.
for k in self.kernels:
if k == 0:
self.conv_mono.append(None)
self.conv_di.append(None)
else:
if self.use_conv1d:
next_mono = tnn.Conv1d(in_channels=4, out_channels=1, kernel_size=k, bias=False)
else: # conv2d
next_mono = tnn.Conv2d(1, 1, kernel_size=(4, k), padding=(0, 0), bias=False)
if not self.init_random:
next_mono.weight.data.uniform_(0, 0)
else:
next_mono.weight.data.uniform_(.01, .05) # problem with fitting dinucleotides if (0, 0)
self.conv_mono.append(next_mono)
# create the conv_di layers. These are skipped during forward unless use_dinuc is True
if self.dinuc_mode == 'local':
# the number of contiguous dinucleotides for k positions is k - 1
if self.use_conv1d:
next_di = tnn.Conv1d(in_channels=16, out_channels=1, kernel_size=k, bias=False)
else:
next_di = tnn.Conv2d(1, 1, kernel_size=(16, k - 1), padding=(0, 0), bias=False)
if not self.init_random:
next_di.weight.data.uniform_(-.01, .01) # problem with fitting dinucleotides if (0, 0)
self.conv_di.append(next_di)
# a matrix of conv2d
elif self.dinuc_mode == 'full':
conv_di_next = tnn.ModuleList()
for i in range(1, k):
conv_di_next.append(tnn.Conv2d(1, 1, kernel_size=(16, k - i)))
self.conv_di.append(conv_di_next)
# regularization parameter
self.prob_act = tnn.Parameter(torch.ones(len(self.conv_mono) - 1, dtype=torch.float32)) # minus the intercept
# self.prob_thr = .1 # tnn.Parameter(torch.zeros(1, dtype=torch.float32))
if kwargs.get('p_dropout', False):
self.p_dropout = kwargs.get('p_dropout')
self.dropout = tnn.Dropout(p=self.p_dropout)
else:
self.dropout = None
def forward(self, mono, mono_rev, di=None, di_rev=None, **kwargs):
bm_pred = []
# if self.use_conv1d:
# # print(mono.shape, mono_rev.shape, di.shape, di_rev.shape)
# mono = mono.squeeze(1)
# mono_rev = mono_rev.squeeze(1)
# di = di.squeeze(1) if di is not None else None
# di_rev = di_rev.squeeze(1) if di_rev is not None else None
# print(mono.shape, mono_rev.shape, di.shape, di_rev.shape)
# print(mono.shape)
# assert False
for i in range(len(self.kernels)):
# print(i)
if self.kernels[i] == 0:
# intercept (will be scaled by the activity of the non-specific binding)
# print(mono.shape[0])
# temp = torch.tensor([1.0] * mono.shape[0], device=mono.device)
if self.ones is None or self.ones.shape[0] < mono.shape[
0]: # aux ones tensor, to avoid memory init delay
self.ones = torch.ones(mono.shape[0], device=mono.device) # torch.ones is much faster
# print(mono.device)
# print(self.ones)
temp = self.ones[:mono.shape[0]].to(mono.device) if self.ones.shape[0] != mono.shape[0] else self.ones # subsetting of ones to fit batch
# print(temp)
# print(temp.device)
# assert False
bm_pred.append(temp)
else:
# check devices match
# if self.conv_mono[i].weight.device != mono.device:
# self.conv_mono[i].weight.to(mono.device)
# else:
# print('devices match...')
# print('here...')
# print(self.conv_di[i])
# print(type(self.use_dinuc), self.use_dinuc, self.dinuc_mode)
# assert False
if self.use_dinuc and self.dinuc_mode == 'local':
temp = torch.cat(
(
self.conv_mono[i](mono),
self.conv_mono[i](mono_rev),
self.conv_di[i](di),
self.conv_di[i](di_rev),
),
dim=2 if self.use_conv1d else 3,
)
elif self.use_dinuc and self.dinuc_mode == 'full':
next_conv_di = self.conv_di[i]
k = self.conv_mono[i].weight.shape[-1] # this is to infer the number of conv2d for dinuc
pi = -1 # the overall counter of the conv2d
# iterating in this way, we go from the largest axis i.e. diagonal to the corner
out_di = []
# print(next_conv_di)
# for pi, di_ij in enumerate(next_conv_di):
# print(pi)
for ki in range(0, k): # ki indicates the delta between positions i.e. the diagonal index
# print('\nI = %i' % ki)
if not self.use_conv1d:
p = mono[:, :, :, :mono.shape[-1] - ki]
q = mono[:, :, :, ki:]
else:
p = mono[:, :, :mono.shape[-1] - ki]
q = mono[:, :, ki:]
assert p.shape[-1] == q.shape[-1]
p_max = torch.argmax(p, axis=2)
p_max = torch.mul(p_max, 4)
q_max = torch.argmax(q, axis=2)
mask = p_max + q_max
mask_flatten = mask.flatten()
one_hot = torch.nn.functional.one_hot(mask_flatten, num_classes=16)
m = one_hot.reshape(mono.shape[0], mono.shape[-1] - ki, 16)
m = m.float()
m = m.reshape(mono.shape[0], 16, mono.shape[-1] - ki)
m = torch.unsqueeze(m, 1)
di_ij = next_conv_di[pi]
next_out = di_ij(m)
# print(next_out.shape)
out_di.append(next_out)
# assert False
temp_mono = torch.cat(
(
self.conv_mono[i](mono),
self.conv_mono[i](mono_rev),
),
dim=2 if self.use_conv1d else 3,
)
temp_di = torch.cat(out_di, dim=3)
temp = torch.cat((temp_mono, temp_di), axis=3)
# print(temp.shape)
# print('here....')
# assert False
else:
# print(i, self.conv_mono[i], mono.shape, mono_rev.shape
out_mono = self.conv_mono[i](mono)
out_mono_rev = self.conv_mono[i](mono_rev)
temp = torch.cat((out_mono, out_mono_rev), dim=2 if self.use_conv1d else 3)
# this particular step can generate out of bounds due to the exponential cost
# print(temp_mono.type())
# print(temp.shape, temp.type())
# assert False
temp = temp.view(temp.shape[0], -1) # view before exp
temp = torch.exp(temp)
temp = torch.sum(temp, dim=1)
bm_pred.append(temp)
# print(bm_pred)
# for t in bm_pred[:3]:
# print(t.shape, t.device)
out = torch.stack(bm_pred).T
# regularization step, using activation probability
if self.dropout is not None:
sparsity = False
if not sparsity:
prob = torch.ones(out.shape[-1] - 1, device=mono.device)
prob = self.dropout(prob)
prob = torch.cat((torch.ones(1, device=mono.device), prob))
prob = prob > 0 # self.prob_thr
return out * prob
else:
mask1 = self.prob_act > torch.ones(self.prob_act.shape[0], device=mono.device)
out1 = torch.where(mask1, 1, self.prob_act)
mask2 = out1 < torch.zeros(out1.shape[0], device=mono.device)
z = torch.where(mask2, 0, out1)
z = torch.cat((torch.ones(1, device=mono.device), z))
return out * z
else:
return out
def set_seed(self, seed, index, max, min):
assert len(seed) <= self.conv_mono[index].kernel_size[1]
shift = int((self.conv_mono[index].kernel_size[1] - len(seed)) / 2)
seed_params = torch.full(self.conv_mono[index].kernel_size, max, dtype=torch.float32)
for i in range(len(seed)):
if seed[i] == "A":
seed_params[:, i + shift] = torch.tensor([max, min, min, min])
elif seed[i] == "C":
seed_params[:, i + shift] = torch.tensor([min, max, min, min])
elif seed[i] == "G":
seed_params[:, i + shift] = torch.tensor([min, min, max, min])
elif seed[i] == "T":
seed_params[:, i + shift] = torch.tensor([min, min, min, max])
else:
seed_params[:, i + shift] = torch.tensor([0, 0, 0, 0])
self.conv_mono[index].weight = tnn.Parameter(torch.unsqueeze(torch.unsqueeze(seed_params, 0), 0))
def update_grad_mono(self, index, value):
if self.conv_mono[index] is not None:
self.conv_mono[index].weight.requires_grad = value
if not value:
# print('setting grad to None')
self.conv_mono[index].weight.grad = None
def update_grad_di(self, index, value):
if len(self.conv_di) >= index and self.conv_di[index] is not None:
if isinstance(self.conv_di[index], tnn.ModuleList):
for conv_di in self.conv_di[index]:
conv_di.weight.requires_grad = value
if not value:
# print('setting grad to None')
conv_di.weight.grad = None
else:
self.conv_di[index].weight.requires_grad = value
if not value:
# print('setting grad to None')
self.conv_di[index].weight.grad = None
def update_grad(self, index, value):
self.update_grad_mono(index, value)
if self.use_dinuc:
self.update_grad_di(index, value)
def modify_kernel(self, index=None, shift=0, expand_left=0, expand_right=0, device=None):
# shift mono
shape_mono_before = None
shape_di_before = None
for i, m in enumerate(self.conv_mono):
if index is not None and index != i:
continue
if m is None:
continue
before_w = m.weight.shape[-1]
shape_mono_before = m.weight.shape
# update the weight
if shift >= 1:
if not self.use_conv1d:
self.conv_mono[i].weight = torch.nn.Parameter(
torch.cat([m.weight[:, :, :, shift:], torch.zeros(1, 1, 4, shift, device=device)], dim=3)
)
else:
self.conv_mono[i].weight = torch.nn.Parameter(
torch.cat([m.weight[:, :, shift:], torch.zeros(1, 4, shift, device=device)], dim=2)
)
elif shift <= -1:
if not self.use_conv1d:
self.conv_mono[i].weight = torch.nn.Parameter(
torch.cat(
[
torch.zeros(1, 1, 4, -shift, device=device),
m.weight[:, :, :, :shift],
],
dim=3,
)
)
else:
self.conv_mono[i].weight = torch.nn.Parameter(
torch.cat(
[
torch.zeros(1, 4, -shift, device=device),
m.weight[:, :, :shift],
],
dim=2,
)
)
# adding more positions left and right
if expand_left > 0:
if not self.use_conv1d:
self.conv_mono[i].weight = torch.nn.Parameter(
torch.cat([torch.zeros(1, 1, 4, expand_left, device=device), m.weight[:, :, :, :]], dim=3)
)
else:
self.conv_mono[i].weight = torch.nn.Parameter(
torch.cat([torch.zeros(1, 4, expand_left, device=device), m.weight[:, :, :]], dim=2)
)
if expand_right > 0:
if not self.use_conv1d:
self.conv_mono[i].weight = torch.nn.Parameter(
torch.cat([m.weight[:, :, :, :], torch.zeros(1, 1, 4, expand_right, device=device)], dim=3)
)
else:
self.conv_mono[i].weight = torch.nn.Parameter(
torch.cat([m.weight[:, :, :], torch.zeros(1, 4, expand_right, device=device)], dim=2)
)
after_w = m.weight.shape[-1]
if after_w != (before_w + expand_left + expand_right):
assert after_w != (before_w + expand_left + expand_right)
# shift di
for i, m in enumerate(self.conv_di):
if index is not None and index != i:
continue
if m is None:
continue
# update the weight
if self.dinuc_mode == 'local':
shape_di_before = m.weight.shape
if shift >= 1:
if not self.use_conv1d:
m.weight = torch.nn.Parameter(
torch.cat([m.weight[:, :, :, shift:], torch.zeros(1, 1, 16, shift, device=device)], dim=3)
)
else:
m.weight = torch.nn.Parameter(
torch.cat([m.weight[:, :, shift:], torch.zeros(1, 16, shift, device=device)], dim=2)
)
elif shift <= -1:
if not self.use_conv1d:
m.weight = torch.nn.Parameter(
torch.cat([torch.zeros(1, 1, 16, -shift, device=device), m.weight[:, :, :, :shift]], dim=3)
)
else:
m.weight = torch.nn.Parameter(
torch.cat([torch.zeros(1, 16, -shift, device=device), m.weight[:, :, :shift]], dim=2)
)
# adding more positions left and right
if expand_left > 0:
self.conv_di[i].weight = torch.nn.Parameter(
torch.cat([torch.zeros(1, 1, 16, expand_left, device=device), m.weight[:, :, :, :]], dim=3)
)
if expand_right > 0:
self.conv_di[i].weight = torch.nn.Parameter(
torch.cat([m.weight[:, :, :, :], torch.zeros(1, 1, 16, expand_right, device=device)], dim=3)
)
# check that the differences between kernels are the same before and after updates
diff_width_after = self.conv_di[i].weight.shape[-1] - self.conv_mono[i].weight.shape[-1]
assert diff_width_after == (shape_di_before[-1] - shape_mono_before[-1])
elif self.dinuc_mode == 'full': # reset the dinuc weights to match the shape of the new conv2d for mononuc
if shape_mono_before[-1] != self.conv_mono[i].weight.shape[-1]:
if False:
k = self.conv_mono[i].weight.shape[-1]
print('updating convdi triangle to have %i positions' % k)
conv_di_next = tnn.ModuleList()
for i in range(1, k + 1):
for j in range(k - i + 1):
# pi += 1
conv_di_next.append(tnn.Conv2d(1, 1, kernel_size=(16, i)))
self.conv_di[i] = conv_di_next
# print(len(self.conv_di))
def dirichlet_regularization(self):
out = 0
for m in self.conv_mono:
if m is None:
continue
elif m.weight.requires_grad:
out -= torch.sum(m.weight - torch.logsumexp(m.weight, dim=2))
if self.use_dinuc:
for d in self.conv_di:
if d is None:
continue
elif d.weight.requires_grad:
out -= torch.sum(d.weight - torch.logsumexp(d.weight, dim=2))
return out
def get_kernel_width(self, index):
return self.conv_mono[index].weight.shape[-1] if self.conv_mono[index] is not None else 0
def get_kernel_weights(self, index, dinucleotide=False):
values = self.conv_mono if not dinucleotide else self.conv_di
return values[index].weight if len(values) >= (index + 1) and values[index] is not None else None
def __len__(self):
return len(self.conv_mono)
class BindingModesPerProtein(tnn.Module):
"""
Implements binding modes (also non-specific binding) for multiple proteins in the same batch.
Args:
bm_generator (torch.nn.Module): PyTorch module which has a weight matrix as output
Keyword Args:
add_intercept (bool): Whether an intercept is used in addition to the predicted binding modes. Default: True
"""
def __init__(self, bm_generator, **kwargs):
super().__init__()
self.generator = bm_generator
self.use_intercept = kwargs.get("intercept", True)
def forward(self, mono, mono_rev, di=None, di_rev=None, **kwargs):
weights = self.generator(**kwargs) # weights needs to be a list
bm_pred = []
if self.use_intercept:
bm_pred.append(torch.Tensor([1.0] * mono.shape[0]).to(device=mono.device))
for w in weights:
w = torch.unsqueeze(w, 1)
# Transposing batch dim and channels
mono = torch.transpose(mono, 0, 1)
mono_rev = torch.transpose(mono_rev, 0, 1)
temp = torch.cat(
(
F.conv2d(mono, w, groups=w.shape[0]),
F.conv2d(mono_rev, w, groups=w.shape[0]),
),
dim=3,
)
# Transposing back
mono = torch.transpose(mono, 0, 1)
mono_rev = torch.transpose(mono_rev, 0, 1)
temp = torch.transpose(temp, 0, 1)
temp = torch.exp(temp)
temp = temp.view(temp.shape[0], -1)
temp = torch.sum(temp, dim=1)
bm_pred.append(temp)
return torch.stack(bm_pred).T
def update_grad(self, index, value):
self.generator.update_grad(index, value)
def modify_kernel(self, index=None, shift=0, expand_left=0, expand_right=0, device=None):
self.generator.modify_kernel(index, shift, expand_left, expand_right, device)
def dirichlet_regularization(self):
return 0 # could be implemented in the genaerators end then changed here
def get_kernel_width(self, index):
return self.generator.get_kernel_width(index)
def get_kernel_weights(self, index):
return self.generator.get_kernel_width(index)
def __len__(self):
return len(self.generator)
class ActivitiesLayer(tnn.Module):
"""
Implements activities with batch effects.
Args:
target_dim: Second dimension of the output of forward
Keyword Args:
n_batches (int): Number of batches that will occur in the data. Default: 1
n_proteins (int): Number of proteins in the dataset. Either n_proteins or n_batches may be used. Default: 1
ignore_kernel (list[bool]): Whether a kernel should be ignored. Default: None.
"""
def __init__(self, n_kernels, target_dim, **kwargs):
super().__init__()
self.n_kernels = n_kernels
# due to having multiple batches in some rounds, the max n_rounds is stored
self.target_dim = max(target_dim) if not isinstance(target_dim, int) else target_dim
self.n_batches = kwargs.get("n_batches", 1) if "n_batches" in kwargs else kwargs.get("n_proteins", 1)
self.ignore_kernel = kwargs.get("ignore_kernel", None)
if self.n_batches != 1:
self.log_activities = tnn.ParameterList()
for i in range(n_kernels):
self.log_activities.append(
tnn.Parameter(torch.zeros([self.n_batches, self.target_dim], dtype=torch.float32))
)
else: # one batch - simple case and faster
self.log_activities = tnn.Parameter(torch.zeros([n_kernels, self.target_dim], dtype=torch.float32))
def forward(self, binding_per_mode, **kwargs):
batch = kwargs.get("batch", None)
if batch is None:
batch = kwargs.get("protein_id", None)
if batch is None:
batch = torch.zeros([binding_per_mode.shape[0]], device=binding_per_mode.device)
# print(scores.shape)
# print(torch.stack(list(self.log_activities), dim=1).shape)
# this is to compare old/new implementation of relevant low-level operations
scores = None
option = 1
if option == 1:
if self.n_batches != 1:
a = torch.exp(torch.stack(list(self.log_activities), dim=1))
# print(b.shape, a.shape, batch.shape)
# print(b.type(), a.type(), batch.type())
result = torch.matmul(b, a[batch, :, :])
scores = result.squeeze(1)
else:
b = binding_per_mode
a = torch.exp(self.log_activities)
# print(b.shape, a.shape, batch.shape)
# print(b.type(), a.type(), batch.type())
result = torch.matmul(b, a)
return result
# print(a)
# print('b')
# print(b)
else:
scores = torch.zeros([binding_per_mode.shape[0], self.target_dim], device=binding_per_mode.device)
for i in range(self.n_batches):
a = torch.exp(torch.stack(list(self.log_activities), dim=1)[i, :, :])
batch_mask = batch == i
b = binding_per_mode[batch_mask]
if self.ignore_kernel is not None:
mask = self.ignore_kernel != 1 # == False
scores[batch_mask] = torch.matmul(b[:, mask], a[mask, :])
else:
scores[batch_mask] = torch.matmul(b, a)
return scores
def update_grad(self, index, value):
if self.n_batches != 1:
self.log_activities[index].requires_grad = value
if not value:
self.log_activities[index].grad = None
else:
self.log_activities.requires_grad = value
if not value:
self.log_activities.grad = None
def set_ignore_kernel(self, ignore_kernel):
self.ignore_kernel = ignore_kernel
def get_ignore_kernel(self):
return self.ignore_kernel
def get_log_activities(self):
return torch.stack(list(self.log_activities), dim=1) if self.n_batches != 1 else self.log_activities
class GraphLayer(tnn.Module):
"""
Implements the layer that calculates associations between samples and readouts
Args:
target_dim: Second dimension of the output of forward
Keyword Args:
enr_series (bool): Whether the data should be handled as enrichment series. Default: True
"""
def __init__(self, n_rounds, **kwargs):
super().__init__()
self.n_rounds = max(n_rounds) if not isinstance(n_rounds, int) else n_rounds
self.enr_series = kwargs.get("enr_series", True)
self.n_batches = kwargs.get("n_batches", 1)
self.log_etas = tnn.Parameter(torch.zeros([self.n_batches, self.n_rounds]))
# log dynamic is a matrix with upper/lower triangle with opposite symbols
# self.log_dynamic = tnn.Parameter(torch.zeros([int((self.n_rounds - 1) * (self.n_rounds) / 2)]))
# self.log_dynamic = tnn.Parameter(torch.zeros([self.n_rounds, self.n_rounds]))
# print(self.log_etas.shape, self.log_dynamic.shape)
# print(self.log_etas.shape, self.log_dynamic.shape)
if kwargs.get('prepare_knn'):
self.prepare_knn(**kwargs)
print('setting up log dynamic')
# self.log_dynamic = tnn.Parameter(torch.rand(self.conn_sparse.indices().shape[1])).requires_grad_(True) # .cuda()
self.use_hadamard = False
def prepare_knn(self,
**kwargs):
'''
This routine is in charge of the graph to be used during the assay-assay relatedness step. A custom RNA-based graph can be provided.
'''
adata = kwargs.get('adata')
device = kwargs.get('device')
# prepare the zero counts
counts = adata.X.T
next_data = pd.DataFrame(counts.A if type(counts) != np.ndarray else counts) # sparse.from_spmatrix(counts.A)
next_data['var'] = next_data.var(axis=1)
del next_data['var']
df = next_data.copy() # sample
zero_counts = df.sum(axis=1) == 0
vel_graph = kwargs.get('velocity_graph')
if vel_graph is not None:
print(vel_graph.shape)
# print(adata[:, ~zero_counts].obsp['connectivities'].A.shape, type(adata[:, ~zero_counts].obsp['connectivities'].A))
# print(vel_graph.shape, type(vel_graph))
conn = adata[:, ~zero_counts].obsp['connectivities'].A if vel_graph is None else vel_graph.A
self.conn_sparse = torch.tensor(conn).to_sparse()
if str(device) != 'cpu':
self.conn_sparse = self.conn_sparse.cuda()
# do not activate the required grad of this function, otherwise, it does not optimize
# if device == 'cpu':
# self.log_dynamic = tnn.Parameter(
# torch.rand(self.conn_sparse.indices().shape[1])) # .requires_grad_(True).cuda()
# if device != 'cpu':
# self.log_dynamic = self.log_dynamic.cuda() # requires_grad_(True)
#
#
# if self.log_dynamic.shape[0] == 0:
# print('Warning: Log dynamic is empty. This indicates an empty kNN representations.'
# 'Please verify previous steps...')
# else:
# initialize log dynamic
tspa = torch.sparse_coo_tensor
t = torch.transpose
C = self.conn_sparse # .cuda()
a_ind = C.indices()
# self.D = self.log_dynamic.cuda()
# print(a_ind.device, self.D.device, C.device)
# assert False
# self.log_dynamic = tnn.Parameter(torch.rand(self.conn_sparse.indices().shape[1])).requires_grad_(True).cuda()
# do not convert to cuda, otherwise, the optimization of these weights will not happen.
self.log_dynamic = tnn.Parameter(torch.rand(self.conn_sparse.indices().shape[1])) # .cuda()
# the opposite direction will always be rescaled into a negative sign, and this is the factor that controls the magnitude
self.knn_free_weights = kwargs.get('knn_free_weights')
if self.knn_free_weights:
self.log_dynamic_scaling = tnn.Parameter(torch.rand(self.conn_sparse.indices().shape[1])) # .cuda()
# self.D_tril = tspa(a_ind, torch.rand(self.conn_sparse.indices().shape[1]).cuda(), C.shape).requires_grad_(True).cuda()
# self.D_triu = -self.D_tril # opposite sign
def forward(self, binding_scores, countsum, **kwargs):
batch = kwargs.get("batch", None)
one_batch = self.n_batches == 1
if batch is None and self.n_batches != 1:
batch = torch.zeros([binding_scores.shape[0]], device=binding_scores.device)
# assert hasattr(self, 'conn_sparse')
out = None
if self.enr_series:
out = torch.cumprod(binding_scores, dim=1) # cum product between rounds 0 and N
elif hasattr(self, 'conn_sparse') and kwargs.get('use_conn', True): # in this particular step, we multiply by the dynamic or static scores.
# general math operations
tsum = torch.sum
texp = torch.exp
tspa = torch.sparse_coo_tensor
tsmm = torch.sparse.mm
t = torch.transpose
# binding scores and transposed
b = binding_scores
b_T = torch.transpose(b, 0, 1)
# connectivities (kNN or velocity-inferred graph)
G = self.conn_sparse
a_ind = G.indices()
# conn_spa_T = torch.transpose(C, 0, 1)
# log dynamic weights
D = self.log_dynamic
D_tril = tspa(a_ind, D, G.shape) # .requires_grad_(True).cuda()
# scaling of weights yes/no
if self.knn_free_weights:
D_triu = tspa(a_ind, -D * torch.exp(self.log_dynamic_scaling), G.shape)
# print('here...')
else:
D_triu = tspa(a_ind, -D, G.shape)
D_all = D_tril + t(D_triu, 0, 1)
tmp = tsmm(G, b_T).T
tmp_T = t(tmp, 0, 1)
if self.use_hadamard:
# scores explained by neighboring-based weights
# b -> (p, c), G -> (c, c), D_all -> (c, c)
out = (b @ (texp(D_all.to_dense())))
else:
dynamic_out1 = tsmm(texp(t(D_all, 0, 1).to_dense()), tmp_T).T
static_out1 = tsmm(texp(D_all.to_dense()), b_T).T
out = static_out1 + dynamic_out1
else:
out = binding_scores
# print(hasattr(self, 'connectivities'))
# assert False
#
# print('binding scores')
# print(out[:5])
#
# print('log etas')
# print(self.log_etas)
# multiplication in one step
etas = torch.exp(self.log_etas)
#
# print('etas exp')
# print(etas)
if not one_batch: # several batches
etas = etas[batch, :]
out = out * etas
else: # only one batch
# etas = etas[batch, :]
out = torch.mul(out, etas)
# print('out after eta scaling')
# print(out[:5])
# fluorescent data e.g. PBM, does not require scaling, to keep numbers beyond range [0 - 1]
if not kwargs.get('scale_countsum', True):
return out
# results = out.T / torch.sum(out, dim=1)
results = out / out.sum(dim=-1).unsqueeze(-1)
# print('sums', torch.sum(out, dim=1))
# print('results')
# print(results)
#
# print('countsum')
# print(countsum)
# print('mat sum')
# print(countsum.sum())
return (results * countsum.unsqueeze(1))
def get_log_etas(self):
return self.log_etas
def update_grad_etas(self, value):
self.log_etas.requires_grad = value
# if not value:
# self.log_etas.grad = None
def _weight_distances(mono, min_k=5):
d = []
for a, b in itertools.combinations(mono, r=2):
a = a.weight
b = b.weight
min_w = min(a.shape[-1], b.shape[-1])
# print(min_w)
lowest_d = -1
for k in range(5, min_w):
# print(k)
for i in range(0, a.shape[-1] - k + 1):
ai = a[:, :, :, i: i + k]
for j in range(0, b.shape[-1] - k + 1):
bi = b[:, :, :, j: j + k]
bi_rev = torch.flip(bi, [3])[:, :, [3, 2, 1, 0], :]
d.append(((bi - ai) ** 2).sum() / bi.shape[-1])
d.append(((bi_rev - ai) ** 2).sum() / bi.shape[-1])
if lowest_d == -1 or d[-1] < lowest_d or d[-2] < lowest_d:
next_d = min(d[-2], d[-1])
# print(i, i + k, j, j + k, d[-2], d[-1])
lowest_d = next_d
return min(d)
class MubindFlexibleWeights(tnn.Module):
def __init__(
self,
n_rounds,
n_batches,
use_dinuc=False,
max_w=15,
ignore_kernel=None,
rho=1,
gamma=0,
init_random=True,
enr_series=True,
padding_const=0.25,
datatype='selex', # must be 'selex' ot 'pbm' (case-insensitive)
):
super().__init__()
self.datatype = datatype.lower()
assert self.datatype in ['selex', 'pbm']
self.use_dinuc = use_dinuc
self.n_rounds = n_rounds
self.n_batches = n_batches
self.rho = rho
self.gamma = gamma
# self.padding = tnn.ModuleList()
# only keep one padding equals to the length of the max kernel
self.padding = tnn.ConstantPad2d((max_w - 1, max_w - 1, 0, 0), padding_const)
self.log_activities = tnn.ParameterList()
self.enr_series = enr_series
# self.log_activities = tnn.Parameter(torch.zeros([n_batches, len(kernels), n_rounds+1]))
self.log_etas = tnn.Parameter(torch.zeros([n_batches, n_rounds + 1]))
self.ignore_kernel = ignore_kernel
for _ in range(1, 3):
self.log_activities.append(tnn.Parameter(torch.zeros([n_batches, n_rounds + 1], dtype=torch.float32)))
def forward(self, x):
# Create the forward pass through the network.
mono, mono_rev, batch, countsum, weight = x
# padding of sequences
mono = self.padding(mono)
mono_rev = self.padding(mono_rev)
mono = torch.unsqueeze(mono, 1)
mono_rev = torch.unsqueeze(mono_rev, 1)
x_ = []
temp = torch.Tensor([1.0] * mono.shape[0], device=mono.device)
x_.append(temp)
# Transposing batch dim and channels
mono = torch.transpose(mono, 0, 1)
mono_rev = torch.transpose(mono_rev, 0, 1)
temp = torch.cat(
(
F.conv2d(mono, weight, groups=weight.shape[0]),
F.conv2d(mono_rev, weight, groups=weight.shape[0]),
),
dim=3,
)
temp = torch.transpose(temp, 0, 1) # Transposing back
temp = torch.exp(temp)
temp = temp.view(temp.shape[0], -1)
temp = torch.sum(temp, dim=1)
x_.append(temp)
x = torch.stack(x_).T
scores = torch.zeros([x.shape[0], self.n_rounds + 1]).to(device=mono.device)
for i in range(self.n_batches):
# a = torch.exp(self.log_activities[i, :, :])
a = torch.exp(torch.stack(list(self.log_activities), dim=1)[i, :, :])
if self.ignore_kernel is not None:
mask = self.ignore_kernel != 1 # == False
# print(mask_kernel)
# print(x.shape, a.shape, x[batch == i][:,mask_kernel], a[mask_kernel,:].shape)
scores[batch == i] = torch.matmul(x[batch == i][:, mask], a[mask, :])
else:
scores[batch == i] = torch.matmul(x[batch == i], a)
if self.datatype == "pbm":
return scores
# return torch.log(scores)
# a = torch.reshape(a, [a.shape[0], a.shape[2]])
# x = torch.matmul(x, a)
# sequential enrichment or independent samples
if self.enr_series:
# print('using enrichment series')
predictions_ = [scores[:, 0]]
for i in range(1, self.n_rounds + 1):
predictions_.append(predictions_[-1] * scores[:, i])
out = torch.stack(predictions_).T
else:
out = scores
for i in range(self.n_batches):
eta = torch.exp(self.log_etas[i, :])
out[batch == i] = out[batch == i] * eta
results = out.T / torch.sum(out, dim=1)
results = (results * countsum).T
return results
# This class can be used to store binding modes for several proteins
class BMCollection(tnn.Module):
"""
Implements binding modes for multiple proteins at once. Should be used as a generator in combination with
BindingModesPerProtein.
Keyword Args:
kernels (List[int]): Size of the binding modes (0 indicates non-specific binding, and will be accomplished by
setting add_intercept to True). Default: [0, 15]
init_random (bool): Use a random initialization for all parameters. Default: True
"""
def __init__(self, n_proteins, **kwargs):
super().__init__()
self.n_proteins = n_proteins
if "kernels" not in kwargs and "n_kernels" not in kwargs:
kwargs["kernels"] = [0, 15]
kwargs["n_kernels"] = len(kwargs["kernels"])
elif "n_kernels" not in kwargs:
kwargs["n_kernels"] = len(kwargs["kernels"])
elif "kernels" not in kwargs:
kwargs["kernels"] = [0] + [15] * (kwargs["n_kernels"] - 1)
else:
assert len(kwargs["kernels"]) == kwargs["n_kernels"]
self.kernels = kwargs.get("kernels")
self.stored_indizes = {}
self.init_random = kwargs.get("init_random", True)
self.conv_mono_list = tnn.ModuleList()
for k in self.kernels:
if k != 0:
conv_mono = tnn.ModuleList()
for i in range(self.n_proteins):
next_mono = tnn.Conv2d(1, 1, kernel_size=(4, k), padding=(0, 0), bias=False)
if not self.init_random:
next_mono.weight.data.uniform_(0, 0)
conv_mono.append(next_mono)
self.conv_mono_list.append(conv_mono)
else:
self.conv_mono_list.append(None)
def forward(self, protein_id, **kwargs):
output = []
for l in range(len(self.kernels)):
if self.conv_mono_list[l] is not None:
kernel = torch.stack([self.conv_mono_list[l][i].weight for i in protein_id])
output.append(torch.squeeze(torch.squeeze(kernel, 1), 1))
return output
def update_grad(self, index, value):
if self.conv_mono_list[index] is not None:
for i in range(self.n_proteins):
self.conv_mono_list[index][i].weight.requires_grad = value
if not value:
self.conv_mono_list[index][i].weight.grad = None
def modify_kernel(self, index=None, shift=0, expand_left=0, expand_right=0, device=None):
# shift mono
for l in range(len(self.kernels)):
if (index is None or index == l) and (self.conv_mono_list[l] is not None):
if expand_left > 0:
self.kernels[l] += expand_left
if expand_right > 0:
self.kernels[l] += expand_right
for i, m in enumerate(self.conv_mono_list[l]):
before_w = m.weight.shape[-1]
# update the weight
if shift >= 1:
m.weight = torch.nn.Parameter(
torch.cat([m.weight[:, :, :, shift:], torch.zeros(1, 1, 4, shift, device=device)], dim=3)
)
elif shift <= -1:
m.weight = torch.nn.Parameter(
torch.cat(
[
torch.zeros(1, 1, 4, -shift, device=device),
m.weight[:, :, :, :shift],
],
dim=3,
)
)
# adding more positions left and right
if expand_left > 0:
m.weight = torch.nn.Parameter(
torch.cat([torch.zeros(1, 1, 4, expand_left, device=device), m.weight[:, :, :, :]], dim=3)
)
if expand_right > 0:
m.weight = torch.nn.Parameter(
torch.cat([m.weight[:, :, :, :], torch.zeros(1, 1, 4, expand_right, device=device)], dim=3)
)
after_w = m.weight.shape[-1]
if after_w != (before_w + expand_left + expand_right):
assert after_w != (before_w + expand_left + expand_right)
def get_kernel_width(self, index):
return self.conv_mono_list[index][0].weight.shape[-1] if self.conv_mono_list[index] is not None else 0
def get_kernel_weights(self, index):
return self.conv_mono_list[index][0].weight if self.conv_mono_list[index] is not None else None
def __len__(self):
return len(self.kernels)
# This class could be used as a bm_generator
class BMPrediction(tnn.Module):
def __init__(self, num_classes, input_size, hidden_size, num_layers, seq_length): # state_size_buff=512):
super().__init__()
self.num_classes = num_classes # number of classes
self.num_layers = num_layers # number of layers
self.input_size = input_size # input size
self.hidden_size = hidden_size # hidden state
self.seq_length = seq_length # sequence length
self.lstm = tnn.LSTM(input_size=input_size, hidden_size=hidden_size, num_layers=num_layers, batch_first=True)
self.linear = tnn.Linear(hidden_size, 253) # fully connected 1
self.conv_mono = tnn.Linear(253, 60) # fully connected last layer
self.relu = tnn.ReLU()
# self.state_size_buff = state_size_buff
# self.h_0 = Variable(torch.zeros(self.num_layers, state_size_buff, self.hidden_size)) # .to(x.device)
# self.c_0 = Variable(torch.zeros(self.num_layers, state_size_buff, self.hidden_size)) # .to(x.device)
def forward(self, residues, **kwargs):
# assert x.size(0) <= self.state_size_buff
# h_0 = self.h_0[:,:x.size(0),:]
# c_0 = self.c_0[:,:x.size(0),:]
h_0 = Variable(
torch.zeros(self.num_layers, residues.size(0), self.hidden_size, device=residues.device)) # hidden state
c_0 = Variable(
torch.zeros(self.num_layers, residues.size(0), self.hidden_size, device=residues.device)) # internal state
# Propagate input through LSTM
output, (hn, cn) = self.lstm(residues, (h_0, c_0)) # lstm with input, hidden, and internal state
hn = hn.view(-1, self.hidden_size) # reshaping the data for Dense layer next
out = self.relu(hn)
out = self.linear(out) # first Dense
out = self.relu(out) # relu
out = self.conv_mono(out) # Final Output
return [out.reshape(residues.shape[0], 4, 15)]
class Decoder(tnn.Module):
def __init__(self, input_size=60, enc_size=21, seq_length=88, **kwargs):
super().__init__()
if "layers" in kwargs and kwargs["layers"] is not None:
layers = kwargs["layers"]
else:
layers = [200, 500, 1000]
self.input_size = input_size # input size
self.enc_size = enc_size
self.seq_length = seq_length
self.output_size = enc_size * seq_length # output size
modules = [tnn.Linear(input_size, layers[0])]
for i in range(len(layers) - 1):
modules.append(tnn.ReLU())
modules.append(tnn.Linear(layers[i], layers[i + 1]))
modules.append(tnn.ReLU())
modules.append(tnn.Linear(layers[len(layers) - 1], self.output_size))
self.decoder = tnn.Sequential(*modules)
def forward(self, x):
x = x.reshape(x.shape[0], -1)
x = self.decoder(x)
x = torch.reshape(x, (x.shape[0], self.enc_size, -1))
return x
# return tnn.functional.softmax(x, dim=1)
# This class should be deleted in the future
class ProteinDNABinding(tnn.Module):
def __init__(self, n_rounds, n_batches, num_classes=1, input_size=21, hidden_size=2, num_layers=1, seq_length=88,
datatype="pbm", **kwargs):
super().__init__()
self.datatype = datatype
self.bm_prediction = BMPrediction(num_classes, input_size, hidden_size, num_layers, seq_length)
self.decoder = mb.models.Decoder(enc_size=input_size, seq_length=seq_length, **kwargs)
self.mubind = MubindFlexibleWeights(n_rounds, n_batches, datatype=datatype)
self.best_model_state = None
self.best_loss = None
self.loss_history = []
self.r2_history = []
self.crit_history = []
self.rec_history = []
self.loss_color = []
self.total_time = 0
def forward(self, x):
if len(x) == 4:
mono, batch, countsum, residues = x
mono_rev = mb.tl.mono2revmono(mono)
elif len(x) == 5:
mono, mono_rev, batch, countsum, residues = x
else:
assert False
weights = self.bm_prediction(residues)
reconstruction = torch.transpose(self.decoder(weights), 1, 2)
weights = tnn.Parameter(weights)
weights = torch.unsqueeze(weights, 1)
pred = self.mubind((mono, mono_rev, batch, countsum, weights))
return pred.view(-1), reconstruction
# expects msa as tensor with dims (n_seq, 21, n_residues)
def get_predicted_bm(self, msa):
msa = torch.transpose(msa, 1, 2)
with torch.no_grad():
weights = self.bm_prediction(msa)
return weights
# Multiple datasets
class DinucMulti(tnn.Module):
def __init__(self, use_dinuc=False, n_datasets=1, n_latent=1, w=8):
super().__init__()
self.use_dinuc = use_dinuc
# Create and initialise weights and biases for the layers.
self.conv_mono = tnn.Conv2d(1, 1, kernel_size=(4, w), bias=False)
self.conv_di = tnn.Conv2d(1, 1, kernel_size=(16, w), bias=False)
self.embedding = tnn.Embedding(n_datasets, n_latent)
self.best_model_state = None
# self.fc = tnn.Linear(193, 1, bias=True)
# torch.nn.init.uniform_(self.fc.weight, 0.0, 2/193)
# self.log_weight_1 = tnn.Parameter(torch.tensor(np.array([[0]]).astype(np.float32)))
# self.log_weight_2 = tnn.Parameter(torch.tensor(np.array([[-5.3]]).astype(np.float32)))
def forward(self, x):
# Create the forward pass through the network.
mono, di, batch = x[0], x[1], x[2]
# print('input shape', mono.shape)
mono = torch.unsqueeze(mono, 1)
mono = mono.type(torch.float32)
# print('mono type', mono.dtype)
# mono = mono.type(torch.LongTensor)
mono = self.conv_mono(mono)
# print(di)
di = torch.unsqueeze(di, 1)
di = di.type(torch.float32)
di = self.conv_di(di)
# this is necessary but it needs to be rellocated
di = di.type(torch.LongTensor)
mono = torch.exp(mono)
di = torch.exp(di)
mono = mono.view(mono.shape[0], -1) # Flatten tensor.
di = di.view(di.shape[0], -1)
if self.use_dinuc:
x = torch.sum(mono, axis=1) + torch.sum(di, axis=1)
else:
x = torch.sum(mono, axis=1)
x = x.view(-1) # Flatten tensor.
# print('x in shape', x.shape)
emb = self.embedding
# print('emb shape', emb.weight.shape)
# print('batch shape', b.shape)
# b = emb(batch).sum(axis=1)
# b = b.view(-1)
# print('b shape (after emb)', b.shape)
b = emb.weight.T @ batch.T.type(torch.float32)
# b = b.view(-1)
x = torch.sum(x.reshape(x.shape[0], 1) * b.T, axis=1)
# print('x after emb', b.shape)
return x