Skip to content

Commit

Permalink
Fix some badcase for sentence determine.
Browse files Browse the repository at this point in the history
  • Loading branch information
winlinvip committed Jan 27, 2024
1 parent 6ae4512 commit 3a540cc
Show file tree
Hide file tree
Showing 3 changed files with 125 additions and 64 deletions.
23 changes: 17 additions & 6 deletions backend/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ type ASRResult struct {
}

type ASRService interface {
RequestASR(ctx context.Context, filepath, language, prompt string) (*ASRResult, error)
RequestASR(ctx context.Context, filepath, language, prompt string, onBeforeRequest func()) (*ASRResult, error)
}

type TTSService interface {
Expand Down Expand Up @@ -117,6 +117,8 @@ type Stage struct {
lastSentence time.Time
// The time for last upload audio.
lastUploadAudio time.Time
// The time for last extract audio for ASR.
lastExtractAudio time.Time
// The time for last request ASR result.
lastRequestASR time.Time
// The last request ASR text.
Expand Down Expand Up @@ -183,9 +185,16 @@ func (v *Stage) upload() float64 {
return 0
}

func (v *Stage) exta() float64 {
if v.lastExtractAudio.After(v.lastUploadAudio) {
return float64(v.lastExtractAudio.Sub(v.lastUploadAudio)) / float64(time.Second)
}
return 0
}

func (v *Stage) asr() float64 {
if v.lastRequestASR.After(v.lastUploadAudio) {
return float64(v.lastRequestASR.Sub(v.lastUploadAudio)) / float64(time.Second)
if v.lastRequestASR.After(v.lastExtractAudio) {
return float64(v.lastRequestASR.Sub(v.lastExtractAudio)) / float64(time.Second)
}
return 0
}
Expand Down Expand Up @@ -591,7 +600,9 @@ func handleUploadQuestionAudio(ctx context.Context, w http.ResponseWriter, r *ht

// Do ASR, convert to text.
var asrText string
if resp, err := asrService.RequestASR(ctx, inputFile, robot.asrLanguage, stage.previousAsrText); err != nil {
if resp, err := asrService.RequestASR(ctx, inputFile, robot.asrLanguage, stage.previousAsrText, func() {
stage.lastExtractAudio = time.Now()
}); err != nil {
return errors.Wrapf(err, "transcription")
} else {
asrText = strings.TrimSpace(resp.Text)
Expand Down Expand Up @@ -768,8 +779,8 @@ func handleDownloadAnswerTTS(ctx context.Context, w http.ResponseWriter, r *http
if !segment.logged && segment.first {
stage.lastDownloadAudio = time.Now()
speech := float64(stage.lastAsrDuration) / float64(time.Second)
logger.Tf(ctx, "Report cost total=%.1fs, steps=[upload=%.1fs,asr=%.1fs,chat=%.1fs,tts=%.1fs,download=%.1fs], ask=%v, speech=%.1fs, answer=%v",
stage.total(), stage.upload(), stage.asr(), stage.chat(), stage.tts(), stage.download(),
logger.Tf(ctx, "Elapsed cost total=%.1fs, steps=[upload=%.1fs,exta=%.1fs,asr=%.1fs,chat=%.1fs,tts=%.1fs,download=%.1fs], ask=%v, speech=%.1fs, answer=%v",
stage.total(), stage.upload(), stage.exta(), stage.asr(), stage.chat(), stage.tts(), stage.download(),
stage.lastRequestAsrText, speech, stage.lastRobotFirstText)
}

Expand Down
160 changes: 103 additions & 57 deletions backend/openai.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ import (
"io"
"os"
"os/exec"
"regexp"
"strconv"
"strings"
"time"
Expand Down Expand Up @@ -70,11 +71,15 @@ func openaiInit(ctx context.Context) {
type openaiASRService struct {
}

func NewOpenAIASRService() ASRService {
return &openaiASRService{}
func NewOpenAIASRService(opts ...func(service *openaiASRService)) ASRService {
v := &openaiASRService{}
for _, opt := range opts {
opt(v)
}
return v
}

func (v *openaiASRService) RequestASR(ctx context.Context, inputFile, language, prompt string) (*ASRResult, error) {
func (v *openaiASRService) RequestASR(ctx context.Context, inputFile, language, prompt string, onBeforeRequest func()) (*ASRResult, error) {
outputFile := fmt.Sprintf("%v.m4a", inputFile)

// Transcode input audio in opus or aac, to aac in m4a format.
Expand All @@ -99,6 +104,10 @@ func (v *openaiASRService) RequestASR(ctx context.Context, inputFile, language,
return nil, errors.Wrapf(err, "ffprobe")
}

if onBeforeRequest != nil {
onBeforeRequest()
}

// Request ASR.
client := openai.NewClientWithConfig(asrAIConfig)
resp, err := client.CreateTranscription(
Expand Down Expand Up @@ -252,44 +261,31 @@ func (v *openaiChatService) handle(ctx context.Context, stage *Stage, robot *Rob
stage.generating = false
}()

var sentence string
var finished bool
firstSentense := true
for !finished && ctx.Err() == nil {
response, err := gptChatStream.Recv()
finished = errors_std.Is(err, io.EOF)
filterAIResponse := func(response *openai.ChatCompletionStreamResponse, err error) (bool, string, error) {
finished := errors_std.Is(err, io.EOF)
if err != nil && !finished {
return errors.Wrapf(err, "recv chat")
return finished, "", errors.Wrapf(err, "recv chat")
}

newSentence := false
if len(response.Choices) > 0 {
choice := response.Choices[0]
if dc := choice.Delta.Content; dc != "" {
filteredStencese := strings.ReplaceAll(dc, "\n\n", "\n")
filteredStencese = strings.ReplaceAll(filteredStencese, "\n", " ")
sentence += filteredStencese

// Any ASCII character to split sentence.
if strings.ContainsAny(dc, ",.?!\n") {
newSentence = true
}

// Any Chinese character to split sentence.
if strings.ContainsRune(dc, '。') ||
strings.ContainsRune(dc, '?') ||
strings.ContainsRune(dc, '!') ||
strings.ContainsRune(dc, ',') {
newSentence = true
}
//logger.Tf(ctx, "AI response: text=%v, new=%v", dc, newSentence)
}
if len(response.Choices) == 0 {
return finished, "", nil
}

if sentence == "" {
continue
choice := response.Choices[0]
dc := choice.Delta.Content
if dc == "" {
return finished, "", nil
}

filteredStencese := strings.ReplaceAll(dc, "\n\n", "\n")
filteredStencese = strings.ReplaceAll(filteredStencese, "\n", " ")

return finished, filteredStencese, nil
}

gotNewSentence := func(sentence, lastWords string, firstSentense bool) bool {
newSentence := false

isEnglish := func(s string) bool {
for _, r := range s {
if r > unicode.MaxASCII {
Expand All @@ -299,10 +295,34 @@ func (v *openaiChatService) handle(ctx context.Context, stage *Stage, robot *Rob
return true
}

// Ignore empty.
if sentence == "" {
return newSentence
}

// Any ASCII character to split sentence.
if strings.ContainsAny(lastWords, ",.?!\n") {
newSentence = true
}

// Any Chinese character to split sentence.
if strings.ContainsRune(lastWords, '。') ||
strings.ContainsRune(lastWords, '?') ||
strings.ContainsRune(lastWords, '!') ||
strings.ContainsRune(lastWords, ',') {
newSentence = true
}

// Badcase, for number such as 1.3, or 1,300,000.
var badcase bool
if match, _ := regexp.MatchString(`\d+(\.|,)\d*$`, sentence); match {
badcase, newSentence = true, false
}

// Determine whether new sentence by length.
if isEnglish(sentence) {
maxWords, minWords := 30, 3
if !firstSentense {
if !firstSentense || badcase {
maxWords, minWords = 50, 5
}

Expand All @@ -313,7 +333,7 @@ func (v *openaiChatService) handle(ctx context.Context, stage *Stage, robot *Rob
}
} else {
maxWords, minWords := 50, 3
if !firstSentense {
if !firstSentense || badcase {
maxWords, minWords = 100, 5
}

Expand All @@ -324,30 +344,56 @@ func (v *openaiChatService) handle(ctx context.Context, stage *Stage, robot *Rob
}
}

if finished || newSentence {
stage.previousAssitant += sentence + " "
// We utilize user ASR and AI responses as prompts for the subsequent ASR, given that this is
// a chat-based scenario where the user converses with the AI, and the following audio should pertain to both user and AI text.
stage.previousAsrText += " " + sentence

isFirstSentence := firstSentense
if firstSentense {
firstSentense = false
if robot.prefix != "" {
sentence = fmt.Sprintf("%v %v", robot.prefix, sentence)
}
if v.onFirstResponse != nil {
v.onFirstResponse(ctx, sentence)
}
return newSentence
}

commitAISentence := func(sentence string, firstSentense bool) {
if sentence == "" {
return
}

if firstSentense {
if robot.prefix != "" {
sentence = fmt.Sprintf("%v %v", robot.prefix, sentence)
}
if v.onFirstResponse != nil {
v.onFirstResponse(ctx, sentence)
}
}

stage.ttsWorker.SubmitSegment(ctx, stage, NewAnswerSegment(func(segment *AnswerSegment) {
segment.rid = rid
segment.text = sentence
segment.first = firstSentense
}))
return
}

var sentence, lastWords string
isFinished, firstSentense := false, true
for !isFinished && ctx.Err() == nil {
response, err := gptChatStream.Recv()
if finished, words, err := filterAIResponse(&response, err); err != nil {
return errors.Wrapf(err, "filter")
} else {
isFinished, sentence, lastWords = finished, sentence+words, words
}
logger.Tf(ctx, "AI response: text=%v plus %v", lastWords, sentence)

stage.ttsWorker.SubmitSegment(ctx, stage, NewAnswerSegment(func(segment *AnswerSegment) {
segment.rid = rid
segment.text = sentence
segment.first = isFirstSentence
}))
sentence = ""
newSentence := gotNewSentence(sentence, lastWords, firstSentense)
if !isFinished && !newSentence {
continue
}

// Use the sentence for prompt and logging.
stage.previousAssitant += sentence + " "
// We utilize user ASR and AI responses as prompts for the subsequent ASR, given that this is
// a chat-based scenario where the user converses with the AI, and the following audio should pertain to both user and AI text.
stage.previousAsrText += " " + sentence
// Commit the sentense to TTS worker and callbacks.
commitAISentence(sentence, firstSentense)
// Reset the sentence, because we have committed it.
sentence, firstSentense = "", false
}

return nil
Expand Down
6 changes: 5 additions & 1 deletion backend/tencent.go
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ func NewTencentASRService() ASRService {
return &tencentASRService{}
}

func (v *tencentASRService) RequestASR(ctx context.Context, inputFile, language, prompt string) (*ASRResult, error) {
func (v *tencentASRService) RequestASR(ctx context.Context, inputFile, language, prompt string, onBeforeRequest func()) (*ASRResult, error) {
outputFile := fmt.Sprintf("%v.wav", inputFile)

// Transcode input audio in opus or aac, to aac in m4a format.
Expand All @@ -70,6 +70,10 @@ func (v *tencentASRService) RequestASR(ctx context.Context, inputFile, language,
return nil, errors.Wrapf(err, "ffprobe")
}

if onBeforeRequest != nil {
onBeforeRequest()
}

// Request ASR.
EngineModelType := "16k_zh"
if language == "en" {
Expand Down

0 comments on commit 3a540cc

Please sign in to comment.