Module flowcon.datasets.base
Based on https://github.com/bayesiains/nsf
Functions
def batch_generator(loader, num_batches=10000000000)
def load_num_batches(loader, num_batches)
-
A generator that returns num_batches batches from the loader, irrespective of the length of the dataset.
def load_plane_dataset(name, num_points, flip_axes=False, return_label=False)
-
Loads and returns a plane dataset.
Args
name
- string, the name of the dataset.
num_points
- int, the number of points the dataset should have,
flip_axes
- bool, flip x and y axes if True.
Possible Names: 'gaussian' 'crescent' 'crescent_cubed' 'sine_wave' 'abs' 'sign' 'four_circles' 'diamond' 'two_spirals' 'checkerboard' "eight_gaussians" 'two_circles' 'two_moons' 'pinwheel' 'swissroll' 'rings'
Returns
A Dataset object, the requested dataset.
Raises
ValueError
- If
name
an unknown dataset.
Classes
class InfiniteLoader (num_epochs=None, *args, **kwargs)
-
A data loader that can load a dataset repeatedly.
Constructor.
Args
dataset
- A
Dataset
object to be loaded. batch_size
- int, the size of each batch.
shuffle
- bool, whether to shuffle the dataset after each epoch.
drop_last
- bool, whether to drop last batch if its size is less than
batch_size
. num_epochs
- int or None, number of epochs to iterate over the dataset. If None, defaults to infinity.
Expand source code
class InfiniteLoader(data.DataLoader): """A data loader that can load a dataset repeatedly.""" def __init__(self, num_epochs=None, *args, **kwargs): """Constructor. Args: dataset: A `Dataset` object to be loaded. batch_size: int, the size of each batch. shuffle: bool, whether to shuffle the dataset after each epoch. drop_last: bool, whether to drop last batch if its size is less than `batch_size`. num_epochs: int or None, number of epochs to iterate over the dataset. If None, defaults to infinity. """ super().__init__( *args, **kwargs ) self.finite_iterable = super().__iter__() self.counter = 0 self.num_epochs = float('inf') if num_epochs is None else num_epochs def __next__(self): try: return next(self.finite_iterable) except StopIteration: self.counter += 1 if self.counter >= self.num_epochs: raise StopIteration self.finite_iterable = super().__iter__() return next(self.finite_iterable) def __iter__(self): return self def __len__(self): return None
Ancestors
- torch.utils.data.dataloader.DataLoader
- typing.Generic
Class variables
var batch_size : Optional[int]
var dataset : torch.utils.data.dataset.Dataset[+T_co]
var drop_last : bool
var num_workers : int
var pin_memory : bool
var pin_memory_device : str
var prefetch_factor : Optional[int]
var sampler : Union[torch.utils.data.sampler.Sampler, Iterable[+T_co]]
var timeout : float