Module flowcon.datasets.base
Based on https://github.com/bayesiains/nsf
Functions
def batch_generator(loader, num_batches=10000000000)def load_num_batches(loader, num_batches)-
A generator that returns num_batches batches from the loader, irrespective of the length of the dataset.
def load_plane_dataset(name, num_points, flip_axes=False, return_label=False)-
Loads and returns a plane dataset.
Args
name- string, the name of the dataset.
num_points- int, the number of points the dataset should have,
flip_axes- bool, flip x and y axes if True.
Possible Names: 'gaussian' 'crescent' 'crescent_cubed' 'sine_wave' 'abs' 'sign' 'four_circles' 'diamond' 'two_spirals' 'checkerboard' "eight_gaussians" 'two_circles' 'two_moons' 'pinwheel' 'swissroll' 'rings'
Returns
A Dataset object, the requested dataset.
Raises
ValueError- If
namean unknown dataset.
Classes
class InfiniteLoader (num_epochs=None, *args, **kwargs)-
A data loader that can load a dataset repeatedly.
Constructor.
Args
dataset- A
Datasetobject to be loaded. batch_size- int, the size of each batch.
shuffle- bool, whether to shuffle the dataset after each epoch.
drop_last- bool, whether to drop last batch if its size is less than
batch_size. num_epochs- int or None, number of epochs to iterate over the dataset. If None, defaults to infinity.
Expand source code
class InfiniteLoader(data.DataLoader): """A data loader that can load a dataset repeatedly.""" def __init__(self, num_epochs=None, *args, **kwargs): """Constructor. Args: dataset: A `Dataset` object to be loaded. batch_size: int, the size of each batch. shuffle: bool, whether to shuffle the dataset after each epoch. drop_last: bool, whether to drop last batch if its size is less than `batch_size`. num_epochs: int or None, number of epochs to iterate over the dataset. If None, defaults to infinity. """ super().__init__( *args, **kwargs ) self.finite_iterable = super().__iter__() self.counter = 0 self.num_epochs = float('inf') if num_epochs is None else num_epochs def __next__(self): try: return next(self.finite_iterable) except StopIteration: self.counter += 1 if self.counter >= self.num_epochs: raise StopIteration self.finite_iterable = super().__iter__() return next(self.finite_iterable) def __iter__(self): return self def __len__(self): return NoneAncestors
- torch.utils.data.dataloader.DataLoader
- typing.Generic
Class variables
var batch_size : Optional[int]var dataset : torch.utils.data.dataset.Dataset[+T_co]var drop_last : boolvar num_workers : intvar pin_memory : boolvar pin_memory_device : strvar prefetch_factor : Optional[int]var sampler : Union[torch.utils.data.sampler.Sampler, Iterable[+T_co]]var timeout : float