Luigi's picture
Restore complete zipvoice package with all source files
2967cdb
from typing import Callable, Dict, List, Sequence, Union
import torch
from lhotse import CutSet, validate
from lhotse.dataset import PrecomputedFeatures
from lhotse.dataset.collation import collate_audio
from lhotse.dataset.input_strategies import BatchIO
from lhotse.utils import ifnone
class SpeechSynthesisDataset(torch.utils.data.Dataset):
"""
The PyTorch Dataset for the speech synthesis task.
Each item in this dataset is a dict of:
.. code-block::
{
'audio': (B x NumSamples) float tensor
'features': (B x NumFrames x NumFeatures) float tensor
'audio_lens': (B, ) int tensor
'features_lens': (B, ) int tensor
'text': List[str] of len B # when return_text=True
'tokens': List[List[str]] # when return_tokens=True
'speakers': List[str] of len B # when return_spk_ids=True
'cut': List of Cuts # when return_cuts=True
}
"""
def __init__(
self,
cut_transforms: List[Callable[[CutSet], CutSet]] = None,
feature_input_strategy: BatchIO = PrecomputedFeatures(),
feature_transforms: Union[Sequence[Callable], Callable] = None,
return_text: bool = True,
return_tokens: bool = False,
return_spk_ids: bool = False,
return_cuts: bool = False,
return_audio: bool = False,
) -> None:
super().__init__()
self.cut_transforms = ifnone(cut_transforms, [])
self.feature_input_strategy = feature_input_strategy
self.return_text = return_text
self.return_tokens = return_tokens
self.return_spk_ids = return_spk_ids
self.return_cuts = return_cuts
self.return_audio = return_audio
if feature_transforms is None:
feature_transforms = []
elif not isinstance(feature_transforms, Sequence):
feature_transforms = [feature_transforms]
assert all(
isinstance(transform, Callable) for transform in feature_transforms
), "Feature transforms must be Callable"
self.feature_transforms = feature_transforms
def __getitem__(self, cuts: CutSet) -> Dict[str, torch.Tensor]:
validate_for_tts(cuts)
for transform in self.cut_transforms:
cuts = transform(cuts)
features, features_lens = self.feature_input_strategy(cuts)
for transform in self.feature_transforms:
features = transform(features)
batch = {
"features": features,
"features_lens": features_lens,
}
if self.return_audio:
audio, audio_lens = collate_audio(cuts)
batch["audio"] = audio
batch["audio_lens"] = audio_lens
if self.return_text:
text = [cut.supervisions[0].text for cut in cuts]
batch["text"] = text
if self.return_tokens:
tokens = [cut.supervisions[0].tokens for cut in cuts]
batch["tokens"] = tokens
if self.return_spk_ids:
batch["speakers"] = [cut.supervisions[0].speaker for cut in cuts]
if self.return_cuts:
batch["cut"] = [cut for cut in cuts]
return batch
def validate_for_tts(cuts: CutSet) -> None:
validate(cuts)
for cut in cuts:
assert (
len(cut.supervisions) == 1
), "Only the Cuts with single supervision are supported."