Pretrain Sequence Encoder for K562#

In this tutorial, we will train a sequence encoder to predict chromatin accessibility for K562 cells. This is similar to ChromBPNet and the idea is to learn sequence variation across the whole genome.

Let’s get started!

import warnings
warnings.filterwarnings("ignore")

import numpy as np
from scipy import stats
import torch
import cell2net as cn
import mudata as md
import scanpy as sc
import matplotlib.pyplot as plt
import seaborn as sns

md.set_options(pull_on_update=False)
<mudata._core.config.set_options at 0x7f510d978a70>

Load Multiome Data#

Let’s first import the multiome single-cell data for K562 cells. This dataset contains paired measurements of gene expression (RNA-seq) and chromatin accessibility (ATAC-seq) from the same cells. The raw sequencing is obtained from here and we preprocessed the data using Cell Ranger ARC. You can downloaded the Mudata object here.

mdata = md.read_h5mu(
    "../../../../results/37_K562_10x_multiome/05_create_mdata/mdata.h5mu"
)

Let’s examine the structure of our multiome data:

mdata
MuData object with n_obs × n_vars = 6508 × 151699
  obs:	'total_counts_rna', 'total_counts_atac', 'total_counts_rna_log', 'total_counts_atac_log'
  2 modalities
    rna:	6508 x 15735
      obs:	'orig.ident', 'nCount_RNA', 'nFeature_RNA', 'Sample', 'TSSEnrichment', 'ReadsInTSS', 'ReadsInPromoter', 'ReadsInBlacklist', 'PromoterRatio', 'PassQC', 'NucleosomeRatio', 'nMultiFrags', 'nMonoFrags', 'nFrags', 'nDiFrags', 'BlacklistRatio', 'DoubletScore', 'DoubletEnrichment', 'ReadsInPeaks', 'FRIP', 'nCount_ATAC', 'nFeature_ATAC', 'percent.mt', 'n_genes_by_counts', 'log1p_n_genes_by_counts', 'total_counts', 'log1p_total_counts', 'pct_counts_in_top_50_genes', 'pct_counts_in_top_100_genes', 'pct_counts_in_top_200_genes', 'pct_counts_in_top_500_genes'
      var:	'genes', 'n_cells', 'n_cells_by_counts', 'mean_counts', 'log1p_mean_counts', 'pct_dropout_by_counts', 'total_counts', 'log1p_total_counts', 'highly_variable', 'highly_variable_rank', 'means', 'variances', 'variances_norm'
      uns:	'hvg'
      obsm:	'X_pca', 'X_umap'
      layers:	'counts'
    atac:	6508 x 135964
      obs:	'orig.ident', 'nCount_RNA', 'nFeature_RNA', 'Sample', 'TSSEnrichment', 'ReadsInTSS', 'ReadsInPromoter', 'ReadsInBlacklist', 'PromoterRatio', 'PassQC', 'NucleosomeRatio', 'nMultiFrags', 'nMonoFrags', 'nFrags', 'nDiFrags', 'BlacklistRatio', 'DoubletScore', 'DoubletEnrichment', 'ReadsInPeaks', 'FRIP', 'nCount_ATAC', 'nFeature_ATAC', 'percent.mt', 'n_genes_by_counts', 'log1p_n_genes_by_counts', 'total_counts', 'log1p_total_counts', 'pct_counts_in_top_50_genes', 'pct_counts_in_top_100_genes', 'pct_counts_in_top_200_genes', 'pct_counts_in_top_500_genes'
      var:	'peaks', 'n_cells_by_counts', 'mean_counts', 'log1p_mean_counts', 'pct_dropout_by_counts', 'total_counts', 'log1p_total_counts'
      obsm:	'X_umap'
      layers:	'counts'

Visualize data#

Let’s visualize the RNA and ATAC data to understand the cellular diversity and data quality in our K562 dataset.

RNA UMAP visualization:

sc.pl.umap(mdata["rna"])
../../_images/3cb50476f0fad6679897b1e09c6fc31d7ba58ef1d868ad00b60101ac41ae1538.png

ATAC UMAP visualization:

sc.pl.umap(mdata["atac"])
../../_images/32be408c1363ef3a507900d0c6aad8703b1d637ba06f18797007e72bd3a52198.png

Prepare Training Data#

Now we prepare the data for sequence encoder training by:

  1. Adding peak regions with 256bp windows around accessibility peaks

  2. Extracting DNA sequences from the reference genome for each peak region

This creates paired sequence-accessibility data for training the neural network.

cn.pp.add_peaks(mdata, mod_name='atac', peak_len=256)
cn.pp.add_dna_sequence(mdata, ref_fasta='../../../../data/refdata-gex-GRCh38-2020-A/fasta/genome.fa')

Calculate Accessibility Scores#

Compute total accessibility for each peak by summing counts across all cells, then apply log1p transformation to normalize the distribution.

df_seq = mdata['atac'].uns['peaks']
df_true = mdata['atac'].layers['counts'].todense().sum(axis=0)
df_seq['acc'] = np.array(df_true).flatten()
df_seq['acc'] = np.log1p(df_seq['acc'])

Let’s examine our prepared training data:

df_seq.head()
chr start end summit sequence acc
chr1-812439-812939 chr1 812561 812817 812689 GATAAACATGGAAGCAACCCACATGTCCATCAGTGGATGAATAGAT... 5.802118
chr1-817080-817580 chr1 817202 817458 817330 GACAACCTACTGAATGAGAGAAACTATTTGCAAACTATGCATCTGA... 5.961005
chr1-819720-820220 chr1 819842 820098 819970 TTTCATTCTGAGTAGCTTGATGAAGTCTCATGCCGTCCCACTCAGC... 5.552959
chr1-820480-820980 chr1 820602 820858 820730 ACACTACCTGCTTGTCCAGCAGGTCCACACTGTCTACACTACCTGC... 5.442418
chr1-821021-821521 chr1 821143 821399 821271 CAGCTGATCCGCCCTGTCTACACTACCTGCTTGTCGAGCAGATCTG... 5.361292

Create Train/Validation Split#

We split the data by chromosome to ensure no sequence similarity between training and validation sets. This approach prevents data leakage and provides a more realistic evaluation of model performance on unseen genomic regions.

# split data into training and validation sets by chromosome
chromosomes = df_seq['chr'].unique().tolist()
np.random.seed(42)
np.random.shuffle(chromosomes)
split_idx = int(len(chromosomes) * 0.8)
train_chromosomes = chromosomes[:split_idx]
valid_chromosomes = chromosomes[split_idx:]

df_train = df_seq[df_seq['chr'].isin(train_chromosomes)]
df_valid = df_seq[df_seq['chr'].isin(valid_chromosomes)]

Prepare final training datasets with only the essential columns (DNA sequence and accessibility score):

df_train = df_train[['sequence', 'acc']].reset_index(drop=True)
df_valid = df_valid[['sequence', 'acc']].reset_index(drop=True)

Visualize Data Distribution#

Let’s examine the distribution of accessibility scores in our training and validation sets to ensure they are well-balanced.

Training set distribution:

sns.histplot(df_train['acc'], bins=100)
<Axes: xlabel='acc', ylabel='Count'>
../../_images/6eeb86a01895b060df24c2ea62b5d04d02d9cfc3fe2854297f2321168ed428ce.png

Validation set distribution:

sns.histplot(df_valid['acc'], bins=100)
<Axes: xlabel='acc', ylabel='Count'>
../../_images/197f305ab75ae97d24eeacf2959b02ca0a7c6ca0d5a2f519120d1301d8172434.png

Initialize Sequence-to-Accessibility Model#

Create a Seq2Acc model with:

  • peak_len=256: Input sequence length of 256 base pairs

  • dropout_rate=0.25: Regularization to prevent overfitting

This model uses convolutional neural networks to learn sequence patterns that predict chromatin accessibility.

model = cn.pd.model.Seq2Acc(peak_len=256, dropout_rate=0.25)

Train the Sequence Encoder#

Train the model with the following parameters:

  • max_epochs=100: Maximum training iterations

  • batch_size=512: Number of sequences processed simultaneously

  • weight_decay=1e-04: L2 regularization strength

  • lr=3e-04: Learning rate for optimization

The model learns to predict accessibility from DNA sequence patterns through backpropagation.

model.train(df_train=df_train, df_valid=df_valid, max_epochs=100, 
            batch_size=512, weight_decay=1e-04, lr=3e-04)

Save Pretrained Model#

Save the trained sequence encoder weights for use in downstream Cell2Net modeling. This pretrained encoder captures fundamental sequence-to-accessibility relationships that can improve Cell2Net performance through transfer learning.

model.save_module('./pretrained_seq2acc.pth')

Load Best Model for Evaluation#

Reload the best model weights (based on validation performance) for final evaluation and prediction visualization.

# load best model weights
pretrained_state_dict = torch.load('./pretrained_seq2acc.pth')
model.module.seq_encoder.load_state_dict(pretrained_state_dict)
<All keys matched successfully>

Generate Predictions#

Use the trained model to predict accessibility scores for both training and validation datasets.

df_train['pred_acc'] = model.predict(df_train)
df_valid['pred_acc'] = model.predict(df_valid)

Evaluate Model Performance#

Training Set Performance#

Visualize model predictions vs. true accessibility scores on the training set. The Pearson correlation coefficient measures how well the model learned the sequence-accessibility relationships.

pearsonR, pv = stats.pearsonr(df_train['acc'], df_train['pred_acc'])

ax = sns.jointplot(
    data=df_train,
    x="acc",
    y="pred_acc",
    kind = 'scatter',
    joint_kws={'marker':'o', 's':10, 'alpha':0.1, 'linewidth':0},
    marginal_kws={'bins':20, 'element':'step', 'kde':True, 'linewidth':0},
)
ax.plot_joint(sns.regplot, color="r", scatter=False, 
              line_kws={"color": "orange", 'linestyle':'dashed'})
_min = np.min((np.min(df_train['acc'].values), np.min(df_train['pred_acc'].values)))
_max = np.max((np.max(df_train['acc'].values), np.max(df_train['pred_acc'].values)))

# set axis-x and axis-y the same scale
plt.xlim(_min, _max)
plt.ylim(_min, _max)
plt.title('pearsonR = {:.3f}'.format(pearsonR))
plt.xlabel('True Accessibility (log1p counts)')
plt.ylabel('Predicted Accessibility (log1p counts)')
plt.tight_layout()
../../_images/3e65132dc79c4dc843aa61e94266506dd199f918ed26a961154f90e3981ea9a0.png

Validation Set Performance#

Evaluate model generalization on unseen chromosomes. This validation performance indicates how well the model can predict accessibility for completely new genomic regions, which is crucial for downstream applications.

pearsonR, pv = stats.pearsonr(df_valid['acc'], df_valid['pred_acc'])

ax = sns.jointplot(
    data=df_valid,
    x="acc",
    y="pred_acc",
    kind = 'scatter',
    joint_kws={'marker':'o', 's':10, 'alpha':0.1, 'linewidth':0},
    marginal_kws={'bins':20, 'element':'step', 'kde':True, 'linewidth':0},
)
ax.plot_joint(sns.regplot, color="r", scatter=False, 
              line_kws={"color": "orange", 'linestyle':'dashed'})

_min = np.min((np.min(df_valid['acc'].values), np.min(df_valid['pred_acc'].values)))
_max = np.max((np.max(df_valid['acc'].values), np.max(df_valid['pred_acc'].values)))

# set axis-x and axis-y the same scale
plt.xlim(_min, _max)
plt.ylim(_min, _max)
plt.title('pearsonR = {:.3f}'.format(pearsonR))
plt.gca().set_aspect('equal', adjustable='box')
plt.xlabel('True Accessibility (log1p counts)')
plt.ylabel('Predicted Accessibility (log1p counts)')
plt.tight_layout()
../../_images/50b2717cbaae029014ece87e9390c2226ab1b2a07bd83582d6bdb88d0040732f.png

Summary and Next Steps#

✅ What we accomplished:#

  1. Loaded and preprocessed K562 multiome single-cell data

  2. Prepared training datasets with DNA sequences and accessibility scores

  3. Trained a sequence encoder to predict chromatin accessibility from DNA sequence

  4. Evaluated model performance on held-out chromosomes

  5. Saved pretrained weights for downstream Cell2Net applications

🎯 Key results:#

  • The model successfully learned sequence-to-accessibility relationships

  • Validation performance demonstrates generalization to unseen genomic regions

  • Pretrained encoder is ready for transfer learning in Cell2Net models

💡 Biological insights:#

The trained sequence encoder has learned fundamental patterns that link DNA sequence features to chromatin accessibility. This knowledge can now be transferred to Cell2Net models, potentially improving their ability to predict gene expression from regulatory sequences and cell-specific accessibility patterns.

The pretrained encoder serves as a foundation for understanding regulatory grammar and can be fine-tuned for specific cell types or experimental conditions in downstream analyses.