diff --git a/webui.py b/webui.py index 969111b..595f579 100644 --- a/webui.py +++ b/webui.py @@ -167,25 +167,69 @@ def generate_audio(tts_text, mode_checkbox_group, sft_dropdown, prompt_text, pro if mode_checkbox_group == '预训练音色': logging.info('get sft inference request') set_all_random_seed(seed) - for i in cosyvoice.inference_sft(tts_text, sft_dropdown, stream=stream, speed=speed,new_dropdown=new_dropdown): - yield (target_sr, i['tts_speech'].numpy().flatten()) + + if stream: + + for i in cosyvoice.inference_sft(tts_text, sft_dropdown, stream=stream, speed=speed,new_dropdown=new_dropdown): + yield (target_sr, i['tts_speech'].numpy().flatten()) + else: + + tts_speeches = [] + for i in cosyvoice.inference_sft(tts_text, sft_dropdown, stream=stream, speed=speed,new_dropdown=new_dropdown): + tts_speeches.append(i['tts_speech']) + audio_data = torch.concat(tts_speeches, dim=1) + yield (target_sr, audio_data.numpy().flatten()) + + elif mode_checkbox_group == '3s极速复刻': logging.info('get zero_shot inference request') prompt_speech_16k = postprocess(load_wav(prompt_wav, prompt_sr)) set_all_random_seed(seed) - for i in cosyvoice.inference_zero_shot(tts_text, prompt_text, prompt_speech_16k, stream=stream, speed=speed): - yield (target_sr, i['tts_speech'].numpy().flatten()) + + if stream: + + for i in cosyvoice.inference_zero_shot(tts_text, prompt_text, prompt_speech_16k, stream=stream, speed=speed): + yield (target_sr, i['tts_speech'].numpy().flatten()) + else: + + tts_speeches = [] + for i in cosyvoice.inference_zero_shot(tts_text, prompt_text, prompt_speech_16k, stream=stream, speed=speed): + tts_speeches.append(i['tts_speech']) + audio_data = torch.concat(tts_speeches, dim=1) + yield (target_sr, audio_data.numpy().flatten()) + elif mode_checkbox_group == '跨语种复刻': logging.info('get cross_lingual inference request') prompt_speech_16k = postprocess(load_wav(prompt_wav, prompt_sr)) set_all_random_seed(seed) - for i in cosyvoice.inference_cross_lingual(tts_text, prompt_speech_16k, stream=stream, speed=speed): - yield (target_sr, i['tts_speech'].numpy().flatten()) + + if stream: + + for i in cosyvoice.inference_cross_lingual(tts_text, prompt_speech_16k, stream=stream, speed=speed): + yield (target_sr, i['tts_speech'].numpy().flatten()) + else: + + tts_speeches = [] + for i in cosyvoice.inference_cross_lingual(tts_text, prompt_speech_16k, stream=stream, speed=speed): + tts_speeches.append(i['tts_speech']) + audio_data = torch.concat(tts_speeches, dim=1) + yield (target_sr, audio_data.numpy().flatten()) + else: logging.info('get instruct inference request') set_all_random_seed(seed) - for i in cosyvoice.inference_instruct(tts_text, sft_dropdown, instruct_text, stream=stream, speed=speed,new_dropdown=new_dropdown): - yield (target_sr, i['tts_speech'].numpy().flatten()) + if stream: + + for i in cosyvoice.inference_instruct(tts_text, sft_dropdown, instruct_text, stream=stream, speed=speed,new_dropdown=new_dropdown): + yield (target_sr, i['tts_speech'].numpy().flatten()) + else: + + tts_speeches = [] + for i in cosyvoice.inference_instruct(tts_text, sft_dropdown, instruct_text, stream=stream, speed=speed,new_dropdown=new_dropdown): + tts_speeches.append(i['tts_speech']) + audio_data = torch.concat(tts_speeches, dim=1) + yield (target_sr, audio_data.numpy().flatten()) + def main():