The PyTorch-FWD package#
pytorchfwd.fwd module#
Frechet Wavelet Distance computation.
- pytorchfwd.fwd.calculate_path_statistics(path: str, wavelet: str, max_level: int, log_scale: bool, batch_size: int) Tuple[ndarray, ...] [source]#
Compute mean and sigma for given path.
- Parameters:
path (str) – npz path or image directory.
wavelet (str) – Choice of wavelet.
max_level (int) – Decomposition level.
log_scale (bool) – Apply log scale.
batch_size (int) – Batch size for packet decomposition.
- Raises:
ValueError – Error if mu and sigma cannot be calculated.
- Returns:
Tuple containing mean and sigma for each packet.
- Return type:
Tuple[np.ndarray, …]
- pytorchfwd.fwd.compute_fwd(paths: List[str], wavelet: str, max_level: int, log_scale: bool, batch_size: int) float [source]#
Compute Frechet Wavelet Distance.
- Parameters:
paths (List[str]) – List containing path of source and generated images.
wavelet (str) – Choice of wavelet.
max_level (int) – Decomposition level.
log_scale (bool) – Apply log scale.
batch_size (int) – Batch size for packet decomposition.
- Raises:
RuntimeError – Error if path doesn’t exist.
- Returns:
Frechet Wavelet Distance.
- Return type:
float
- pytorchfwd.fwd.compute_packet_statistics(dataloader: DataLoader, wavelet: str, max_level: int, log_scale: bool) Tuple[ndarray, ...] [source]#
Compute wavelet packet transform across batches.
- Parameters:
dataloader (th.utils.data.DataLoader) – Torch dataloader.
wavelet (str) – Choice of wavelet.
max_level (int) – Wavelet decomposition level.
log_scale (bool) – Apply log scale.
- Returns:
Mean and sigma for each packet.
- Return type:
Tuple[np.ndarray, …]