Training Cell2Net models for K562#

This tutorial demonstrates how to train Cell2Net models for predicting gene expression from chromatin accessibility data in K562 cells. We’ll leverage a pretrained sequence encoder to improve model performance and training efficiency.

import warnings
warnings.filterwarnings("ignore")

import os
import mudata as md
import cell2net as cn
import torch
import seaborn as sns
import matplotlib.pyplot as plt
import pandas as pd
import numpy as np
from sklearn.model_selection import train_test_split
md.set_options(pull_on_update=False)
<mudata._core.config.set_options at 0x7ff16e7dede0>
cn.__version__
'0.13'

Setup Directories and Paths#

We define input and output paths for this training session. The multiome data should be prepared from the previous tutorial (02_prepare_data).

input_data = "./02_prepare_data/mdata.h5mu"
out_dir = './03_train'

os.makedirs(out_dir, exist_ok=True)

Load Data and Extract Target Genes#

Load the preprocessed multiome data and extract the list of genes that have valid peak-to-gene associations. These will be our modeling targets.

mdata = md.read_h5mu(input_data)
genes = mdata.uns['peak_to_gene']['gene'].unique().tolist()

Let’s check how many genes we’ll be training models for:

len(genes)
1608

Create Train/Validation Split#

We split the cells into training (80%) and validation (20%) sets to properly evaluate model performance and prevent overfitting.

train_idx, valid_idx = train_test_split(
    mdata.obs_names.values.tolist(),
    train_size=0.8,
    random_state=42)

Load Pretrained Sequence Encoder#

We load the pretrained sequence encoder weights from the first tutorial. This transfer learning approach helps models converge faster and achieve better performance by leveraging learned sequence-to-accessibility patterns.

pretrained_model_path = "./pretrained_seq2acc.pth"
pretrained_state_dict = torch.load(pretrained_model_path, map_location="cpu")

Setup Output Directories#

Create directories to organize our training outputs:

  • model/: Trained model checkpoints

  • plot/: Training curves and evaluation plots

  • prediction/: Numerical results and predictions

os.makedirs(f"{out_dir}/model", exist_ok=True)
os.makedirs(f"{out_dir}/plot", exist_ok=True)
os.makedirs(f"{out_dir}/prediction", exist_ok=True)

Train Cell2Net Models#

This is the main training loop where we:

  1. Initialize models for each gene with the Cell2Net architecture

  2. Load pretrained weights into the sequence encoder component

  3. Train models for 40 epochs with early stopping based on validation loss

  4. Evaluate performance on both training and validation sets

  5. Generate visualization showing loss curves and prediction accuracy

  6. Save results including model checkpoints, plots, and numerical predictions

Note that it usually takes 1-2 minutes to train a model for an individual gene, so depending on how many genes you want model, the total training training time can be a few hours.

# We only train Cell2net for one gene to save time
for gene in genes[:1]:    
    print('Training model for gene:', gene)

    cn.utils.set_random_seed(42)
    model = cn.pd.model.Cell2Net(mdata=mdata,
                                 gene=gene, 
                                 covariates=['total_counts_rna_log', 'total_counts_atac_log'])

    # load pretrained weights for the sequence encoder
    model.module.seq_encoder.load_state_dict(pretrained_state_dict)

    model.train(max_epochs=40, device_name='cuda:1', 
                batch_size=16, train_idx=train_idx, 
                valid_idx=valid_idx, 
                num_workers=4,
                verbose=False)

    model.save(dir_path=f"{out_dir}/model")

    # set the model with the best checkpoint
    model.module.load_state_dict(model.check_point)

    # Evaluate the model for training and validation dataset
    train_pred = model.predict(model.mdata[train_idx])
    train_true = model.mdata[train_idx]["rna"].layers["counts"].todense().A1

    valid_pred = model.predict(model.mdata[valid_idx])
    valid_true = model.mdata[valid_idx]["rna"].layers["counts"].todense().A1

    df_train = pd.DataFrame({'true': train_true, 'pred': train_pred, 'data': 'train'})
    df_valid = pd.DataFrame({'true': valid_true, 'pred': valid_pred, 'data': 'valid'})

    df_train['true'] = np.log1p(df_train['true'])
    df_valid['true'] = np.log1p(df_valid['true'])
    df_train['pred'] = np.log1p(df_train['pred'])
    df_valid['pred'] = np.log1p(df_valid['pred'])

    fig, axes = plt.subplots(1, 2, figsize=(8, 4))

    sns.scatterplot(data=df_train, x='true', y='pred', ax=axes[0], label='train')
    sns.scatterplot(data=df_valid, x='true', y='pred', ax=axes[1], label='valid')

    axes[0].set_xlabel("Observation (log1p)")
    axes[0].set_ylabel("Prediction (log1p)")
    axes[1].set_xlabel("Observation (log1p)")
    axes[1].set_ylabel("Prediction (log1p)")

    # save figure
    plt.savefig(f'{out_dir}/plot/{gene}.png', dpi=300)
    plt.close()

    np.savez(f'{out_dir}/prediction/{gene}.npz', 
             best_valid_corr=model.best_valid_corr,
             train_pred=train_pred,
             train_true=train_true,
             valid_pred=valid_pred,
             valid_true=valid_true)
Training model for gene: PLEKHN1

Next Steps#

After training is complete, you can:

  1. Analyze model performance by examining the correlation scores and loss curves

  2. Compare different genes to identify which are most predictable from chromatin accessibility

  3. Perform model interpretation using attention weights and feature importance

The trained models and results are saved in the 03_train/ directory and can be loaded for downstream analysis and biological interpretation.