Some checks failed
CodeQL / Analyze (go) (push) Successful in 6m28s
Docker Image / build-docker (push) Failing after 13m26s
Lint and Testing / lint (push) Successful in 11m17s
Lint and Testing / test (push) Successful in 11m17s
Lint and Testing / golangci (push) Successful in 2m40s
98 lines
2.6 KiB
Go
98 lines
2.6 KiB
Go
//go:build xlm
|
|
|
|
package punctuation
|
|
|
|
import (
|
|
"strings"
|
|
|
|
"go-whisper-api/punctuation/internal/spwrap"
|
|
)
|
|
|
|
func decodeXLMSegment(
|
|
sp *spwrap.Processor,
|
|
cfg xlmModelConfig,
|
|
inputIDs []int,
|
|
prePred, postPred []int64,
|
|
capPred [][]bool,
|
|
sbdPred []bool,
|
|
joinSentences bool,
|
|
) (string, error) {
|
|
var outputTexts []string
|
|
current := make([]string, 0, len(inputIDs)*4)
|
|
for tokenIdx := 1; tokenIdx < len(inputIDs)-1; tokenIdx++ {
|
|
piece, err := sp.IDToPiece(inputIDs[tokenIdx])
|
|
if err != nil {
|
|
return "", err
|
|
}
|
|
if strings.HasPrefix(piece, "▁") && len(current) > 0 {
|
|
current = append(current, " ")
|
|
}
|
|
preLabel := labelAt(cfg.PreLabels, prePred, tokenIdx)
|
|
postLabel := labelAt(cfg.PostLabels, postPred, tokenIdx)
|
|
if preLabel != cfg.NullToken {
|
|
current = append(current, preLabel)
|
|
}
|
|
charStart := 0
|
|
if strings.HasPrefix(piece, "▁") {
|
|
charStart = 1
|
|
}
|
|
runes := []rune(piece)
|
|
for tokenCharIdx := charStart; tokenCharIdx < len(runes); tokenCharIdx++ {
|
|
ch := string(runes[tokenCharIdx])
|
|
if capAt(capPred, tokenIdx, tokenCharIdx) {
|
|
ch = strings.ToUpper(ch)
|
|
}
|
|
current = append(current, ch)
|
|
if postLabel == cfg.Acronym {
|
|
current = append(current, ".")
|
|
}
|
|
}
|
|
if postLabel != cfg.NullToken && postLabel != cfg.Acronym {
|
|
current = append(current, postLabel)
|
|
}
|
|
if sbdAt(sbdPred, tokenIdx) {
|
|
outputTexts = append(outputTexts, strings.Join(current, ""))
|
|
current = current[:0]
|
|
}
|
|
}
|
|
if len(current) > 0 {
|
|
outputTexts = append(outputTexts, strings.Join(current, ""))
|
|
}
|
|
if len(outputTexts) == 0 {
|
|
return "", nil
|
|
}
|
|
if joinSentences {
|
|
return strings.Join(outputTexts, " "), nil
|
|
}
|
|
return outputTexts[0], nil
|
|
}
|
|
|
|
func labelAt(labels []string, preds []int64, idx int) string {
|
|
if idx < 0 || idx >= len(preds) {
|
|
return labels[0]
|
|
}
|
|
pi := int(preds[idx])
|
|
if pi < 0 || pi >= len(labels) {
|
|
return labels[0]
|
|
}
|
|
return labels[pi]
|
|
}
|
|
|
|
func capAt(capPred [][]bool, tokenIdx, charIdx int) bool {
|
|
if tokenIdx < 0 || tokenIdx >= len(capPred) {
|
|
return false
|
|
}
|
|
row := capPred[tokenIdx]
|
|
if charIdx < 0 || charIdx >= len(row) {
|
|
return false
|
|
}
|
|
return row[charIdx]
|
|
}
|
|
|
|
func sbdAt(sbd []bool, tokenIdx int) bool {
|
|
if tokenIdx < 0 || tokenIdx >= len(sbd) {
|
|
return false
|
|
}
|
|
return sbd[tokenIdx]
|
|
}
|