inferno.utils package¶
Submodules¶
inferno.utils.exceptions module¶
Exceptions and Error Handling
inferno.utils.io_utils module¶
-
inferno.utils.io_utils.
fromh5
(path, datapath=None, dataslice=None, asnumpy=True, preptrain=None)[source]¶ Opens a hdf5 file at path, loads in the dataset at datapath, and returns dataset as a numpy array.
-
inferno.utils.io_utils.
print_tensor
(tensor, prefix, directory)[source]¶ Prints a image or volume tensor to file as images.
inferno.utils.math_utils module¶
-
inferno.utils.math_utils.
max_allowed_ds_steps
(shape, factor)[source]¶ - How often can a shape be down-sampled by a given factor
- such that non of the divisions will give non-integers.
Parameters: - shape (listlike) – tensor shape
- factor (integer) – downsample factor
Returns: maximum allowed downsample operations
Return type: int
inferno.utils.model_utils module¶
-
class
inferno.utils.model_utils.
ModelTester
(input_shape, expected_output_shape)[source]¶ Bases: object
inferno.utils.python_utils module¶
Utility functions with no external dependencies.
-
class
inferno.utils.python_utils.
delayed_keyboard_interrupt
[source]¶ Bases: object
Delays SIGINT over critical code. Borrowed from: https://stackoverflow.com/questions/842557/ how-to-prevent-a-block-of-code-from-being-interrupted-by-keyboardinterrupt-in-py
-
inferno.utils.python_utils.
deprecated
(reason)[source]¶ This is a decorator which can be used to mark functions as deprecated. It will result in a warning being emitted when the function is used.
Borrowed from https://stackoverflow.com/questions/2536307/ decorators-in-the-python-standard-lib-deprecated-specifically by Laurent LAPORTE https://stackoverflow.com/users/1513933/laurent-laporte
-
inferno.utils.python_utils.
ensure_dir
(directory)[source]¶ ensure the existence of e directory at a given path
If the directory does not exist it is createdParameters: directory (str) – path of the directory Returns: path of the directory Return type: str
-
inferno.utils.python_utils.
require_dict_kwagrs
(kwargs, msg=None)[source]¶ - Ensure arguments passed kwargs are either None or a dict.
- If arguments are neither a dict nor None a RuntimeError is thrown
Parameters: - kwargs (object) – possible dict or None
- msg (None, optional) – Error msg
Returns: kwargs dict
Return type: dict
Raises: RuntimeError – if the passed value is neither a dict nor None this error is raised
inferno.utils.test_utils module¶
-
inferno.utils.test_utils.
generate_random_data
(num_samples, shape, num_classes, hardness=0.3, dtype=None)[source]¶ Generate a random dataset with a given hardness and number of classes.
inferno.utils.torch_utils module¶
-
inferno.utils.torch_utils.
flatten_samples
(tensor_or_variable)[source]¶ Flattens a tensor or a variable such that the channel axis is first and the sample axis is second. The shapes are transformed as follows:
(N, C, H, W) –> (C, N * H * W) (N, C, D, H, W) –> (C, N * D * H * W) (N, C) –> (C, N)The input must be atleast 2d.
-
inferno.utils.torch_utils.
unwrap
(tensor_or_variable, to_cpu=True, as_numpy=False, extract_item=False)[source]¶
-
inferno.utils.torch_utils.
where
(condition, if_true, if_false)[source]¶ Torch equivalent of numpy.where.
Parameters: - condition (torch.ByteTensor or torch.cuda.ByteTensor or torch.autograd.Variable) – Condition to check.
- if_true (torch.Tensor or torch.cuda.Tensor or torch.autograd.Variable) – Output value if condition is true.
- if_false (torch.Tensor or torch.cuda.Tensor or torch.autograd.Variable) – Output value if condition is false
Returns: Return type: torch.Tensor
Raises: - AssertionError – if if_true and if_false are not both variables or both tensors.
- AssertionError – if if_true and if_false don’t have the same datatype.
inferno.utils.train_utils module¶
Utilities for training.
-
class
inferno.utils.train_utils.
AverageMeter
[source]¶ Bases: object
Computes and stores the average and current value. Taken from https://github.com/pytorch/examples/blob/master/imagenet/main.py
-
class
inferno.utils.train_utils.
Duration
(value=None, units=None)[source]¶ Bases: inferno.utils.train_utils.Frequency
Like frequency, but measures a duration.
-
class
inferno.utils.train_utils.
Frequency
(value=None, units=None)[source]¶ Bases: object
-
UNIT_PRIORITY
= 'iterations'¶
-
VALID_UNIT_NAME_MAPPING
= {'epoch': 'epochs', 'epochs': 'epochs', 'iteration': 'iterations', 'iterations': 'iterations'}¶
-
by_epoch
¶
-
by_iteration
¶
-
is_consistent
¶
-
units
¶
-
value
¶
-
-
class
inferno.utils.train_utils.
MovingAverage
(momentum=0)[source]¶ Bases: object
Computes the moving average of a given float.
-
relative_change
¶
-