Module flowcon.datasets.base

Based on https://github.com/bayesiains/nsf

Functions

def batch_generator(loader, num_batches=10000000000)
def load_num_batches(loader, num_batches)

A generator that returns num_batches batches from the loader, irrespective of the length of the dataset.

def load_plane_dataset(name, num_points, flip_axes=False, return_label=False)

Loads and returns a plane dataset.

Args

name
string, the name of the dataset.
num_points
int, the number of points the dataset should have,
flip_axes
bool, flip x and y axes if True.

Possible Names: 'gaussian' 'crescent' 'crescent_cubed' 'sine_wave' 'abs' 'sign' 'four_circles' 'diamond' 'two_spirals' 'checkerboard' "eight_gaussians" 'two_circles' 'two_moons' 'pinwheel' 'swissroll' 'rings'

Returns

A Dataset object, the requested dataset.

Raises

ValueError
If name an unknown dataset.

Classes

class InfiniteLoader (num_epochs=None, *args, **kwargs)

A data loader that can load a dataset repeatedly.

Constructor.

Args

dataset
A Dataset object to be loaded.
batch_size
int, the size of each batch.
shuffle
bool, whether to shuffle the dataset after each epoch.
drop_last
bool, whether to drop last batch if its size is less than batch_size.
num_epochs
int or None, number of epochs to iterate over the dataset. If None, defaults to infinity.
Expand source code
class InfiniteLoader(data.DataLoader):
    """A data loader that can load a dataset repeatedly."""

    def __init__(self, num_epochs=None, *args, **kwargs):
        """Constructor.

        Args:
            dataset: A `Dataset` object to be loaded.
            batch_size: int, the size of each batch.
            shuffle: bool, whether to shuffle the dataset after each epoch.
            drop_last: bool, whether to drop last batch if its size is less than
                `batch_size`.
            num_epochs: int or None, number of epochs to iterate over the dataset.
                If None, defaults to infinity.
        """
        super().__init__(
            *args, **kwargs
        )
        self.finite_iterable = super().__iter__()
        self.counter = 0
        self.num_epochs = float('inf') if num_epochs is None else num_epochs

    def __next__(self):
        try:
            return next(self.finite_iterable)
        except StopIteration:
            self.counter += 1
            if self.counter >= self.num_epochs:
                raise StopIteration
            self.finite_iterable = super().__iter__()
            return next(self.finite_iterable)

    def __iter__(self):
        return self

    def __len__(self):
        return None

Ancestors

  • torch.utils.data.dataloader.DataLoader
  • typing.Generic

Class variables

var batch_size : Optional[int]
var dataset : torch.utils.data.dataset.Dataset[+T_co]
var drop_last : bool
var num_workers : int
var pin_memory : bool
var pin_memory_device : str
var prefetch_factor : Optional[int]
var sampler : Union[torch.utils.data.sampler.Sampler, Iterable[+T_co]]
var timeout : float