Mouse neurogenesis scRNA-seq (Noack et al. 2022) | Training with an RNA-dynamics kNN-graph

The studied dataset for this tutorial is mouse neurogenesis scRNA-seq [Noack et al., 2022].

%load_ext autoreload
%autoreload 2
cd ~/workspace/theislab/mubind/docs/notebooks/single_cell
/home/ilibarra/workspace/theislab/mubind/docs/notebooks/single_cell
/home/ilibarra/miniconda3/envs/mubind/lib/python3.9/site-packages/IPython/core/magics/osm.py:393: UserWarning: using bookmarks requires you to install the `pickleshare` library.
  bkms = self.shell.db.get('bookmarks', {})
/home/ilibarra/miniconda3/envs/mubind/lib/python3.9/site-packages/IPython/core/magics/osm.py:417: UserWarning: using dhist requires you to install the `pickleshare` library.
  self.shell.db['dhist'] = compress_dhist(dhist)[-100:]
import pandas as pd
import numpy as np
import os
import seaborn
import itertools
import glob

import pandas as pd
from pathlib import Path
import bindome as bd
bd.constants.ANNOTATIONS_DIRECTORY = 'annotations'
import mubind as mb
from tqdm.notebook import tqdm_notebook as tqdm
# these are motifs used during training (pre-weights)
# pwms = mb.datasets.cisbp_hs()
pwms = mb.datasets.archetypes()
# pwms = pwms[:20]
len(pwms)
# motifs loaded 2179
286
import random
# reduced_groups = [p.to_numpy() for p in random.sample(pwms, 20) if p.shape[-1] != 0]
# suppress numba deprecations warnings
from numba.core.errors import NumbaDeprecationWarning, NumbaPendingDeprecationWarning
import warnings

warnings.simplefilter('ignore', category=NumbaDeprecationWarning)
warnings.simplefilter('ignore', category=NumbaPendingDeprecationWarning)
import mubind as mb
import numpy as np
import pandas as pd
import torch
# mb.models.Mubind
import torch
import torch
import torch.optim as topti
import torch.utils.data as tdata
import matplotlib.pyplot as plt
import logomaker
import os
import scipy
import pickle


# Use a GPU if available, as it should be faster.
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print("Using device: " + str(device))
Using device: cuda:0
# device = 'cpu'
print('here...')
import scanpy as sc
print('here...')
here...
here...
# pybiomart is requred
annot = sc.queries.biomart_annotations(
    "mmusculus",
    ["ensembl_gene_id", "start_position", "end_position", "chromosome_name"],
).set_index("ensembl_gene_id")
# rna.var[annot.columns] = annot
tss_path = os.path.join(bd.constants.ANNOTATIONS_DIRECTORY, 'mm10', 'genome', 'TSS.in')
tss = pd.read_csv(tss_path, sep='\t', header=None)
tss.columns = ['chrom', 'pos', 'strand', 'symbol']
tss = tss.set_index('symbol')
tss = tss[~tss.index.duplicated(keep='first')]
# load the pancreas multiome dataset
rna = sc.read_h5ad('noack_et_al_velocity_graph.h5ad')
# rna, atac = mb.datasets.pancreas_multiome() # data_directory='../../../annotations/scatac')
rna.var[tss.columns] = tss
rna = rna[:,~pd.isnull(rna.var['pos'])].copy()
rna.var['tss_start'] = rna.var['pos'].astype(int)
rna = rna[:,~pd.isnull(rna.var['tss_start'])]
bp = 100
tss_start = rna.var['tss_start'].astype(int)
rna.var['tss_start'] = np.where(rna.var['strand'] == '+', tss_start - bp, tss_start)
rna.var['tss_end'] = np.where(rna.var['strand'] == '+', tss_start, tss_start + bp)

rna.var['chromosome_name'] = rna.var['chrom']
rna.var['k'] = rna.var['chromosome_name'].astype(str) + ':' + rna.var['tss_start'].astype(str) + '-' + rna.var['tss_end'].astype(str)
rna.var['k'] 
index
Xkr4           chr1:3671498-3671598
Rgs20          chr1:5018734-5018834
Slco5a1      chr1:12991135-12991235
Jph1         chr1:17097889-17097989
Mcm3         chr1:20820213-20820313
                     ...           
Sh3kbp1    chrX:159708501-159708601
Rai2       chrX:161717453-161717553
Ap1s2      chrX:163908917-163909017
Ddx3y          chrY:1286613-1286713
Eif2s3y        chrY:1010512-1010612
Name: k, Length: 1948, dtype: object
# rna.shape, atac.shape
# atac_path = '../../../annotations/scatac/noack_et_al_velocity_graph.h5ad'
# os.path.exists(atac_path)
# ad = sc.read_h5ad(atac_path)
ad = rna
ad.var_names
Index(['Xkr4', 'Rgs20', 'Slco5a1', 'Jph1', 'Mcm3', 'Ogfrl1', 'Prim2', 'Bag2',
       'Bend6', 'Aff3',
       ...
       'Kif4', 'Tmem35a', 'Pak3', 'Alas2', 'Smc1a', 'Sh3kbp1', 'Rai2', 'Ap1s2',
       'Ddx3y', 'Eif2s3y'],
      dtype='object', name='index', length=1948)
# if rapids is available, use it
# import rapids_singlecell as rapids
# rapids.pp.neighbors(rna)

# if rapids is not avaible, default scanpy
sc.pp.neighbors(rna)
from matplotlib import rcParams
rcParams['figure.figsize'] = 5, 5

# RNA
sc.pl.umap(rna, color='clusters')

#ATAC
# sc.pl.umap(atac, color='celltype')
../../_images/8a215a0beb45df997d83fcf0efc8d62a41d45d3065f491b84fd742b4fa1c626d.png
rna
AnnData object with n_obs × n_vars = 7469 × 1948
    obs: 'Clusters', '_X', '_Y', 'initial_size_unspliced', 'initial_size_spliced', 'initial_size', 'n_counts', 'louvain', 'clusters', 'velocity_self_transition', 'root_cells', 'end_points', 'velocity_pseudotime', 'latent_time'
    var: 'Accession', 'Chromosome', 'End', 'Start', 'Strand', 'means', 'dispersions', 'dispersions_norm', 'velocity_gamma', 'velocity_r2', 'velocity_genes', 'fit_alpha', 'fit_beta', 'fit_gamma', 'fit_t_', 'fit_scaling', 'fit_std_u', 'fit_std_s', 'fit_likelihood', 'fit_u0', 'fit_s0', 'fit_pval_steady', 'fit_steady_u', 'fit_steady_s', 'fit_alignment_scaling', 'chrom', 'pos', 'strand', 'tss_start', 'tss_end', 'chromosome_name', 'k'
    uns: 'clusters_colors', 'louvain', 'louvain_colors', 'neighbors', 'pca', 'recover_dynamics', 'velocity_graph', 'velocity_graph_neg', 'velocity_settings'
    obsm: 'X_pca', 'X_umap', 'velocity_umap'
    varm: 'PCs', 'loss'
    layers: 'Ms', 'Mu', 'ambiguous', 'fit_t', 'fit_tau', 'fit_tau_', 'matrix', 'spliced', 'unspliced', 'variance_velocity', 'velocity', 'velocity_u'
    obsp: 'distances', 'connectivities'
# scv.pl.velocity_embedding_stream(rna, color='celltype')
acc = (rna.X != 0).sum(axis=0) / rna.shape[0]
acc_score = abs(.5 - acc)
rna.var['acc_score'] = acc_score.A.T
rna.var['acc_score_rank'] = rna.var['acc_score'].rank(ascending=False)
plt.hist(rna.var['acc_score'], bins=35)
# var_sample = atac.var.sort_values('acc_score_rank').index[:n_sample_peaks]
(array([ 41.,  50.,  34.,  39.,  49.,  47.,  49.,  54.,  45.,  46.,  50.,
         46.,  62.,  52.,  46.,  50.,  49.,  48.,  45.,  62.,  53.,  56.,
         54.,  51.,  58.,  41.,  55.,  45.,  60.,  47.,  38.,  47.,  86.,
         91., 202.]),
 array([0.00060249, 0.01487099, 0.02913949, 0.04340799, 0.05767649,
        0.07194499, 0.08621349, 0.10048199, 0.11475049, 0.12901899,
        0.14328749, 0.15755599, 0.17182449, 0.18609299, 0.20036149,
        0.21462999, 0.22889849, 0.243167  , 0.2574355 , 0.271704  ,
        0.2859725 , 0.300241  , 0.3145095 , 0.328778  , 0.3430465 ,
        0.357315  , 0.3715835 , 0.385852  , 0.4001205 , 0.414389  ,
        0.4286575 , 0.442926  , 0.4571945 , 0.471463  , 0.4857315 ,
        0.5       ]),
 <BarContainer object of 35 artists>)
../../_images/6656f51ef92f90d260030857353ef00b5d05e9ace25fa1ca47c03e736fe693d4.png
# n_sample_cells = 750
# n_sample_peaks = 1500
n_cell_types = len(set(rna.obs['clusters']))
n_cell_types
9
rna.shape
(7469, 1948)
rna.obs['clusters'].value_counts()
clusters
N1       2192
N2       1614
NSC      1233
IPC       979
NSC_M     853
IPC_M     337
IN        198
CR         43
HC         20
Name: count, dtype: int64
rna.shape
(7469, 1948)
 # int(n_sample_cells * 1.5)
n_select = rna.shape[0] # 50
n_sample_cells = n_select
n_sample_peaks = rna.shape[1]
# n_select =  int(n_sample_cells / n_cell_types) # bug if > 20
obs_sample = rna.obs.groupby('clusters').sample(n_select, random_state=n_sample_cells, replace=True).index.drop_duplicates()
var_sample = pd.Series(rna.var_names).sample(n_sample_peaks, random_state=n_sample_peaks)
# var_sample = atac.var.sort_values('acc_score_rank').index[:n_sample_peaks]

ad = rna[rna.obs_names.isin(obs_sample),rna.var_names.isin(var_sample)].copy()
ad.shape
(7359, 1948)
ad.obs['clusters'].value_counts()
clusters
N1       2098
N2       1602
NSC      1229
IPC       979
NSC_M     853
IPC_M     337
IN        198
CR         43
HC         20
Name: count, dtype: int64
# ad.uns['velocity_graph'] = rna_sample.uns['velocity_graph']
# # ad.layers['velocity'] = rna_sample.layers['velocity']
# ad.uns['velocity_graph'].shape
rna_sample = rna[ad.obs_names]

mask = rna.obs_names.isin(rna_sample.obs_names)

# local one step transposition to finalize processing
vgraph = rna.uns['velocity_graph']
vgraph = vgraph[mask == 1]
vgraph = vgraph.T
vgraph = vgraph[mask == 1]
vgraph = vgraph.T
vgraph.shape
(7359, 7359)
# pip install scvelo
# the scvelo package is a condition to explore the vector stream visualization. Not priority
import scvelo as scv
rna_sample.uns['velocity_graph'] = vgraph
try:
# atac
    sc.pl.umap(ad)

    # rna
    # append the estimated velocities to the sample adata
    scv.pl.velocity_embedding_stream(rna_sample)

    scv.pl.velocity_embedding_stream(rna_sample)
    rna_sample.shape
except Exception:
    print('sample too small to plot graph')
../../_images/2685e0a2e53f96d06b35dd6973bb543613c5d82efe2c50bbfb8f73bde0b3bae1.png ../../_images/93d40f479929bef74f2451b7f4ade09fbab48eff44b80617feb2fa218db68d9e.png ../../_images/93d40f479929bef74f2451b7f4ade09fbab48eff44b80617feb2fa218db68d9e.png
# ad.var['chr'] = 'chr' + ad.var_names.str.split('-').str[0]
# ad.var['start'] = ad.var_names.str.split('-').str[1].astype(int)
# ad.var['end'] = ad.var_names.str.split('-').str[2].astype(int)

# print('features by chromosome')
# extend = 50
# ad.var['chr'].value_counts()
# ad.var['summit'] = ((ad.var['end'] + ad.var['start']) / 2).astype(int)
# ad.var['summit.start'] = ad.var['summit'] - extend
# ad.var['summit.end'] = ad.var['summit'] + extend
# ad.var['k.summit'] = ad.var['chr'] + ':' + ad.var['summit.start'].astype(str) + '-' + ad.var['summit.end'].astype(str)
ad.var['chr'] = 'chr' + ad.var['Chromosome'].astype(str)
ad.var['summit.start'] = ad.var['tss_start'].astype(str)
ad.var['summit.end'] = ad.var['tss_end'].astype(str)
ad.var['k.summit'] = ad.var['chr'] + ':' + ad.var['summit.start'].astype(str) + '-' + ad.var['summit.end'].astype(str)
ad.shape
(7359, 1948)
!ls /home/ilibarra/miniconda3/envs/mubind/bin/bedtools
/home/ilibarra/miniconda3/envs/mubind/bin/bedtools
n_seqs = 10000 # adata.shape[1] # None # 1000
seqs = mb.bindome.tl.get_sequences_from_bed(ad.var[['chr', 'summit.start', 'summit.end']].head(n_seqs), genome='mm10', uppercase=True,
                                            gen_path='annotations/mm10/genome/mm10.fa',
                                            bin='/home/ilibarra/miniconda3/envs/mubind/bin/bedtools')
                                            # bin='/home/ilibarra/.conda/envs/mubind/bin/bedtools')
                                            # gen_path='../../../annotations/hg38/genome/hg38.fa')

keys = set([s[0] for s in seqs])
ad = ad[:,ad.var['k.summit'].isin(keys)]
# seqs = [[s[0], s[1].upper()] for s in seqs[0]]
len(seqs)
/tmp/tmpflh7ti9k
genome mm10 True
annotations/mm10/genome/mm10.fa
True annotations/mm10/genome/mm10.fa
running bedtools...
['/home/ilibarra/miniconda3/envs/mubind/bin/bedtools', 'getfasta', '-fi', 'annotations/mm10/genome/mm10.fa', '-bed', '/tmp/tmpflh7ti9k', '-fo', '/tmp/tmpflqgv7wa']
/home/ilibarra/miniconda3/envs/mubind/bin/bedtools getfasta -fi annotations/mm10/genome/mm10.fa -bed /tmp/tmpflh7ti9k -fo /tmp/tmpflqgv7wa
1948
# remove Ns
for s in seqs:
    if 'N' in s:
        assert False
    # seqs = [[s[0], s[1].replace('N', '')] for s in seqs]
counts = ad.X.T
counts.shape
(1948, 7359)
next_data = pd.DataFrame(counts.A) # sparse.from_spmatrix(counts.A)
next_data['var'] = next_data.var(axis=1)
# next_data = next_data[range(100)].copy()
next_data.index = [s[1] for s in seqs]
next_data.index.name = 'seq'
next_data.shape
(1948, 7360)
next_data = next_data[~next_data.index.str.contains('N')]
next_data.shape
(1948, 7360)
# sum_index = next_data[next_data.columns[:-1]].var(axis=1).sort_values(ascending=False).index
n_cells = 10000
top_var = next_data[['var']].sort_values('var', ascending=False).index[:n_cells]
# next_data.index
# next_data = next_data.head(10000)
next_data_sel = next_data.reindex(top_var) # .reset_index(drop=True)
del next_data_sel['var']
# next_data_sel.index = next_data_sel['seq']
# del next_data_sel['seq']
df = next_data_sel.copy() # sample
# df = df[df.columns[:5000]] # .head(100) # sample

# shorten sequences/remove duplicates
# df.index = df.index.astype(str).str[35:-35]
# df = df[~df.index.duplicated(keep='first')]

zero_counts = df.sum(axis=1) == 0
df = df[~zero_counts] # remove zeroes

df.shape
(1948, 7359)
ad.shape, df.shape
((7359, 1948), (1948, 7359))
ad = ad[:,~zero_counts].copy()
df2 = df.reset_index().melt('seq').set_index('seq')
# add a baseline count
df2[0] = 1.0
cols = ['batch', 1, 0]
df2.columns = cols
df2 = df2[[0, 1, 'batch']] # cols[::-1]]
df2['batch'] = df2['batch'].astype(int)
print(df2.shape)
print(len(set(df2['batch'])))
n_cells = 200
df2 = df2[df2['batch'].isin(range(0, n_cells))]
print(df2['batch'].value_counts())
df2.head()
(14335332, 3)
7359
batch
0      1948
137    1948
127    1948
128    1948
129    1948
       ... 
69     1948
70     1948
71     1948
72     1948
199    1948
Name: count, Length: 200, dtype: int64
0 1 batch
seq
CCATTGGCCGCCGAGTGTGGGAACGGCCGCGGCGCCCGGACCCCAGGCGCCGGGCCGCTGCCCGCGCCTATATAGGGCTGGCGTGCTGGGGTCCAGGTCG 1.0 2.807686 0
TTGCACGCCTTTAACTGATCCGCGGCGCTGAAGGCGGAGAGACCAGAAGAGGAGTGGCCACAAAGATTGCAATTCTGACATCTTATTGGACCTTTAGGTC 1.0 0.000000 0
CCACGGCAGCCTGCTGGAAAGAGACAAGTTAGTCATTCAAATGATTGGCTTTTTGCCCGCTTCTTCTCTAAATAAGAAGGCAGCAGCTTCTGCTGAGGTG 1.0 4.635506 0
GATAGGGAAAAAATCTGCACCAGCCTTTTACATACAGTAGGTGTGTTTCTCTTCTAGCTGCTCGGGTCTCACTGGCAATCTGTAGAAAGAGCTGTGTTAG 1.0 1.004364 0
AGAAGTCGGATCTGCTTTCGGGGGCGTGTCCATCCTAGCGGAGCGCTCCAACCCTTCCCGGACCCAAGAGGGAGTTTATTGCGCACGCGCGGGGAGGAGC 1.0 4.091436 0

Simple motif enrichment

use_kmers = False
if use_kmers:
    kmers_by_module = {}
    pwm_by_module = {}
    glossary_kmers_dir = '../../../../../zaugglab/apobec2_data_analysis/src/analyses_2023_new_sample/Glossary Modules'
    for f in os.listdir(glossary_kmers_dir):
        if not f.endswith('.xlsx'):
            continue
        name = f.replace('Glossary_module_', '').replace('.xlsx', '')
        kmers = pd.read_excel(os.path.join(glossary_kmers_dir, f), sheet_name='8mers')
        kmers_by_module[name] = kmers[kmers.columns[0]]
        pwm = pd.read_excel(os.path.join(glossary_kmers_dir, f), sheet_name='PWM', index_col=0)
        pwm_by_module[name] = pwm

Generate pos and neg

# import ushuffle

# fg = df.index
# bg = []

# for i, s in enumerate(fg):
#     # print(i, s)
#     bg.append(ushuffle.shuffle(str.encode(s), 2))
# for a, b in zip(fg, bg):
#     # print(len(a[1]), len(b))
#     assert len(a) == len(b)
frac_acc = (df > 0).sum(axis=0) / df.shape[1]
reduced_groups = pwms
df
0 1 2 3 4 5 6 7 8 9 ... 7349 7350 7351 7352 7353 7354 7355 7356 7357 7358
seq
CCATTGGCCGCCGAGTGTGGGAACGGCCGCGGCGCCCGGACCCCAGGCGCCGGGCCGCTGCCCGCGCCTATATAGGGCTGGCGTGCTGGGGTCCAGGTCG 2.807686 2.804581 1.063736 0.533166 0.000000 1.137600 4.033162 1.340032 3.597629 0.451104 ... 2.569399 4.006182 0.661377 0.000000 2.204438 0.000000 3.098087 3.075634 4.121038 0.000000
TTGCACGCCTTTAACTGATCCGCGGCGCTGAAGGCGGAGAGACCAGAAGAGGAGTGGCCACAAAGATTGCAATTCTGACATCTTATTGGACCTTTAGGTC 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000 ... 2.977316 2.622818 2.505485 2.629052 2.508885 2.023399 3.225960 3.322917 3.054366 2.640159
CCACGGCAGCCTGCTGGAAAGAGACAAGTTAGTCATTCAAATGATTGGCTTTTTGCCCGCTTCTTCTCTAAATAAGAAGGCAGCAGCTTCTGCTGAGGTG 4.635506 3.794510 4.562876 2.168771 0.794902 0.000000 2.475682 1.057597 1.693586 0.760848 ... 4.142088 3.331074 0.000000 0.000000 1.764575 0.000000 1.391833 1.768496 3.676070 1.515053
GATAGGGAAAAAATCTGCACCAGCCTTTTACATACAGTAGGTGTGTTTCTCTTCTAGCTGCTCGGGTCTCACTGGCAATCTGTAGAAAGAGCTGTGTTAG 1.004364 0.000000 0.490046 3.128247 2.478890 2.354867 0.000000 2.961377 0.000000 3.216548 ... 0.000000 0.000000 2.934479 3.501090 2.930838 3.130262 2.681020 2.821474 0.000000 2.865914
AGAAGTCGGATCTGCTTTCGGGGGCGTGTCCATCCTAGCGGAGCGCTCCAACCCTTCCCGGACCCAAGAGGGAGTTTATTGCGCACGCGCGGGGAGGAGC 4.091436 3.851065 3.559601 0.533166 1.767620 1.137600 3.653754 1.654172 2.290299 0.760848 ... 3.402821 3.774636 0.661377 0.000000 2.204438 1.681944 2.146579 1.768496 3.518010 1.213856
... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ...
CCCACGGCTCAGGAGAAGGTTTGCCAAAGGGGAGAAGAATGATCTTCAGTTTCTGACCTCCTGGCACCGAGCGAGCGAGTGAGAGAGCGGTGTTTATCAG 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000 ... 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000
CTGCCCTTCCTCCCCCCTTTTCTCCCTCTCTGCTCCACTAGCCCCCTCCCTTCTTGTCTCCCTCTCCGCCCCGTCCCCTCCCTGTCGGCCCGCCCGCCCA 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000 ... 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000
GCTCGGTGCGTGCCAGGCCGGGCGGCGCCGCGGGCGGCGGAGGGAGCAGAGCCCTCCCCCCGGGAGCTGCGCGGCCAAGGTGGAAATGCGGAAGTTTCCG 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000 ... 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000
CTCTGATGGTGGCAACTCAGCAGAAAGCTGCTGGGCTCAGTCTGGCTTTGTTGAGCAACCCTGACTCCACCCCTTTTCTTCCCCACAAAGCAAGCTTTTA 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000 ... 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000
GAGCCGAGCTGAGCAGCAGGCCGACACGAGCTGCCAGCCAATCAGAGACGCGCTCGGGAGCTGCAGAGACCCAGCCTGGGAGCCTGGGGCTGGAGGTGCG 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000 ... 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000

1948 rows × 7359 columns

dataset = mb.datasets.SelexDataset(df, n_rounds=df.shape[1], enr_series=False)
n_batch = len(set(dataset.batch))
n_kernels = len(pwms)
train = tdata.DataLoader(dataset=dataset, batch_size=256, shuffle=True)
# df_neg = df.copy()
# df_neg.index = bg
# df_neg.index = df_neg.index.astype(str)

# dataset_neg = mb.datasets.SelexDataset(df_neg, n_rounds=df_neg.shape[1], enr_series=False)
# train_neg = tdata.DataLoader(dataset=dataset_neg, batch_size=512, shuffle=True)

The parameter log_dynamic defines whether a kNN-graph will be used

len(pwms)
286
# for unit tests
# torch.save(train, '../../../tests/_data/pancreas_multiome.pth')
# ad.write('../../../tests/_data/pancreas_multiome.h5ad', compression='lzf')
rna_sample.shape, ad.shape
((7359, 1948), (7359, 1948))
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(device)
cuda:0
criterion = mb.tl.PoissonLoss()
w = [r.shape[1] for r in pwms]
optimize_log_dynamic = True
model = mb.models.Mubind.make_model(train, n_kernels, criterion, kernels=[0, 2,] + w, # [0, 2] + w,
                                        # use_dinuc=True, dinuc_mode='full',
                                        optimize_sym_weight=False,
                                        optimize_exp_barrier=True,
                                        optimize_prob_act=True,
                                        optimize_log_dynamic=optimize_log_dynamic,
                                        use_dinuc=False,
                                        device=device,
                                        p_dropout=0.8,
                                        prepare_knn=optimize_log_dynamic,
                                        velocity_graph=vgraph,
                                        knn_free_weights=False,
                                        adata=None if not optimize_log_dynamic else ad,
                                        dinuc_mode=None)
print('no err')
(7359, 7359)
setting up log dynamic
no err
model.graph_module.conn_sparse.shape
torch.Size([7359, 7359])
import torch.optim as topti
import warnings

model_by_logdynamic = {}
n_epochs_intercept = 10
n_epochs_kernel = 20
log_each=75
n_unfreeze_kernels=70

for optimize_log_dynamic in [True, False]:
    warnings.filterwarnings("ignore")

    criterion = mb.tl.PoissonLoss()
    w = [r.shape[1] for r in pwms]
    model = mb.models.Mubind.make_model(train, n_kernels, criterion, kernels=[0, 2,] + w, # [0, 2] + w,
                                           # use_dinuc=True, dinuc_mode='full',
                                           optimize_sym_weight=False,
                                           optimize_exp_barrier=True,
                                           optimize_prob_act=True,
                                           optimize_log_dynamic=optimize_log_dynamic,
                                           use_dinuc=False,
                                           device=device,
                                           p_dropout=0.8,
                                           prepare_knn=optimize_log_dynamic,
                                           knn_free_weights=False,
                                           adata=None if not optimize_log_dynamic else ad,
                                           dinuc_mode=None) # .cuda()

    # initialize the reduce kernels
    for i, mono_best in enumerate(pwms):
        mono_best = mono_best.to_numpy()
        if mono_best.shape[-1] == 0:
            continue
        # print(mono_best.shape, model.binding_modes.conv_mono[i + 1].weight.shape)
        # print(model.binding_modes.conv_mono[i + 1].weight.device)
        new_w = mono_best.reshape([1, 1] + list(mono_best.shape))
        # print(i, i + 1)
        model.binding_modes.conv_mono[i + 2].weight = torch.nn.Parameter(torch.tensor(new_w, dtype=torch.float).squeeze(1))
        # print(model.binding_modes.conv_mono[i + 1].weight.device)
    # move the model a final time to the GPU
    model = model.to(device)

    mb.pl.set_rcParams({'figure.figsize': [15, 3], 'figure.dpi': 90})
    mb.pl.logo(model,
               title=False,
               xticks=False,
               rowspan_dinuc=0,
               rowspan_mono=1,
               n_rows=5,
               n_cols=12,
               stop_at=10   ) # n_cols=len(reduced_groups))

    mb.pl.set_rcParams({'figure.figsize': [20, 5], 'figure.dpi': 100})

    # mb.pl.conv(model, n_cols=2)
    import torch.nn as tnn

    model.graph_module.use_hadamard = True
    model, best_loss = model.optimize_iterative(train, n_epochs=[n_epochs_intercept] + [n_epochs_kernel] * (n_kernels + 1),
                                                show_logo=False, use_mono=True, use_dinuc=False, dinuc_mode='local',
                                                opt_kernel_shift=[0, 0] + [0] * (n_kernels),
                                                opt_kernel_length=[0, 0] + [0] * (n_kernels),
                                                opt_one_step=True,
                                                shift_max=1, shift_step=1,
                                                n_unfreeze_kernels=n_unfreeze_kernels,
                                                # optimiser=topti.SGD,
                                                optimiser=topti.Adam,
                                                # skip_kernels=list(range(1, 2)) + list(range(5, 500)),
                                                n_batches=1, n_rounds=2, num_epochs_shift_factor=1, # log_etas=log_etas, # log_etas=log_etas,
                                                kernels = [0] + [2] + [20] * (n_kernels), r2_per_epoch=True,
                                                exp_max=8,
                                                early_stopping=250, log_each=log_each, w=20, max_w=20) #  target_dim=train.dataset.signal.shape[1])

    model_by_logdynamic[optimize_log_dynamic] = model
setting up log dynamic
True
12
order None
0 out of 288...
break
../../_images/d2d44401c89f166e9c66dba09d4fb222932a46193435a3a80b5fd054f9e85371.png
verbose=2
current kernel 0

### next filter to optimize 0 (intercept)

FREEZING KERNELS
optimizing feature type mono
next kernels 0-69, n=70
filters mask None
optimizer:  Adam , criterion: PoissonLoss 
epochs: 10 
early_stopping: 250
lr= 0.01, weight_decay= 0, dir weight= 0
100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 10/10 [02:08<00:00, 12.84s/it]
Current time: 2024-07-08 02:07:21.212463
	Loss: 462.783 , R2: 0.564
	Training time (model/function): (128.447s / 128.447s)
		per epoch (model/function): (14.272s/ 14.272s)
		per 1k samples: 7.326s
('etas corr with lib_sizes (before refinement)', SignificanceResult(statistic=0.9989549262135297, pvalue=0.0))
optimization of dinuc is not valid for the intercept (filter=0). Skip...
current kernel 70

### next filter to optimize 70 

FREEZING KERNELS
optimizing feature type mono
next kernels 70-139, n=70
filters mask None
optimizer:  Adam , criterion: PoissonLoss 
epochs: 20 
early_stopping: 250
lr= 0.01, weight_decay= 0, dir weight= 0
100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 20/20 [09:07<00:00, 27.40s/it]
Current time: 2024-07-08 02:16:33.104925
	Loss: 396.374 , R2: 0.564
	Training time (model/function): (676.464s / 548.017s)
		per epoch (model/function): (35.603s/ 28.843s)
		per 1k samples: 14.806s
('etas corr with lib_sizes (before refinement)', SignificanceResult(statistic=0.9989549262135297, pvalue=0.0))


final refinement step (after shift)...unfreezing all layers
optimizer:  Adam , criterion: PoissonLoss 
epochs: 20 
early_stopping: 250
lr= 0.01, weight_decay= 0, dir weight= 0
100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 20/20 [09:41<00:00, 29.08s/it]
Current time: 2024-07-08 02:26:18.761646
	Loss: 151.898 , R2: 0.564
	Training time (model/function): (1258.246s / 581.782s)
		per epoch (model/function): (66.223s/ 30.620s)
		per 1k samples: 15.719s
best loss 151.898
last five r2 values, by sequential filter optimization: ['0.564']
('etas corr with lib_sizes (before refinement)', SignificanceResult(statistic=0.9989549262135297, pvalue=0.0))
optimizing feature type dinuc
the optimization of dinucleotide features is skipped...
current kernel 140

### next filter to optimize 140 

FREEZING KERNELS
optimizing feature type mono
next kernels 140-209, n=70
filters mask None
optimizer:  Adam , criterion: PoissonLoss 
epochs: 20 
early_stopping: 250
lr= 0.01, weight_decay= 0, dir weight= 0
100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 20/20 [11:18<00:00, 33.94s/it]
Current time: 2024-07-08 02:37:45.806687
	Loss: 374.997 , R2: 0.564
	Training time (model/function): (1937.147s / 678.900s)
		per epoch (model/function): (101.955s/ 35.732s)
		per 1k samples: 18.343s
('etas corr with lib_sizes (before refinement)', SignificanceResult(statistic=0.9989549262135297, pvalue=0.0))


final refinement step (after shift)...unfreezing all layers
optimizer:  Adam , criterion: PoissonLoss 
epochs: 20 
early_stopping: 250
lr= 0.01, weight_decay= 0, dir weight= 0
100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 20/20 [09:23<00:00, 28.19s/it]
Current time: 2024-07-08 02:47:14.220486
	Loss: 145.567 , R2: 0.564
	Training time (model/function): (2501.125s / 563.978s)
		per epoch (model/function): (131.638s/ 29.683s)
		per 1k samples: 15.238s
best loss 145.567
last five r2 values, by sequential filter optimization: ['0.564', '0.564']
('etas corr with lib_sizes (before refinement)', SignificanceResult(statistic=0.9989549262135297, pvalue=0.0))
optimizing feature type dinuc
the optimization of dinucleotide features is skipped...
current kernel 210

### next filter to optimize 210 

FREEZING KERNELS
optimizing feature type mono
next kernels 210-279, n=70
filters mask None
optimizer:  Adam , criterion: PoissonLoss 
epochs: 20 
early_stopping: 250
lr= 0.01, weight_decay= 0, dir weight= 0
100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 20/20 [08:28<00:00, 25.43s/it]
Current time: 2024-07-08 02:55:50.651272
	Loss: 373.124 , R2: 0.564
	Training time (model/function): (3009.795s / 508.670s)
		per epoch (model/function): (158.410s/ 26.772s)
		per 1k samples: 13.743s
('etas corr with lib_sizes (before refinement)', SignificanceResult(statistic=0.9989549262135297, pvalue=0.0))


final refinement step (after shift)...unfreezing all layers
optimizer:  Adam , criterion: PoissonLoss 
epochs: 20 
early_stopping: 250
lr= 0.01, weight_decay= 0, dir weight= 0
100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 20/20 [09:19<00:00, 28.00s/it]
Current time: 2024-07-08 03:05:14.580285
	Loss: 145.011 , R2: 0.564
	Training time (model/function): (3569.923s / 560.128s)
		per epoch (model/function): (187.891s/ 29.480s)
		per 1k samples: 15.134s
best loss 145.011
last five r2 values, by sequential filter optimization: ['0.564', '0.564', '0.564']
('etas corr with lib_sizes (before refinement)', SignificanceResult(statistic=0.9989549262135297, pvalue=0.0))
optimizing feature type dinuc
the optimization of dinucleotide features is skipped...
current kernel 280

### next filter to optimize 280 

FREEZING KERNELS
optimizing feature type mono
next kernels 280-349, n=70
filters mask None
optimizer:  Adam , criterion: PoissonLoss 
epochs: 20 
early_stopping: 250
lr= 0.01, weight_decay= 0, dir weight= 0
100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 20/20 [10:11<00:00, 30.59s/it]
Current time: 2024-07-08 03:15:34.393812
	Loss: 372.959 , R2: 0.564
	Training time (model/function): (4181.818s / 611.895s)
		per epoch (model/function): (220.096s/ 32.205s)
		per 1k samples: 16.532s
('etas corr with lib_sizes (before refinement)', SignificanceResult(statistic=0.9989549262135297, pvalue=0.0))


final refinement step (after shift)...unfreezing all layers
optimizer:  Adam , criterion: PoissonLoss 
epochs: 20 
early_stopping: 250
lr= 0.01, weight_decay= 0, dir weight= 0
100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 20/20 [09:38<00:00, 28.94s/it]
Current time: 2024-07-08 03:25:17.124641
	Loss: 144.963 , R2: 0.564
	Training time (model/function): (4760.813s / 578.995s)
		per epoch (model/function): (250.569s/ 30.473s)
		per 1k samples: 15.643s
best loss 144.962
last five r2 values, by sequential filter optimization: ['0.564', '0.564', '0.564', '0.564']
('etas corr with lib_sizes (before refinement)', SignificanceResult(statistic=0.9989549262135297, pvalue=0.0))
optimizing feature type dinuc
the optimization of dinucleotide features is skipped...

optimization finished:
total time: 4760.813280582428s
Time per epoch (total): 0.828 s
True
12
order None
0 out of 288...
break
../../_images/e2585e916b97b0a74b230faac6fe0b79a914c6b2f90677afd31488f44d7d75a6.png
verbose=2
current kernel 0

### next filter to optimize 0 (intercept)

FREEZING KERNELS
optimizing feature type mono
next kernels 0-69, n=70
filters mask None
optimizer:  Adam , criterion: PoissonLoss 
epochs: 10 
early_stopping: 250
lr= 0.01, weight_decay= 0, dir weight= 0
100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 10/10 [00:03<00:00,  2.93it/s]
Current time: 2024-07-08 03:25:29.620020
	Loss: 451.237 , R2: 0.564
	Training time (model/function): (3.554s / 3.554s)
		per epoch (model/function): (0.395s/ 0.395s)
		per 1k samples: 0.203s
('etas corr with lib_sizes (before refinement)', SignificanceResult(statistic=0.996821185007101, pvalue=0.0))
optimization of dinuc is not valid for the intercept (filter=0). Skip...
current kernel 70

### next filter to optimize 70 

FREEZING KERNELS
optimizing feature type mono
next kernels 70-139, n=70
filters mask None
optimizer:  Adam , criterion: PoissonLoss 
epochs: 20 
early_stopping: 250
lr= 0.01, weight_decay= 0, dir weight= 0
100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 20/20 [00:07<00:00,  2.75it/s]
Current time: 2024-07-08 03:25:40.951607
	Loss: 396.375 , R2: 0.565
	Training time (model/function): (10.948s / 7.394s)
		per epoch (model/function): (0.576s/ 0.389s)
		per 1k samples: 0.200s
('etas corr with lib_sizes (before refinement)', SignificanceResult(statistic=0.996821185007101, pvalue=0.0))


final refinement step (after shift)...unfreezing all layers
optimizer:  Adam , criterion: PoissonLoss 
epochs: 20 
early_stopping: 250
lr= 0.01, weight_decay= 0, dir weight= 0
100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 20/20 [00:07<00:00,  2.55it/s]
Current time: 2024-07-08 03:25:52.831560
	Loss: 151.900 , R2: 0.566
	Training time (model/function): (18.924s / 7.976s)
		per epoch (model/function): (0.996s/ 0.420s)
		per 1k samples: 0.215s
best loss 151.900
last five r2 values, by sequential filter optimization: ['0.565']
('etas corr with lib_sizes (before refinement)', SignificanceResult(statistic=0.996821185007101, pvalue=0.0))
optimizing feature type dinuc
the optimization of dinucleotide features is skipped...
current kernel 140

### next filter to optimize 140 

FREEZING KERNELS
optimizing feature type mono
next kernels 140-209, n=70
filters mask None
optimizer:  Adam , criterion: PoissonLoss 
epochs: 20 
early_stopping: 250
lr= 0.01, weight_decay= 0, dir weight= 0
100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 20/20 [00:07<00:00,  2.73it/s]
Current time: 2024-07-08 03:26:07.778576
	Loss: 377.339 , R2: 0.566
	Training time (model/function): (26.455s / 7.531s)
		per epoch (model/function): (1.392s/ 0.396s)
		per 1k samples: 0.203s
('etas corr with lib_sizes (before refinement)', SignificanceResult(statistic=0.996821185007101, pvalue=0.0))


final refinement step (after shift)...unfreezing all layers
optimizer:  Adam , criterion: PoissonLoss 
epochs: 20 
early_stopping: 250
lr= 0.01, weight_decay= 0, dir weight= 0
100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 20/20 [00:07<00:00,  2.69it/s]
Current time: 2024-07-08 03:26:19.169246
	Loss: 145.566 , R2: 0.566
	Training time (model/function): (34.045s / 7.590s)
		per epoch (model/function): (1.792s/ 0.399s)
		per 1k samples: 0.205s
best loss 145.566
last five r2 values, by sequential filter optimization: ['0.565', '0.565']
('etas corr with lib_sizes (before refinement)', SignificanceResult(statistic=0.996821185007101, pvalue=0.0))
optimizing feature type dinuc
the optimization of dinucleotide features is skipped...
current kernel 210

### next filter to optimize 210 

FREEZING KERNELS
optimizing feature type mono
next kernels 210-279, n=70
filters mask None
optimizer:  Adam , criterion: PoissonLoss 
epochs: 20 
early_stopping: 250
lr= 0.01, weight_decay= 0, dir weight= 0
100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 20/20 [00:07<00:00,  2.80it/s]
Current time: 2024-07-08 03:26:33.566108
	Loss: 375.578 , R2: 0.566
	Training time (model/function): (41.309s / 7.264s)
		per epoch (model/function): (2.174s/ 0.382s)
		per 1k samples: 0.196s
('etas corr with lib_sizes (before refinement)', SignificanceResult(statistic=0.996821185007101, pvalue=0.0))


final refinement step (after shift)...unfreezing all layers
optimizer:  Adam , criterion: PoissonLoss 
epochs: 20 
early_stopping: 250
lr= 0.01, weight_decay= 0, dir weight= 0
100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 20/20 [00:07<00:00,  2.70it/s]
Current time: 2024-07-08 03:26:45.033025
	Loss: 145.014 , R2: 0.565
	Training time (model/function): (48.845s / 7.536s)
		per epoch (model/function): (2.571s/ 0.397s)
		per 1k samples: 0.204s
best loss 145.014
last five r2 values, by sequential filter optimization: ['0.565', '0.565', '0.565']
('etas corr with lib_sizes (before refinement)', SignificanceResult(statistic=0.996821185007101, pvalue=0.0))
optimizing feature type dinuc
the optimization of dinucleotide features is skipped...
current kernel 280

### next filter to optimize 280 

FREEZING KERNELS
optimizing feature type mono
next kernels 280-349, n=70
filters mask None
optimizer:  Adam , criterion: PoissonLoss 
epochs: 20 
early_stopping: 250
lr= 0.01, weight_decay= 0, dir weight= 0
100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 20/20 [00:06<00:00,  2.88it/s]
Current time: 2024-07-08 03:26:59.547257
	Loss: 375.403 , R2: 0.563
	Training time (model/function): (55.980s / 7.134s)
		per epoch (model/function): (2.946s/ 0.375s)
		per 1k samples: 0.193s
('etas corr with lib_sizes (before refinement)', SignificanceResult(statistic=0.996821185007101, pvalue=0.0))


final refinement step (after shift)...unfreezing all layers
optimizer:  Adam , criterion: PoissonLoss 
epochs: 20 
early_stopping: 250
lr= 0.01, weight_decay= 0, dir weight= 0
100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 20/20 [00:07<00:00,  2.85it/s]
Current time: 2024-07-08 03:27:10.379136
	Loss: 144.965 , R2: 0.567
	Training time (model/function): (63.099s / 7.119s)
		per epoch (model/function): (3.321s/ 0.375s)
		per 1k samples: 0.192s
best loss 144.962
last five r2 values, by sequential filter optimization: ['0.565', '0.565', '0.565', '0.566']
('etas corr with lib_sizes (before refinement)', SignificanceResult(statistic=0.996821185007101, pvalue=0.0))
optimizing feature type dinuc
the optimization of dinucleotide features is skipped...

optimization finished:
total time: 63.09895467758179s
Time per epoch (total): 0.011 s
lib_sizes = train.dataset.rounds.sum(axis=0).flatten()
model.device
device(type='cuda', index=0)
for use_logdynamic in [False, True]:
    p = 'noack_2022_use_logdynamic_%i_obs%i_var%i.pth' % (use_logdynamic, ad.shape[0], ad.shape[1])
    print(p)
    torch.save(model_by_logdynamic[use_logdynamic], p)
noack_2022_use_logdynamic_0_obs7359_var1948.pth
noack_2022_use_logdynamic_1_obs7359_var1948.pth
ad.shape
(7359, 1948)
ad.write('noack_2022_sample_train_obs%i_var%i.h5ad' % (ad.shape[0], ad.shape[1]))
print('noack_2022_sample_train_obs%i_var%i.h5ad' % (ad.shape[0], ad.shape[1]))
noack_2022_sample_train_obs7359_var1948.h5ad
import pickle
pickle.dump(train, open('train_dataloader_obs%i_var%i.pkl' % (ad.shape[0], ad.shape[1]), 'wb'))