Pretrain Sequence Encoder for PBMC#

This tutorial demonstrates how to pretrain a sequence encoder for predicting chromatin accessibility from DNA sequences using Peripheral Blood Mononuclear Cells (PBMC) data. This approach creates a foundational model that can be used for downstream Cell2Net applications across immune cell types.

import warnings
warnings.filterwarnings("ignore")

import numpy as np
from scipy import stats
import torch
import cell2net as cn
import mudata as md
import scanpy as sc
import matplotlib.pyplot as plt
import seaborn as sns

md.set_options(pull_on_update=False)
<mudata._core.config.set_options at 0x7fdee33e3560>

Load PBMC Multiome Data#

Import the preprocessed PBMC multiome dataset containing paired RNA-seq and ATAC-seq measurements from the same cells. This dataset represents the cellular diversity of the immune system. You can downloaded the Mudata object here.

mdata = md.read_h5mu("./mdata.h5mu")

Load and examine the MuData object structure:

mdata
MuData object with n_obs × n_vars = 11131 × 131516
  obs:	'cell_type', 'cell_type_v2', 'total_counts_rna', 'total_counts_atac', 'total_counts_rna_log', 'total_counts_atac_log'
  2 modalities
    rna:	11131 x 15932
      obs:	'cell_type', 'n_genes_by_counts', 'log1p_n_genes_by_counts', 'total_counts', 'log1p_total_counts', 'pct_counts_in_top_50_genes', 'pct_counts_in_top_100_genes', 'pct_counts_in_top_200_genes', 'pct_counts_in_top_500_genes', 'cell_type_v2'
      var:	'genes', 'n_cells', 'n_cells_by_counts', 'mean_counts', 'log1p_mean_counts', 'pct_dropout_by_counts', 'total_counts', 'log1p_total_counts', 'highly_variable', 'highly_variable_rank', 'means', 'variances', 'variances_norm'
      uns:	'cell_type_colors', 'cell_type_v2_colors', 'hvg'
      obsm:	'X_umap'
      layers:	'counts'
    atac:	11131 x 115584
      obs:	'cell_type', 'n_genes_by_counts', 'log1p_n_genes_by_counts', 'total_counts', 'log1p_total_counts', 'pct_counts_in_top_50_genes', 'pct_counts_in_top_100_genes', 'pct_counts_in_top_200_genes', 'pct_counts_in_top_500_genes', 'cell_type_v2'
      var:	'peaks', 'n_cells_by_counts', 'mean_counts', 'log1p_mean_counts', 'pct_dropout_by_counts', 'total_counts', 'log1p_total_counts'
      uns:	'cell_type_colors', 'cell_type_v2_colors'
      obsm:	'X_umap'
      layers:	'counts'

Visualize PBMC Cell Type Diversity#

Explore the cellular composition and quality of our PBMC dataset using UMAP visualizations. This helps us understand the immune cell populations present and their transcriptional and epigenetic relationships.

RNA Expression Landscape#

Visualize the transcriptional diversity across PBMC cell types. Each point represents a single cell, colored by its annotated cell type.

sc.pl.umap(mdata["rna"], color="cell_type_v2")
../../_images/dca34d4ba311f3a47d9304b0d3443ac28b85249ffc7e1f85f31a8010a036ad07.png

Chromatin Accessibility Landscape#

Examine the epigenetic diversity across the same PBMC cell types. The ATAC-seq data reveals cell-type-specific regulatory elements and chromatin states.

sc.pl.umap(mdata["atac"], color="cell_type_v2")
../../_images/226ceadebbc297e81717c79a50c34a7818d55cf97505d6ada8c8909ef5ba1f40.png

Prepare Sequence Training Data#

Transform the ATAC-seq peaks into training data for sequence modeling:

  1. Extract peak regions: Define 256bp windows around accessibility peaks

  2. Retrieve DNA sequences: Get reference genome sequences for each peak

  3. Create sequence-accessibility pairs: Link DNA sequence with observed chromatin accessibility

This creates the fundamental training data pairing genomic sequence with regulatory activity.

cn.pp.add_peaks(mdata, mod_name='atac', peak_len=256)
cn.pp.add_dna_sequence(mdata, ref_fasta='../../../../data/refdata-gex-GRCh38-2020-A/fasta/genome.fa')

Calculate Aggregated Accessibility Scores#

Compute total accessibility for each genomic region by summing counts across all PBMC cells, then apply log1p transformation for normalized distribution. This aggregation captures the overall regulatory potential of each sequence while maintaining signal from all cell types.

df_seq = mdata['atac'].uns['peaks']
df_true = mdata['atac'].layers['counts'].todense().sum(axis=0)
df_seq['acc'] = np.array(df_true).flatten()
df_seq['acc'] = np.log1p(df_seq['acc'])

Examine the prepared sequence-accessibility training data:

df_seq.head()
chr start end summit sequence acc
chr1-817094-817594 chr1 817216 817472 817344 TGAGAGAAACTATTTGCAAACTATGCATCTGACAAAGGCCTAATAT... 7.178545
chr1-821831-822331 chr1 821953 822209 822081 TACCCTGTCTACACTACCTACCTGCCCAGCAGATCCACCCTGTCTT... 5.645447
chr1-826616-827116 chr1 826738 826994 826866 TTGTTCAAATTCCTCAGACCCATCCACTTCCCATCTAAACTTCCCT... 6.137727
chr1-827302-827802 chr1 827424 827680 827552 ACGCCGCAGGGAGGGGACTGCGTGGCTGGGTTTGGCCACAAAAAGC... 8.764834
chr1-842067-842567 chr1 842189 842445 842317 GAAGCGCACGCACCAGCCGTTCCACCCACCGAAGTTGTTTGGTCCC... 5.153292

Create Chromosome-Based Data Split#

Split the training data by chromosome to ensure robust evaluation without sequence similarity leakage. This approach:

  • Prevents overfitting: No sequence overlap between training and validation

  • Tests generalization: Model must predict accessibility on completely unseen chromosomes

  • Mimics real applications: Evaluates performance on novel genomic regions

Training chromosomes (80%): Used for model learning
Validation chromosomes (20%): Used for performance assessment

# split data into training and validation sets by chromosome
chromosomes = df_seq['chr'].unique().tolist()
np.random.seed(42)
np.random.shuffle(chromosomes)
split_idx = int(len(chromosomes) * 0.8)
train_chromosomes = chromosomes[:split_idx]
valid_chromosomes = chromosomes[split_idx:]

print("Training chromosomes:", train_chromosomes)
print("Validation chromosomes:", valid_chromosomes)
Training chromosomes: ['chr16', 'chr10', 'chr1', 'chr9', 'chr18', 'chr13', 'chr2', 'chr14', 'chr6', 'chr3', 'chr12', 'chr21', 'chr4', 'chr5', 'chr19', 'chr17', 'chr22', 'chrX']
Validation chromosomes: ['chr8', 'chr11', 'chr15', 'chr20', 'chr7']

Extract essential columns (DNA sequence and accessibility score) for model training:

df_train = df_seq[df_seq['chr'].isin(train_chromosomes)]
df_valid = df_seq[df_seq['chr'].isin(valid_chromosomes)]

df_train = df_train[['sequence', 'acc']].reset_index(drop=True)
df_valid = df_valid[['sequence', 'acc']].reset_index(drop=True)

Examine Data Distributions#

Visualize the distribution of accessibility scores to ensure balanced and well-distributed training data across both datasets.

Training Set Distribution:#

sns.histplot(df_train['acc'], bins=100)
<Axes: xlabel='acc', ylabel='Count'>
../../_images/59b5cc09f71fe0db1bc63445d2ca503fe303848fa67c494634908d5393ab19d6.png

Validation Set Distribution:#

sns.histplot(df_valid['acc'], bins=100)
<Axes: xlabel='acc', ylabel='Count'>
../../_images/5b4d1bfc6ddb4f073b21881e6d33c187920a8e650b9b615b6ca7286ae7c6bec5.png

Initialize Seq2Acc Model#

Create a sequence-to-accessibility model with optimized architecture:

  • peak_len=256: Input DNA sequences of 256 base pairs

  • dropout_rate=0.25: Regularization to prevent overfitting across diverse cell types

The Seq2Acc model uses convolutional neural networks to learn sequence motifs and patterns that predict chromatin accessibility across the immune system.

model = cn.pd.model.Seq2Acc(peak_len=256, dropout_rate=0.25)

Train the Sequence Encoder#

Train the model to learn sequence-to-accessibility relationships using optimized hyperparameters:

  • max_epochs=100: Maximum training iterations with early stopping

  • batch_size=512: Large batches for stable gradient estimates across diverse PBMC sequences

  • weight_decay=1e-04: L2 regularization to prevent overfitting

  • lr=3e-04: Learning rate optimized for convergence across immune cell types

The model learns to recognize DNA sequence patterns (transcription factor binding sites, GC content, repeats) that predict chromatin accessibility across the diverse PBMC cellular landscape.

model.train(df_train=df_train, df_valid=df_valid, max_epochs=100, 
            batch_size=512, weight_decay=1e-04, lr=3e-04)

Save Pretrained Encoder Weights#

Save the trained sequence encoder for transfer learning in downstream Cell2Net applications. The pretrained weights capture fundamental sequence-to-accessibility relationships learned from PBMC diversity.

model.save_module('./pretrained_seq2acc.pth')

Load Best Model for Evaluation#

Reload the optimal model weights (selected based on validation performance) for final evaluation and visualization.

# load best model weights
pretrained_state_dict = torch.load('./pretrained_seq2acc.pth')
model.module.seq_encoder.load_state_dict(pretrained_state_dict)
<All keys matched successfully>

Generate Model Predictions#

Apply the trained model to predict accessibility scores for both training and validation datasets.

df_train['pred_acc'] = model.predict(df_train)
df_valid['pred_acc'] = model.predict(df_valid)
pearsonR, pv = stats.pearsonr(df_train['acc'], df_train['pred_acc'])

ax = sns.jointplot(
    data=df_train,
    x="acc",
    y="pred_acc",
    kind = 'scatter',
    joint_kws={'marker':'o', 's':10, 'alpha':0.2, 'linewidth':0},
    marginal_kws={'bins':20, 'element':'step', 'kde':True, 'linewidth':0},
)
ax.plot_joint(sns.regplot, color="r", scatter=False, 
              line_kws={"color": "orange", 'linestyle':'dashed'})
_min = np.min((np.min(df_train['acc'].values), np.min(df_train['pred_acc'].values)))
_max = np.max((np.max(df_train['acc'].values), np.max(df_train['pred_acc'].values)))

# set axis-x and axis-y the same scale
plt.xlim(_min, _max)
plt.ylim(_min, _max)
plt.title('pearsonR = {:.3f}'.format(pearsonR))
plt.xlabel('True Accessibility (log1p counts)')
plt.ylabel('Predicted Accessibility (log1p counts)')
plt.tight_layout()
../../_images/a81aa898e693e8b9c9ed4aabeb70c920a04cfa444d784856dd54bd3b419108ee.png

Evaluate Model Performance#

Training Set Performance#

Assess how well the model learned sequence-to-accessibility relationships on training data. The Pearson correlation measures the linear relationship between predicted and true accessibility values.

pearsonR, pv = stats.pearsonr(df_valid['acc'], df_valid['pred_acc'])

ax = sns.jointplot(
    data=df_valid,
    x="acc",
    y="pred_acc",
    kind = 'scatter',
    joint_kws={'marker':'o', 's':10, 'alpha':0.1, 'linewidth':0},
    marginal_kws={'bins':20, 'element':'step', 'kde':True, 'linewidth':0},
)
ax.plot_joint(sns.regplot, color="r", scatter=False, 
              line_kws={"color": "orange", 'linestyle':'dashed'})

_min = np.min((np.min(df_valid['acc'].values), np.min(df_valid['pred_acc'].values)))
_max = np.max((np.max(df_valid['acc'].values), np.max(df_valid['pred_acc'].values)))

# set axis-x and axis-y the same scale
plt.xlim(_min, _max)
plt.ylim(_min, _max)
plt.title('pearsonR = {:.3f}'.format(pearsonR))
plt.gca().set_aspect('equal', adjustable='box')
plt.xlabel('True Accessibility (log1p counts)')
plt.ylabel('Predicted Accessibility (log1p counts)')
plt.tight_layout()
../../_images/b601b33a24b7c9581e4a2583fc04c18f9b8f3636d7ee0adf54c021c41ad2e6e8.png

Validation Set Performance#

Evaluate model generalization on completely unseen chromosomes. This tests the model’s ability to predict accessibility for novel genomic regions across the PBMC immune system diversity.

Summary and Next Steps#

✅ Accomplishments:#

  1. Loaded PBMC multiome data representing immune system cellular diversity

  2. Prepared sequence-accessibility training pairs from 256bp genomic windows

  3. Trained a sequence encoder to predict chromatin accessibility from DNA sequence

  4. Evaluated performance on held-out chromosomes demonstrating generalization

  5. Saved pretrained weights for transfer learning in downstream Cell2Net applications

🎯 Key Results:#

  • Successfully learned sequence-to-accessibility relationships across immune cell types

  • Demonstrated generalization to unseen genomic regions

  • Created a reusable sequence encoder capturing PBMC regulatory grammar

The trained encoder has learned fundamental sequence patterns that drive chromatin accessibility across. It serves as a powerful foundation for downstream Cell2Net applications focused on immune system regulation and function.