Skip to content

Commit

Permalink
修复非流式长文本推理的bug
Browse files Browse the repository at this point in the history
  • Loading branch information
v3ucn committed Oct 3, 2024
1 parent ccbba52 commit 3f139d7
Showing 1 changed file with 52 additions and 8 deletions.
60 changes: 52 additions & 8 deletions webui.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,25 +167,69 @@ def generate_audio(tts_text, mode_checkbox_group, sft_dropdown, prompt_text, pro
if mode_checkbox_group == '预训练音色':
logging.info('get sft inference request')
set_all_random_seed(seed)
for i in cosyvoice.inference_sft(tts_text, sft_dropdown, stream=stream, speed=speed,new_dropdown=new_dropdown):
yield (target_sr, i['tts_speech'].numpy().flatten())

if stream:

for i in cosyvoice.inference_sft(tts_text, sft_dropdown, stream=stream, speed=speed,new_dropdown=new_dropdown):
yield (target_sr, i['tts_speech'].numpy().flatten())
else:

tts_speeches = []
for i in cosyvoice.inference_sft(tts_text, sft_dropdown, stream=stream, speed=speed,new_dropdown=new_dropdown):
tts_speeches.append(i['tts_speech'])
audio_data = torch.concat(tts_speeches, dim=1)
yield (target_sr, audio_data.numpy().flatten())


elif mode_checkbox_group == '3s极速复刻':
logging.info('get zero_shot inference request')
prompt_speech_16k = postprocess(load_wav(prompt_wav, prompt_sr))
set_all_random_seed(seed)
for i in cosyvoice.inference_zero_shot(tts_text, prompt_text, prompt_speech_16k, stream=stream, speed=speed):
yield (target_sr, i['tts_speech'].numpy().flatten())

if stream:

for i in cosyvoice.inference_zero_shot(tts_text, prompt_text, prompt_speech_16k, stream=stream, speed=speed):
yield (target_sr, i['tts_speech'].numpy().flatten())
else:

tts_speeches = []
for i in cosyvoice.inference_zero_shot(tts_text, prompt_text, prompt_speech_16k, stream=stream, speed=speed):
tts_speeches.append(i['tts_speech'])
audio_data = torch.concat(tts_speeches, dim=1)
yield (target_sr, audio_data.numpy().flatten())

elif mode_checkbox_group == '跨语种复刻':
logging.info('get cross_lingual inference request')
prompt_speech_16k = postprocess(load_wav(prompt_wav, prompt_sr))
set_all_random_seed(seed)
for i in cosyvoice.inference_cross_lingual(tts_text, prompt_speech_16k, stream=stream, speed=speed):
yield (target_sr, i['tts_speech'].numpy().flatten())

if stream:

for i in cosyvoice.inference_cross_lingual(tts_text, prompt_speech_16k, stream=stream, speed=speed):
yield (target_sr, i['tts_speech'].numpy().flatten())
else:

tts_speeches = []
for i in cosyvoice.inference_cross_lingual(tts_text, prompt_speech_16k, stream=stream, speed=speed):
tts_speeches.append(i['tts_speech'])
audio_data = torch.concat(tts_speeches, dim=1)
yield (target_sr, audio_data.numpy().flatten())

else:
logging.info('get instruct inference request')
set_all_random_seed(seed)
for i in cosyvoice.inference_instruct(tts_text, sft_dropdown, instruct_text, stream=stream, speed=speed,new_dropdown=new_dropdown):
yield (target_sr, i['tts_speech'].numpy().flatten())
if stream:

for i in cosyvoice.inference_instruct(tts_text, sft_dropdown, instruct_text, stream=stream, speed=speed,new_dropdown=new_dropdown):
yield (target_sr, i['tts_speech'].numpy().flatten())
else:

tts_speeches = []
for i in cosyvoice.inference_instruct(tts_text, sft_dropdown, instruct_text, stream=stream, speed=speed,new_dropdown=new_dropdown):
tts_speeches.append(i['tts_speech'])
audio_data = torch.concat(tts_speeches, dim=1)
yield (target_sr, audio_data.numpy().flatten())



def main():
Expand Down

0 comments on commit 3f139d7

Please sign in to comment.