Skip to content

Specification

Table of Contents

Introduction

As already outlined in README here we intend to learn and apply the rules underlying regulatory element sequences using Stable Diffusion.

The architecture of DNA-Diffusion intends to loosely adhere to the principles of Test-Driven Design (TDD),

Here are the main principles we strive to follow:

  1. Write tests first: In TDD, you write a failing test before writing any production code. The test should be small, specific, and test only one aspect of the code.

  2. Write the simplest code that passes the test: Once you've written the test, write the production code that will make the test pass. The code should be the simplest possible code that satisfies the test.

  3. Refactor the code: Once the test has passed, you should refactor the code to improve its quality and maintainability. Refactoring means making changes to the code without changing its behavior.

  4. Repeat the process: Once you've refactored the code, you should write another test and repeat the process.

  5. Test everything that could possibly break: In TDD, you should write tests for all of the functionality that could potentially break. This includes boundary conditions, edge cases, and error conditions.

  6. Use test automation: TDD relies on automated tests to verify that the code works as expected. Writing tests manually can be time-consuming and error-prone, so you should use test automation tools to write and run your tests.

  7. Keep the feedback loop short: TDD is based on a short feedback loop, where you write a test, run it, and get immediate feedback on whether the code works as expected. This short feedback loop helps you catch errors early and makes it easier to debug problems.

Hypothetical usage

First we depict a hypothetical usage example for the training:

from dnadiffusion.configs import LightningTrainer, sample
Config = make_config(
    hydra_defaults=[
        "_self_",
        {"data": "LoadingData"},
        {"model": "Unet"},
    ],
    data=MISSING,
    model=MISSING,
    trainer=LightningTrainer,
    sample=sample,
    # Constants
    data_dir="dna_diffusion/data",
    random_seed=42,
    ckpt_path=None,
)
cs = ConfigStore.instance()
cs.store(name="config", node=Config)
def train(config):
    data = instantiate(config.data)
    sample = instantiate(config.sample, data_module=data)
    model = instantiate(config.model)
    trainer = instantiate(config.trainer)
    # Adding custom callbacks
    trainer.callbacks.append(sample)
    trainer.fit(model, data)
    return model
@hydra.main(config_path=None, config_name="config", version_base="1.3")
def main(cfg: DictConfig):
    return train(cfg)

Another usage example depicting the sampling part:

Config = make_config(
    hydra_defaults=[
        "_self_",
        {"data": "LoadingData"},
        {"model": "Unet"},
    ],
    model=MISSING,
    # Constants
    random_seed=42,
    ckpt_path=None,
)
cs = ConfigStore.instance()
cs.store(name="config", node=Config)
def sample(config):
    model = instantiate(ckpt_path)
    dna_samples=model.sample(config.args)
    return dna_samples
@hydra.main(config_path=None, config_name="config", version_base="1.3")
def main(cfg: DictConfig):
    return sample(cfg)

Functional Requirements

Architecture and Design

TODO finish after final refactoring is in place.

Data - LoadingData: Handles loading data and generating fastas and motifs. - LoadingDataModule: Loading the data as adjusted PytorchLightningDataModule format. Metrics - sampling_metrics.py: This file contains functions (metrics) that are used to judge the sampling process, things like KL divergence etc. - validation_metrics.py: This file contains functions (metrics) that are used as to judge the generated sequences themselves. Things like Enformer and BPNet will be used. Model - UnetDiffusion: Defines the UNET it models with all bells and whistles (scheduler etc.) - Unet: Defines the bare bones UNET model.

Developer

The library is packaged with hatch. Developer usage is documented in README.md.

Abstracted commands to package this software and publish it can be found in Makefile and used with make commands.

User

Here we present a short hypothetical example based on conditional generation, i.e. with text input:

"A sequence that will correspond to open (or closed) chromatin in cell type X"

TEXT_PROMPT = "A sequence that will correspond to open (or closed) chromatin in cell type X"
Config = make_config(
    hydra_defaults=[
        "_self_",
        {"data": "LoadingData"},
        {"model": "Unet"},
    ],
    model=MISSING,
    # Constants
    random_seed=42,
    ckpt_path=None,
)
cs = ConfigStore.instance()
cs.store(name="config", node=Config)
def sample(config):
    model = instantiate(ckpt_path)
    # conditional sampling based on the text prompt
    dna_samples=model.sample(config.args, TEXT_PROMPT)
    return dna_samples
@hydra.main(config_path=None, config_name="config", version_base="1.3")
def main(cfg: DictConfig):
    return sample(cfg)

Non-functional Requirements

  1. Ethics and Security: The system should protect sensitive data and prevent unauthorized access or tampering. It should also be able to detect and respond to security threats or attacks.

  2. Usability: The system should be easy to use and intuitive for users, including data scientists, developers, and end-users. It should have a user-friendly interface and provide clear feedback on errors and warnings.

  3. Maintainability: The system should be easy to maintain and update over time, including managing data, updating the model, and fixing bugs. It should also be compatible with existing infrastructure and tools.

  4. Performance: The system should be able to process data and generate predictions in a timely manner, meeting specific performance requirements or benchmarks. This includes factors like throughput, latency, and response time.

  5. Scalability: The system should be able to handle large volumes of data and users, and be able to scale up or down as needed. This includes considerations such as system capacity, access to the HPC cluster and data storage.

Data Model

We take the data presented in the following part and one-hot encode it before passing it to the network (UNET and/or UNET plus VQ_VAE).

More concretely for every sequence (200 BP in Wouters dataset) we take the nucleotides and one-hot encode them. Meaning if the input is an array of length 200, after processing its (200x4), for 4 nucleotides.

As stated in the Readme: Chromatin (DNA + associated proteins) that is actively used for the regulation of genes (i.e. "regulatory elements") is typically accessible to DNA-binding proteins such as transcription factors (review, relevant paper). Through the use of a technique called DNase-seq, we've measured which parts of the genome are accessible across 733 human biosamples encompassing 438 cell and tissue types and states, resulting in more than 3.5 million DNase Hypersensitive Sites (DHSs). Using Non-Negative Matrix Factorization, we've summarized these data into 16 components, each corresponding to a different cellular context (e.g. 'cardiac', 'neural', 'lymphoid').

For the efforts described in this proposal, and as part of an earlier ongoing project in the research group of Wouter Meuleman, we've put together smaller subsets of these data that can be used to train models to generate synthetic sequences for each NMF component.

Please find these data, along with a data dictionary, here.

External interfaces

Hosted and exposed on Hugging face.

Project structure

DNA-Diffusion
├─ .editorconfig
├─ .gitignore
├─ CITATION.cff
├─ CODE_OF_CONDUCT.md
├─ dockerfiles
│  └─ Dockerfile
├─ docs
│  ├─ contributors.md
│  ├─ images
│  │  ├─ diff_first.gif
│  │  └─ diff_first_lossy.gif
│  ├─ index.md
│  ├─ reference
│  │  └─ dnadiffusion.md
│  └─ specification.md
├─ environments
│  ├─ cluster
│  │  ├─ create_conda.sh
│  │  ├─ dnadiffusion_run.sh
│  │  ├─ install_mambaforge.sh
│  │  ├─ README.md
│  │  ├─ slurm_interactive.sh
│  │  └─ test_path.sh
│  └─ conda
│     └─ environment.yml
├─ LICENSE.md
├─ Makefile
├─ mkdocs.yml
├─ notebooks
│  ├─ experiments
│  │  ├─ conditional_diffusion
│  │  │  ├─ accelerate_diffusion_conditional_4_cells.ipynb
│  │  │  ├─ dna_diff_baseline_conditional_UNET.ipynb
│  │  │  ├─ dna_diff_baseline_conditional_UNET_with_time_warping.ipynb
│  │  │  ├─ easy_training_Conditional_Code_to_refactor_UNET_ANNOTATED_v4 (2).ipynb
│  │  │  ├─ full_script_version_from_accelerate_notebook
│  │  │  │  ├─ dnadiffusion.py
│  │  │  │  ├─ filter_data.ipynb
│  │  │  │  ├─ master_dataset.ipynb
│  │  │  │  └─ README.MD
│  │  │  ├─ previous_version
│  │  │  │  └─ Conditional_Code_to_refactor_UNET_ANNOTATED_v3 (2).ipynb
│  │  │  ├─ README.MD
│  │  │  ├─ vq_vae_accelerate_diffusion_conditional_4_cells.ipynb
│  │  │  └─ VQ_VAE_LATENT_SPACE_WITH_METRICS.ipynb
│  │  └─ README.md
│  ├─ README.md
│  ├─ refactoring
│  │  └─ README.md
│  └─ tutorials
│     └─ README.md
├─ pyproject.toml
├─ README.md
├─ src
│  ├─ dnadiffusion
│  │  ├─ callbacks
│  │  │  ├─ ema.py
│  │  │  └─ sampling.py
│  │  ├─ cli
│  │  │  └─ __init__.py
│  │  ├─ configs.py
│  │  ├─ data
│  │  │  ├─ dataloader.py
│  │  │  ├─ encode_data.npy
│  │  │  ├─ encode_data.pkl
│  │  │  ├─ K562_hESCT0_HepG2_GM12878_12k_sequences_per_group.txt
│  │  │  ├─ model4cells_train_split_3_50_dims.pkl
│  │  │  ├─ README.md
│  │  │  └─ __init__.py
│  │  ├─ losses
│  │  │  ├─ README.md
│  │  │  └─ __init__.py
│  │  ├─ metrics
│  │  │  ├─ README.md
│  │  │  ├─ sampling_metrics.py
│  │  │  ├─ validation_metrics.py
│  │  │  └─ __init__.py
│  │  ├─ models
│  │  │  ├─ diffusion.py
│  │  │  ├─ modules.py
│  │  │  ├─ networks.py
│  │  │  ├─ README.md
│  │  │  ├─ training_modules.py
│  │  │  ├─ unet.py
│  │  │  └─ __init__.py
│  │  ├─ README.md
│  │  ├─ sample.py
│  │  ├─ trainer.py
│  │  ├─ utils
│  │  │  ├─ ema.py
│  │  │  ├─ README.md
│  │  │  ├─ scheduler.py
│  │  │  ├─ utils.py
│  │  │  └─ __init__.py
│  │  ├─ __about__.py
│  │  ├─ __init__.py
│  │  └─ __main__.py
│  └─ refactor
│     ├─ config.py
│     ├─ configs
│     │  ├─ callbacks
│     │  │  └─ default.yaml
│     │  ├─ data
│     │  │  ├─ sequence.yaml
│     │  │  └─ vanilla_sequences.yaml
│     │  ├─ logger
│     │  │  └─ wandb.yaml
│     │  ├─ main.yaml
│     │  ├─ model
│     │  │  ├─ dnaddpmdiffusion.yaml
│     │  │  ├─ dnadiffusion.yaml
│     │  │  ├─ lr_scheduler
│     │  │  │  └─ MultiStepLR.yaml
│     │  │  ├─ optimizer
│     │  │  │  └─ adam.yaml
│     │  │  └─ unet
│     │  │     ├─ unet.yaml
│     │  │     └─ unet_conditional.yaml
│     │  ├─ paths
│     │  │  └─ default.yaml
│     │  └─ trainer
│     │     ├─ ddp.yaml
│     │     └─ default.yaml
│     ├─ data
│     │  ├─ sequence_dataloader.py
│     │  └─ sequence_datamodule.py
│     ├─ main.py
│     ├─ models
│     │  ├─ diffusion
│     │  │  ├─ ddpm.py
│     │  │  └─ diffusion.py
│     │  ├─ encoders
│     │  │  └─ vqvae.py
│     │  └─ networks
│     │     ├─ unet_lucas.py
│     │     └─ unet_lucas_cond.py
│     ├─ README.md
│     ├─ sample.py
│     ├─ tests
│     │  ├─ data
│     │  │  └─ test_sequence_dataloader.py
│     │  ├─ models
│     │  │  ├─ diffusion
│     │  │  │  ├─ test_ddim.py
│     │  │  │  └─ test_ddpm.py
│     │  │  ├─ encoders
│     │  │  │  └─ test_vqvae.py
│     │  │  └─ networks
│     │  │     ├─ test_unet_bitdiffusion.py
│     │  │     └─ test_unet_lucas.py
│     │  ├─ utils
│     │  │  ├─ test_ema.py
│     │  │  ├─ test_misc.py
│     │  │  ├─ test_network.py
│     │  │  └─ test_schedules.py
│     │  └─ __init__.py
│     └─ utils
│        ├─ ema.py
│        ├─ metrics.py
│        ├─ misc.py
│        ├─ network.py
│        └─ schedules.py
├─ tests
│  ├─ conftest.py
│  ├─ test_add.py
│  ├─ test_main.py
│  └─ __init__.py
├─ test_environment.py
└─ train.py

Testing

dnadiffusion will be tested using the pytest framework.

Endpoint to execute all the tests can be found in the Makefile.

Deployment and Maintenance

dnadiffusion will be distributed as a python package that can be installed and executed on any system with python version 3.10 or greater.

Endpoints to create the package and distribute it can be found in the Makefile.