go-whisper-api/punctuation/xlm_decode.go
admin b5c083e06f
Some checks failed
CodeQL / Analyze (go) (push) Successful in 6m28s
Docker Image / build-docker (push) Failing after 13m26s
Lint and Testing / lint (push) Successful in 11m17s
Lint and Testing / test (push) Successful in 11m17s
Lint and Testing / golangci (push) Successful in 2m40s
first commit
2026-06-04 18:10:52 +07:00

98 lines
2.6 KiB
Go

//go:build xlm
package punctuation
import (
"strings"
"go-whisper-api/punctuation/internal/spwrap"
)
func decodeXLMSegment(
sp *spwrap.Processor,
cfg xlmModelConfig,
inputIDs []int,
prePred, postPred []int64,
capPred [][]bool,
sbdPred []bool,
joinSentences bool,
) (string, error) {
var outputTexts []string
current := make([]string, 0, len(inputIDs)*4)
for tokenIdx := 1; tokenIdx < len(inputIDs)-1; tokenIdx++ {
piece, err := sp.IDToPiece(inputIDs[tokenIdx])
if err != nil {
return "", err
}
if strings.HasPrefix(piece, "▁") && len(current) > 0 {
current = append(current, " ")
}
preLabel := labelAt(cfg.PreLabels, prePred, tokenIdx)
postLabel := labelAt(cfg.PostLabels, postPred, tokenIdx)
if preLabel != cfg.NullToken {
current = append(current, preLabel)
}
charStart := 0
if strings.HasPrefix(piece, "▁") {
charStart = 1
}
runes := []rune(piece)
for tokenCharIdx := charStart; tokenCharIdx < len(runes); tokenCharIdx++ {
ch := string(runes[tokenCharIdx])
if capAt(capPred, tokenIdx, tokenCharIdx) {
ch = strings.ToUpper(ch)
}
current = append(current, ch)
if postLabel == cfg.Acronym {
current = append(current, ".")
}
}
if postLabel != cfg.NullToken && postLabel != cfg.Acronym {
current = append(current, postLabel)
}
if sbdAt(sbdPred, tokenIdx) {
outputTexts = append(outputTexts, strings.Join(current, ""))
current = current[:0]
}
}
if len(current) > 0 {
outputTexts = append(outputTexts, strings.Join(current, ""))
}
if len(outputTexts) == 0 {
return "", nil
}
if joinSentences {
return strings.Join(outputTexts, " "), nil
}
return outputTexts[0], nil
}
func labelAt(labels []string, preds []int64, idx int) string {
if idx < 0 || idx >= len(preds) {
return labels[0]
}
pi := int(preds[idx])
if pi < 0 || pi >= len(labels) {
return labels[0]
}
return labels[pi]
}
func capAt(capPred [][]bool, tokenIdx, charIdx int) bool {
if tokenIdx < 0 || tokenIdx >= len(capPred) {
return false
}
row := capPred[tokenIdx]
if charIdx < 0 || charIdx >= len(row) {
return false
}
return row[charIdx]
}
func sbdAt(sbd []bool, tokenIdx int) bool {
if tokenIdx < 0 || tokenIdx >= len(sbd) {
return false
}
return sbd[tokenIdx]
}