//go:build xlm package punctuation import ( "strings" "go-whisper-api/punctuation/internal/spwrap" ) func decodeXLMSegment( sp *spwrap.Processor, cfg xlmModelConfig, inputIDs []int, prePred, postPred []int64, capPred [][]bool, sbdPred []bool, joinSentences bool, ) (string, error) { var outputTexts []string current := make([]string, 0, len(inputIDs)*4) for tokenIdx := 1; tokenIdx < len(inputIDs)-1; tokenIdx++ { piece, err := sp.IDToPiece(inputIDs[tokenIdx]) if err != nil { return "", err } if strings.HasPrefix(piece, "▁") && len(current) > 0 { current = append(current, " ") } preLabel := labelAt(cfg.PreLabels, prePred, tokenIdx) postLabel := labelAt(cfg.PostLabels, postPred, tokenIdx) if preLabel != cfg.NullToken { current = append(current, preLabel) } charStart := 0 if strings.HasPrefix(piece, "▁") { charStart = 1 } runes := []rune(piece) for tokenCharIdx := charStart; tokenCharIdx < len(runes); tokenCharIdx++ { ch := string(runes[tokenCharIdx]) if capAt(capPred, tokenIdx, tokenCharIdx) { ch = strings.ToUpper(ch) } current = append(current, ch) if postLabel == cfg.Acronym { current = append(current, ".") } } if postLabel != cfg.NullToken && postLabel != cfg.Acronym { current = append(current, postLabel) } if sbdAt(sbdPred, tokenIdx) { outputTexts = append(outputTexts, strings.Join(current, "")) current = current[:0] } } if len(current) > 0 { outputTexts = append(outputTexts, strings.Join(current, "")) } if len(outputTexts) == 0 { return "", nil } if joinSentences { return strings.Join(outputTexts, " "), nil } return outputTexts[0], nil } func labelAt(labels []string, preds []int64, idx int) string { if idx < 0 || idx >= len(preds) { return labels[0] } pi := int(preds[idx]) if pi < 0 || pi >= len(labels) { return labels[0] } return labels[pi] } func capAt(capPred [][]bool, tokenIdx, charIdx int) bool { if tokenIdx < 0 || tokenIdx >= len(capPred) { return false } row := capPred[tokenIdx] if charIdx < 0 || charIdx >= len(row) { return false } return row[charIdx] } func sbdAt(sbd []bool, tokenIdx int) bool { if tokenIdx < 0 || tokenIdx >= len(sbd) { return false } return sbd[tokenIdx] }