Training Cell2net for PBMC Dataset#
This tutorial demonstrates how to train Cell2Net models for gene expression prediction using peripheral blood mononuclear cell (PBMC) data.
Workflow Summary#
Data Loading: Load prepared multiome data with peak-to-gene associations
Model Initialization: Create Cell2Net models with pre-trained sequence encoders
Training: Train individual models for each target gene using paired RNA/ATAC data
Validation: Evaluate model performance on held-out validation set
Results: Generate predictions and visualize training metrics
This approach enables Cell2Net to learn complex regulatory relationships between chromatin accessibility patterns, DNA sequence motifs, and gene expression in immune cells.
import warnings
warnings.filterwarnings("ignore")
import os
import mudata as md
import cell2net as cn
import pandas as pd
import numpy as np
import seaborn as sns
import matplotlib.pyplot as plt
from sklearn.model_selection import train_test_split
md.set_options(pull_on_update=False)
import torch
Check Cell2Net Version#
Verify the installed version of Cell2Net to ensure compatibility with this tutorial.
cn.__version__
'0.13'
2. Setup Output Directory#
Create the output directory structure for storing:
model/: Trained Cell2Net models for each gene (saved as .pt files)
plot/: Training visualization plots showing loss curves and prediction scatter plots
prediction/: Numerical results including predictions and performance metrics
This organized structure facilitates result analysis and model deployment.
out_dir = './03_train_cell2net'
os.makedirs(out_dir, exist_ok=True)
os.makedirs(f"{out_dir}/model", exist_ok=True)
os.makedirs(f"{out_dir}/plot", exist_ok=True)
os.makedirs(f"{out_dir}/prediction", exist_ok=True)
3. Load Prepared Data#
Load the multiome dataset prepared in the previous tutorial step. This MuData object contains:
RNA modality: Gene expression counts for immune cell populations
ATAC modality: Chromatin accessibility peaks across the genome
Peak-to-gene associations: Regulatory links between accessible regions and target genes
Sequence information: DNA sequences around regulatory peaks for motif analysis
The genes list contains all target genes with sufficient peak-to-gene associations for training reliable Cell2Net models.
mdata = md.read_h5mu("./02_prepare_data/mdata.h5mu")
genes = mdata.uns['peak_to_gene']['gene'].unique().tolist()
Inspect the Data Structure#
Examine the loaded MuData object to understand:
Number of cells: Metacells representing cell type populations
Number of genes: Target genes for expression prediction
Number of peaks: Accessible chromatin regions from ATAC-seq
Data modalities: RNA and ATAC measurements integrated in single object
mdata
MuData object with n_obs × n_vars = 1000 × 131516
obs: 'cell_type', 'cell_type_v2', 'total_counts_rna', 'total_counts_atac', 'total_counts_rna_log', 'total_counts_atac_log'
uns: 'motifs', 'peak_to_gene'
2 modalities
rna: 1000 x 15932
obs: 'cell_type', 'n_genes_by_counts', 'log1p_n_genes_by_counts', 'total_counts', 'log1p_total_counts', 'pct_counts_in_top_50_genes', 'pct_counts_in_top_100_genes', 'pct_counts_in_top_200_genes', 'pct_counts_in_top_500_genes', 'cell_type_v2'
var: 'genes', 'n_cells', 'n_cells_by_counts', 'mean_counts', 'log1p_mean_counts', 'pct_dropout_by_counts', 'total_counts', 'log1p_total_counts', 'highly_variable', 'highly_variable_rank', 'means', 'variances', 'variances_norm'
uns: 'cell_type_v2_colors', 'gene_tf', 'gene_tss_coord', 'neighbors', 'pca', 'umap'
obsm: 'X_pca', 'X_umap'
varm: 'PCs'
layers: 'counts'
obsp: 'connectivities', 'distances'
atac: 1000 x 115584
obs: 'cell_type', 'n_genes_by_counts', 'log1p_n_genes_by_counts', 'total_counts', 'log1p_total_counts', 'pct_counts_in_top_50_genes', 'pct_counts_in_top_100_genes', 'pct_counts_in_top_200_genes', 'pct_counts_in_top_500_genes', 'cell_type_v2'
var: 'peaks', 'n_cells_by_counts', 'mean_counts', 'log1p_mean_counts', 'pct_dropout_by_counts', 'total_counts', 'log1p_total_counts'
uns: 'motif_match', 'peaks'
varm: 'motif_match'
layers: 'counts'Number of Target Genes#
Display the total number of genes that will be modeled. Each gene requires sufficient peak-to-gene associations to train a robust Cell2Net model. Typically, this includes highly expressed genes with well-characterized regulatory regions in the PBMC dataset.
len(genes)
1927
4. Train-Validation Split#
Divide the data into training (80%) and validation (20%) sets using stratified random sampling. This ensures:
Training set: Used for model parameter optimization and learning regulatory patterns
Validation set: Independent evaluation to assess model generalization and prevent overfitting
Reproducibility: Fixed random seed (42) ensures consistent splits across runs
The split is performed at the cell level, maintaining the integrity of multiome measurements within each metacell.
train_idx, valid_idx = train_test_split(
mdata.obs_names.values.tolist(),
train_size=0.8,
random_state=42)
5. Load Pre-trained Sequence Encoder#
Initialize the pre-trained sequence encoder that was trained in the previous tutorial step. This encoder:
Processes DNA sequences: Converts nucleotide sequences to dense embeddings
Captures motif patterns: Learns transcription factor binding site preferences
Transfer learning: Pre-trained weights provide strong initialization for gene-specific training
The pre-trained model significantly improves training efficiency and final model performance by leveraging sequence-level patterns learned across the entire genome.
pretrained_model_path = "./pretrained_seq2acc.pth"
pretrained_state_dict = torch.load(pretrained_model_path, map_location="cpu")
6. Train Cell2Net Models#
This is the main training loop that creates and trains individual Cell2Net models for each target gene. The process includes:
Model Architecture#
Cell2Net Framework: Integrates sequence and accessibility information for gene expression prediction
Sequence Encoder: Pre-trained transformer that processes DNA sequences around regulatory peaks
Accessibility Encoder: Neural network that processes ATAC-seq peak intensities
Covariates: Include total RNA/ATAC counts to account for technical variation
Training Configuration#
Epochs: 40 training epochs with early stopping based on validation performance
Batch Size: 32 metacells per batch for efficient GPU utilization
Learning Rate: 1e-4 with adaptive optimization
Device: GPU acceleration for faster training (cuda:1)
Model Outputs#
For each gene, the training produces:
Saved Model: Best checkpoint saved as .pt file for future use
Performance Plots: Training/validation loss curves and prediction scatter plots
Predictions: Numerical results including correlation metrics and raw predictions
Biological Interpretation#
Each model learns how chromatin accessibility patterns and DNA sequence motifs around regulatory peaks contribute to gene expression in different immune cell types. The integration of sequence and accessibility enables Cell2Net to capture complex regulatory logic governing immune cell gene programs.
for gene in genes:
if os.path.exists(f'{out_dir}/model/{gene}.pt'):
continue
print('Training model for gene:', gene)
cn.utils.set_random_seed(42)
model = cn.pd.model.Cell2Net(mdata=mdata,
gene=gene,
covariates=['total_counts_rna_log', 'total_counts_atac_log'])
# load pretrained weights for the sequence encoder
model.module.seq_encoder.load_state_dict(pretrained_state_dict)
model.train(max_epochs=40,
device_name='cuda:1',
batch_size=32,
num_workers=4,
lr=1e-4,
verbose=False)
model.save(dir_path=f"{out_dir}/model")
# set the model with the best checkpoint
model.module.load_state_dict(model.check_point)
# Evaluate the model for training and validation dataset
train_pred = model.predict(model.mdata[train_idx])
train_true = model.mdata[train_idx]["rna"].layers["counts"].todense().A1
valid_pred = model.predict(model.mdata[valid_idx])
valid_true = model.mdata[valid_idx]["rna"].layers["counts"].todense().A1
df_train = pd.DataFrame({
'true': train_true,
'pred': train_pred,
'data': 'train'
})
df_valid = pd.DataFrame({
'true': valid_true,
'pred': valid_pred,
'data': 'valid'
})
df_train['true'] = np.log1p(df_train['true'])
df_valid['true'] = np.log1p(df_valid['true'])
df_train['pred'] = np.log1p(df_train['pred'])
df_valid['pred'] = np.log1p(df_valid['pred'])
fig, axes = plt.subplots(2, 2, figsize=(8, 8))
sns.lineplot(data=model.history, x='epochs', y='train_loss', label='train', ax=axes[0, 0])
sns.lineplot(data=model.history, x='epochs', y='valid_loss', label='val', ax=axes[0, 1])
sns.scatterplot(data=df_train, x='true', y='pred', ax=axes[1, 0], label='train')
sns.scatterplot(data=df_valid, x='true', y='pred', ax=axes[1, 1], label='valid')
axes[0, 0].set_xlabel('Epoch')
axes[0, 0].set_ylabel('Loss')
# save figure
plt.savefig(f'{out_dir}/plot/{gene}.png', dpi=300)
plt.close()
np.savez(f'{out_dir}/prediction/{gene}.npz',
best_valid_corr=model.best_valid_corr,
train_pred=train_pred,
train_true=train_true,
valid_pred=valid_pred,
valid_true=valid_true)
Training model for gene: ATRN
Training model for gene: PLCB1
Training model for gene: LAMP5
Training model for gene: AL050403.2
Training model for gene: ISM1
Training model for gene: MACROD2
Training model for gene: SLC24A3
Training model for gene: RIN2
Training model for gene: CST7
Training model for gene: FAM182A
Training model for gene: ID1
Training model for gene: TPX2
Training model for gene: HCK
Training model for gene: MYBL2
Training model for gene: TOX2
Training model for gene: PKIG
Training model for gene: CEBPB
Training model for gene: SMIM25
Training model for gene: LINC01524
Training model for gene: TSHZ2
Training model for gene: MIR646HG
Training model for gene: NRSN2-AS1
Training model for gene: AL110114.1
Training model for gene: SIRPB2
Training model for gene: SIRPD
Training model for gene: C20orf194
Training model for gene: SIGLEC1
Training model for gene: RNF24
Training model for gene: GPCPD1
Training model for gene: TMX4
Training model for gene: CST3
Training model for gene: ZNF341-AS1
Training model for gene: AL035458.2
Training model for gene: MAFB
Training model for gene: ZHX3
Training model for gene: SULF2
Training model for gene: B4GALT5
Training model for gene: AL109930.1
Training model for gene: BCAS1
Training model for gene: PMEPA1
Training model for gene: CTSZ
Training model for gene: ZBTB46
Training model for gene: AP001347.1
Training model for gene: MIR99AHG
Training model for gene: CHODL
Training model for gene: LINC01684
Training model for gene: MIR155HG
Training model for gene: BACH1
Training model for gene: EVA1C
Training model for gene: ITSN1
Training model for gene: AP000317.1
Training model for gene: KCNJ15
Training model for gene: BACE2
Training model for gene: MX2
Training model for gene: MX1
Training model for gene: TRPM2
Training model for gene: PCBP3
Training model for gene: COL6A2
Training model for gene: SAMSN1
Training model for gene: NRIP1
Training model for gene: AF130417.1
Training model for gene: APP
Training model for gene: ADAMTS5
Training model for gene: AF165147.1
Training model for gene: TIAM1
Training model for gene: LINC00159
Training model for gene: AP000282.1
Training Complete#
The Cell2Net training process is now complete! You have successfully:
Generated Outputs#
Trained Models: Individual Cell2Net models for each target gene stored in
./03_train_cell2net/model/Performance Visualizations: Training plots showing loss curves and prediction accuracy in
./03_train_cell2net/plot/Prediction Results: Numerical predictions and performance metrics in
./03_train_cell2net/prediction/
The trained models capture the complex regulatory landscape of immune cells, enabling prediction of gene expression from chromatin accessibility and sequence information alone.