cell2net.prediction.data.get_dataloader#
- cell2net.prediction.data.get_dataloader(mdata, rna_mod='rna', atac_mod='atac', with_variants=False, idx=None, covariates=None, batch_size=128, num_workers=4, pin_memory=True, shuffle=True, drop_last=True, persistent_workers=True, **kwargs)#
Creates a PyTorch DataLoader from a MuData object.
This function converts a MuData object into a PyTorch DataLoader for training or evaluation. It allows customization of data loading parameters, such as batch size, shuffling, and number of workers.
- Parameters:
mdata (
MuData) – A MuData object containing multimodal data.rna_mod (
str(default:'rna')) – The name of the RNA modality in the MuData object. Default is “rna”.atac_mod (
str(default:'atac')) – The name of the ATAC modality in the MuData object. Default is “atac”.idx (
Sequence[int] |Sequence[str] |None(default:None)) – Indices or keys to subset the MuData object. If None, the entire dataset is used. Default is None.covariates (
Sequence[str] |None(default:None)) – Covariates to include from the MuData object. Default is None.batch_size (
int(default:128)) – The number of samples per batch. Default is 128.num_workers (
int(default:4)) – The number of worker processes for data loading. Default is 4.pin_memory (
bool(default:True)) – Whether to pin memory in DataLoader for faster GPU transfers. Default is True.shuffle (
bool(default:True)) – Whether to shuffle the dataset. Default is True.drop_last (
bool(default:True)) – Whether to drop the last incomplete batch. Default is True.persistent_workers (
bool(default:True)) – Whether to keep data loading workers alive between epochs. Default is True.**kwargs – Additional keyword arguments passed to torch.utils.data.DataLoader.
- Return type:
DataLoader- Returns:
A PyTorch DataLoader for the specified MuData dataset.
Examples
>>> from mudata import MuData >>> from cell2net.pd.data import get_dataloader >>> mdata = MuData("data.h5mu") >>> dataloader = get_dataloader(mdata, batch_size=32, shuffle=True) >>> for batch in dataloader: >>> print(batch)