Training Cell2Net models for K562#
This tutorial demonstrates how to train Cell2Net models for predicting gene expression from chromatin accessibility data in K562 cells. We’ll leverage a pretrained sequence encoder to improve model performance and training efficiency.
import warnings
warnings.filterwarnings("ignore")
import os
import mudata as md
import cell2net as cn
import torch
import seaborn as sns
import matplotlib.pyplot as plt
import pandas as pd
import numpy as np
from sklearn.model_selection import train_test_split
md.set_options(pull_on_update=False)
<mudata._core.config.set_options at 0x7ff16e7dede0>
cn.__version__
'0.13'
Setup Directories and Paths#
We define input and output paths for this training session. The multiome data should be prepared from the previous tutorial (02_prepare_data).
input_data = "./02_prepare_data/mdata.h5mu"
out_dir = './03_train'
os.makedirs(out_dir, exist_ok=True)
Load Data and Extract Target Genes#
Load the preprocessed multiome data and extract the list of genes that have valid peak-to-gene associations. These will be our modeling targets.
mdata = md.read_h5mu(input_data)
genes = mdata.uns['peak_to_gene']['gene'].unique().tolist()
Let’s check how many genes we’ll be training models for:
len(genes)
1608
Create Train/Validation Split#
We split the cells into training (80%) and validation (20%) sets to properly evaluate model performance and prevent overfitting.
train_idx, valid_idx = train_test_split(
mdata.obs_names.values.tolist(),
train_size=0.8,
random_state=42)
Load Pretrained Sequence Encoder#
We load the pretrained sequence encoder weights from the first tutorial. This transfer learning approach helps models converge faster and achieve better performance by leveraging learned sequence-to-accessibility patterns.
pretrained_model_path = "./pretrained_seq2acc.pth"
pretrained_state_dict = torch.load(pretrained_model_path, map_location="cpu")
Setup Output Directories#
Create directories to organize our training outputs:
model/: Trained model checkpointsplot/: Training curves and evaluation plotsprediction/: Numerical results and predictions
os.makedirs(f"{out_dir}/model", exist_ok=True)
os.makedirs(f"{out_dir}/plot", exist_ok=True)
os.makedirs(f"{out_dir}/prediction", exist_ok=True)
Train Cell2Net Models#
This is the main training loop where we:
Initialize models for each gene with the Cell2Net architecture
Load pretrained weights into the sequence encoder component
Train models for 40 epochs with early stopping based on validation loss
Evaluate performance on both training and validation sets
Generate visualization showing loss curves and prediction accuracy
Save results including model checkpoints, plots, and numerical predictions
Note that it usually takes 1-2 minutes to train a model for an individual gene, so depending on how many genes you want model, the total training training time can be a few hours.
# We only train Cell2net for one gene to save time
for gene in genes[:1]:
print('Training model for gene:', gene)
cn.utils.set_random_seed(42)
model = cn.pd.model.Cell2Net(mdata=mdata,
gene=gene,
covariates=['total_counts_rna_log', 'total_counts_atac_log'])
# load pretrained weights for the sequence encoder
model.module.seq_encoder.load_state_dict(pretrained_state_dict)
model.train(max_epochs=40, device_name='cuda:1',
batch_size=16, train_idx=train_idx,
valid_idx=valid_idx,
num_workers=4,
verbose=False)
model.save(dir_path=f"{out_dir}/model")
# set the model with the best checkpoint
model.module.load_state_dict(model.check_point)
# Evaluate the model for training and validation dataset
train_pred = model.predict(model.mdata[train_idx])
train_true = model.mdata[train_idx]["rna"].layers["counts"].todense().A1
valid_pred = model.predict(model.mdata[valid_idx])
valid_true = model.mdata[valid_idx]["rna"].layers["counts"].todense().A1
df_train = pd.DataFrame({'true': train_true, 'pred': train_pred, 'data': 'train'})
df_valid = pd.DataFrame({'true': valid_true, 'pred': valid_pred, 'data': 'valid'})
df_train['true'] = np.log1p(df_train['true'])
df_valid['true'] = np.log1p(df_valid['true'])
df_train['pred'] = np.log1p(df_train['pred'])
df_valid['pred'] = np.log1p(df_valid['pred'])
fig, axes = plt.subplots(1, 2, figsize=(8, 4))
sns.scatterplot(data=df_train, x='true', y='pred', ax=axes[0], label='train')
sns.scatterplot(data=df_valid, x='true', y='pred', ax=axes[1], label='valid')
axes[0].set_xlabel("Observation (log1p)")
axes[0].set_ylabel("Prediction (log1p)")
axes[1].set_xlabel("Observation (log1p)")
axes[1].set_ylabel("Prediction (log1p)")
# save figure
plt.savefig(f'{out_dir}/plot/{gene}.png', dpi=300)
plt.close()
np.savez(f'{out_dir}/prediction/{gene}.npz',
best_valid_corr=model.best_valid_corr,
train_pred=train_pred,
train_true=train_true,
valid_pred=valid_pred,
valid_true=valid_true)
Training model for gene: PLEKHN1
Next Steps#
After training is complete, you can:
Analyze model performance by examining the correlation scores and loss curves
Compare different genes to identify which are most predictable from chromatin accessibility
Perform model interpretation using attention weights and feature importance
The trained models and results are saved in the 03_train/ directory and can be loaded for downstream analysis and biological interpretation.