torcheval.metrics.FrechetAudioDistance¶
-
class
torcheval.metrics.FrechetAudioDistance(preproc: Callable[[Tensor], Tensor], model: Module, embedding_dim: int, device: Optional[device] = None)[source]¶ Computes the Fréchet distance between predicted and target audio waveforms.
Original paper: https://arxiv.org/abs/1812.08466
Parameters: - preproc (Callable[[torch.Tensor], torch.Tensor]) – Callable for preprocessing waveforms prior to passing to model.
- model (torch.nn.Module) – Model for generating embeddings from preprocessed waveforms.
- embedding_dim (int) – Size of embedding.
- device (torch.device or None, optional) – Device where computations will be performed. If None, the default device will be used. (Default: None)
-
__init__(preproc: Callable[[Tensor], Tensor], model: Module, embedding_dim: int, device: Optional[device] = None) None[source]¶ Initialize a metric object and its internal states.
Use
self._add_state()to initialize state variables of your metric class. The state variables should be eithertorch.Tensor, a list oftorch.Tensor, or a dictionary withtorch.Tensoras values
Methods
__init__(preproc, model, embedding_dim[, device])Initialize a metric object and its internal states. compute()Computes the Fréchet distance on the current set of internal states. load_state_dict(state_dict[, strict])Loads metric state variables from state_dict. merge_state(fads)Merges the states of other FrechetAudioDistance instances into those of the current instance. reset()Reset the metric state variables to their default value. state_dict()Save metric state variables in state_dict. to(device, *args, **kwargs)Move tensors in metric state variables to device. update(preds, targets)Update states with a batch of predicted and target waveforms. with_vggish([device])Builds an instance of FrechetAudioDistance with TorchAudio's pretrained VGGish model. Attributes
deviceThe last input device of Metric.to().