Files
koko210Serve 8ca716029e add: absorb soprano_to_rvc as regular subdirectory
Voice conversion pipeline (Soprano TTS → RVC) with Docker support.
Previously tracked as bare gitlink; removed .git/ directories and
absorbed into main repo for unified tracking.

Includes: Soprano TTS, RVC WebUI integration, Docker configs,
WebSocket API, and benchmark scripts.
Updated .gitignore to exclude large model weights (*.pth, *.pt, *.onnx, *.index).
287 files (3.1GB of ML weights properly excluded via gitignore).
2026-03-04 00:24:53 +02:00

238 lines
10 KiB
Python

from .vocos.decoder import SopranoDecoder
from .utils.text_normalizer import clean_text
from .utils.text_splitter import split_and_recombine_text
from .utils.auto_select import select_device, select_backend
import torch
import re
from unidecode import unidecode
from scipy.io import wavfile
from huggingface_hub import hf_hub_download
import os
import time
class SopranoTTS:
"""
Soprano Text-to-Speech model.
Args:
backend: Backend to use for inference. Options:
- 'auto' (default): Automatically select best backend. Tries lmdeploy first (fastest),
falls back to transformers. CPU always uses transformers.
- 'lmdeploy': Force use of LMDeploy (fastest, CUDA only)
- 'transformers': Force use of HuggingFace Transformers (slower, all devices)
device: Device to run inference on ('auto', 'cuda', 'cpu', 'mps')
cache_size_mb: Cache size in MB for lmdeploy backend
decoder_batch_size: Batch size for decoder
"""
def __init__(self,
backend='auto',
device='auto',
cache_size_mb=100,
decoder_batch_size=1,
model_path=None):
device = select_device(device=device)
backend = select_backend(backend=backend, device=device)
if backend == 'lmdeploy':
from .backends.lmdeploy import LMDeployModel
self.pipeline = LMDeployModel(device=device, cache_size_mb=cache_size_mb, model_path=model_path)
elif backend == 'transformers':
from .backends.transformers import TransformersModel
self.pipeline = TransformersModel(device=device, model_path=model_path)
self.device = device
self.backend = backend
self.decoder = SopranoDecoder().to(device)
if model_path:
decoder_path = os.path.join(model_path, 'decoder.pth')
else:
decoder_path = hf_hub_download(repo_id='ekwek/Soprano-1.1-80M', filename='decoder.pth')
self.decoder.load_state_dict(torch.load(decoder_path, map_location=device))
self.decoder_batch_size=decoder_batch_size
self.RECEPTIVE_FIELD = 4 # Decoder receptive field
self.TOKEN_SIZE = 2048 # Number of samples per audio token
self.infer("Hello world!") # warmup
def _preprocess_text(self, texts, min_length=30):
'''
adds prompt format and sentence/part index
Enforces a minimum sentence length by merging short sentences.
'''
res = []
for text_idx, text in enumerate(texts):
text = text.strip()
cleaned_text = clean_text(text)
sentences = split_and_recombine_text(cleaned_text)
processed = []
for sentence in sentences:
processed.append({
"text": sentence,
"text_idx": text_idx,
})
if min_length > 0 and len(processed) > 1:
merged = []
i = 0
while i < len(processed):
cur = processed[i]
if len(cur["text"]) < min_length:
if merged: merged[-1]["text"] = (merged[-1]["text"] + " " + cur["text"]).strip()
else:
if i + 1 < len(processed): processed[i + 1]["text"] = (cur["text"] + " " + processed[i + 1]["text"]).strip()
else: merged.append(cur)
else: merged.append(cur)
i += 1
processed = merged
sentence_idxes = {}
for item in processed:
if item['text_idx'] not in sentence_idxes: sentence_idxes[item['text_idx']] = 0
res.append((f'[STOP][TEXT]{item["text"]}[START]', item["text_idx"], sentence_idxes[item['text_idx']]))
sentence_idxes[item['text_idx']] += 1
return res
def hallucination_detector(self, hidden_state):
'''
Analyzes hidden states to find long runs of similar sequences.
'''
DIFF_THRESHOLD = 300 # minimal difference between sequences
MAX_RUNLENGTH = 16 # maximum number of recent similar sequences
if len(hidden_state) <= MAX_RUNLENGTH: # hidden state not long enough
return False
aah_runlength = 0
for i in range(len(hidden_state) - 1):
current_sequences = hidden_state[i]
next_sequences = hidden_state[i + 1]
diffs = torch.abs(current_sequences - next_sequences)
total_diff = diffs.sum(dim=0)
if total_diff < DIFF_THRESHOLD:
aah_runlength += 1
elif aah_runlength > 0:
aah_runlength -= 1
if aah_runlength > MAX_RUNLENGTH:
return True
return False
def infer(self,
text,
out_path=None,
top_p=0.95,
temperature=0.0,
repetition_penalty=1.2,
retries=0):
results = self.infer_batch([text],
top_p=top_p,
temperature=temperature,
repetition_penalty=repetition_penalty,
out_dir=None,
retries=retries)[0]
if out_path:
wavfile.write(out_path, 32000, results.cpu().numpy())
return results
def infer_batch(self,
texts,
out_dir=None,
top_p=0.95,
temperature=0.0,
repetition_penalty=1.2,
retries=0):
sentence_data = self._preprocess_text(texts)
prompts = list(map(lambda x: x[0], sentence_data))
hidden_states = [None] * len(prompts)
pending_indices = list(range(0, len(prompts)))
tries_left = 1 + max(0, retries)
while tries_left > 0 and pending_indices:
current_prompts = [prompts[i] for i in pending_indices]
responses = self.pipeline.infer(current_prompts,
top_p=top_p,
temperature=temperature,
repetition_penalty=repetition_penalty)
bad_indices = []
for idx, response in enumerate(responses):
hidden_state = response['hidden_state']
hidden_states[pending_indices[idx]] = hidden_state
if response['finish_reason'] != 'stop':
print(f"Warning: A sentence did not complete generation, likely due to hallucination.")
if retries > 0 and self.hallucination_detector(hidden_state):
print(f"Warning: A sentence contained a hallucination.")
bad_indices.append(pending_indices[idx])
if not bad_indices:
break
else:
pending_indices = bad_indices
tries_left -= 1
if tries_left > 0:
print(f"Warning: {len(pending_indices)} sentence(s) will be regenerated.")
combined = list(zip(hidden_states, sentence_data))
combined.sort(key=lambda x: -x[0].size(0))
hidden_states, sentence_data = zip(*combined)
num_texts = len(texts)
audio_concat = [[] for _ in range(num_texts)]
for sentence in sentence_data:
audio_concat[sentence[1]].append(None)
for idx in range(0, len(hidden_states), self.decoder_batch_size):
batch_hidden_states = []
lengths = list(map(lambda x: x.size(0), hidden_states[idx:idx+self.decoder_batch_size]))
N = len(lengths)
for i in range(N):
batch_hidden_states.append(torch.cat([
torch.zeros((1, 512, lengths[0]-lengths[i]), device=self.device),
hidden_states[idx+i].unsqueeze(0).transpose(1,2).to(self.device).to(torch.float32),
], dim=2))
batch_hidden_states = torch.cat(batch_hidden_states)
with torch.no_grad():
audio = self.decoder(batch_hidden_states)
for i in range(N):
text_id = sentence_data[idx+i][1]
sentence_id = sentence_data[idx+i][2]
audio_concat[text_id][sentence_id] = audio[i].squeeze()[-(lengths[i]*self.TOKEN_SIZE-self.TOKEN_SIZE):]
audio_concat = [torch.cat(x).cpu() for x in audio_concat]
if out_dir:
os.makedirs(out_dir, exist_ok=True)
for i in range(len(audio_concat)):
wavfile.write(f"{out_dir}/{i}.wav", 32000, audio_concat[i].cpu().numpy())
return audio_concat
def infer_stream(self,
text,
chunk_size=1,
top_p=0.95,
temperature=0.0,
repetition_penalty=1.2):
start_time = time.time()
sentence_data = self._preprocess_text([text])
first_chunk = True
for sentence, _, _ in sentence_data:
responses = self.pipeline.stream_infer(sentence,
top_p=top_p,
temperature=temperature,
repetition_penalty=repetition_penalty)
hidden_states_buffer = []
chunk_counter = chunk_size
for token in responses:
finished = token['finish_reason'] is not None
if not finished: hidden_states_buffer.append(token['hidden_state'][-1])
hidden_states_buffer = hidden_states_buffer[-(2*self.RECEPTIVE_FIELD+chunk_size):]
if finished or len(hidden_states_buffer) >= self.RECEPTIVE_FIELD + chunk_size:
if finished or chunk_counter == chunk_size:
batch_hidden_states = torch.stack(hidden_states_buffer)
inp = batch_hidden_states.unsqueeze(0).transpose(1, 2).to(self.device).to(torch.float32)
with torch.no_grad():
audio = self.decoder(inp)[0]
if finished:
audio_chunk = audio[-((self.RECEPTIVE_FIELD+chunk_counter-1)*self.TOKEN_SIZE-self.TOKEN_SIZE):]
else:
audio_chunk = audio[-((self.RECEPTIVE_FIELD+chunk_size)*self.TOKEN_SIZE-self.TOKEN_SIZE):-(self.RECEPTIVE_FIELD*self.TOKEN_SIZE-self.TOKEN_SIZE)]
chunk_counter = 0
if first_chunk:
print(f"Streaming latency: {1000*(time.time()-start_time):.2f} ms")
first_chunk = False
yield audio_chunk.cpu()
chunk_counter += 1