admin b5c083e06f
Some checks failed
CodeQL / Analyze (go) (push) Successful in 6m28s
Docker Image / build-docker (push) Failing after 13m26s
Lint and Testing / lint (push) Successful in 11m17s
Lint and Testing / test (push) Successful in 11m17s
Lint and Testing / golangci (push) Successful in 2m40s
first commit
2026-06-04 18:10:52 +07:00

320 lines
8.6 KiB
Go

//go:build xlm
package punctuation
import (
"context"
"fmt"
"os"
"strings"
"unicode"
"go-whisper-api/config"
"go-whisper-api/punctuation/internal/spwrap"
ort "github.com/yalue/onnxruntime_go"
)
type XLM struct {
cfg config.Punctuation
modelCfg xlmModelConfig
sp *spwrap.Processor
session *ort.DynamicAdvancedSession
inputName string
outputNames []string
joinSBD bool
}
func newXLM(cfg config.Punctuation) (*XLM, error) {
if err := ensureORT(); err != nil {
return nil, fmt.Errorf("onnxruntime: %w (set ONNXRUNTIME_SHARED_LIBRARY_PATH or install sherpa-onnx libs)", err)
}
onnxPath := cfg.ModelPath()
if _, err := os.Stat(onnxPath); err != nil {
return nil, fmt.Errorf("xlm onnx model not found at %s (run: make download-xlm-punctuation-model)", onnxPath)
}
spPath := cfg.SPModelPath()
if _, err := os.Stat(spPath); err != nil {
return nil, fmt.Errorf("xlm sp.model not found at %s (run: make download-xlm-punctuation-model)", spPath)
}
cfgPath := cfg.XLMConfigPath()
modelCfg, err := loadXLMConfig(cfgPath)
if err != nil {
return nil, err
}
sp, err := spwrap.Load(spPath)
if err != nil {
return nil, err
}
inputs, outputs, err := ort.GetInputOutputInfo(onnxPath)
if err != nil {
sp.Close()
return nil, fmt.Errorf("xlm model io: %w", err)
}
if len(inputs) == 0 || len(outputs) < 4 {
sp.Close()
return nil, fmt.Errorf("xlm model: unexpected inputs/outputs")
}
inNames := make([]string, len(inputs))
for i, in := range inputs {
inNames[i] = in.Name
}
outNames := make([]string, len(outputs))
for i, out := range outputs {
outNames[i] = out.Name
}
session, err := ort.NewDynamicAdvancedSession(onnxPath, inNames, outNames, nil)
if err != nil {
sp.Close()
return nil, fmt.Errorf("xlm onnx session: %w", err)
}
return &XLM{
cfg: cfg,
modelCfg: modelCfg,
sp: sp,
session: session,
inputName: inNames[0],
outputNames: outNames,
joinSBD: cfg.XLMJoinSentences(),
}, nil
}
func (x *XLM) Active() bool {
return true
}
func (x *XLM) Close() {
if x.session != nil {
_ = x.session.Destroy()
x.session = nil
}
if x.sp != nil {
x.sp.Close()
x.sp = nil
}
}
func (x *XLM) Restore(ctx context.Context, text, language string) (string, error) {
if err := ctx.Err(); err != nil {
return "", err
}
text = strings.TrimSpace(normalizeXLMSpaces(text))
if text == "" {
return text, nil
}
ids, err := x.sp.EncodeAsIDs(text)
if err != nil {
return "", err
}
full := make([]int, 0, len(ids)+2)
full = append(full, x.sp.BOSID())
full = append(full, ids...)
full = append(full, x.sp.EOSID())
maxLen := x.modelCfg.MaxLength
if maxLen <= 2 {
maxLen = 256
}
if len(full) <= maxLen {
return x.inferIDs(full)
}
var parts []string
content := full[1 : len(full)-1]
step := maxLen - 2
for start := 0; start < len(content); {
end := start + step
if end > len(content) {
end = len(content)
}
chunk := make([]int, 0, end-start+2)
chunk = append(chunk, x.sp.BOSID())
chunk = append(chunk, content[start:end]...)
chunk = append(chunk, x.sp.EOSID())
out, err := x.inferIDs(chunk)
if err != nil {
return "", err
}
if out != "" {
parts = append(parts, out)
}
if end >= len(content) {
break
}
start = end
}
return strings.TrimSpace(strings.Join(parts, " ")), nil
}
func (x *XLM) inferIDs(inputIDs []int) (string, error) {
data := make([]int64, len(inputIDs))
for i, id := range inputIDs {
data[i] = int64(id)
}
inputTensor, err := ort.NewTensor(ort.NewShape(1, int64(len(inputIDs))), data)
if err != nil {
return "", err
}
defer inputTensor.Destroy()
outputs := make([]ort.Value, len(x.outputNames))
if err := x.session.Run([]ort.Value{inputTensor}, outputs); err != nil {
return "", err
}
defer destroyValues(outputs)
pre, err := int64Row(outputs[0], len(inputIDs))
if err != nil {
return "", err
}
post, err := int64Row(outputs[1], len(inputIDs))
if err != nil {
return "", err
}
cap, err := boolMatrix(outputs[2], len(inputIDs))
if err != nil {
return "", err
}
sbd, err := boolRow(outputs[3], len(inputIDs))
if err != nil {
return "", err
}
return decodeXLMSegment(x.sp, x.modelCfg, inputIDs, pre, post, cap, sbd, x.joinSBD)
}
func destroyValues(vals []ort.Value) {
for _, v := range vals {
if v != nil {
_ = v.Destroy()
}
}
}
func int64Row(v ort.Value, wantLen int) ([]int64, error) {
switch t := v.(type) {
case *ort.Tensor[int64]:
d := t.GetData()
if len(d) == wantLen {
return d, nil
}
if len(d) > wantLen {
return d[len(d)-wantLen:], nil
}
return nil, fmt.Errorf("int64 output short: %d < %d", len(d), wantLen)
case *ort.Tensor[int32]:
d := t.GetData()
if len(d) > wantLen {
d = d[len(d)-wantLen:]
}
out := make([]int64, wantLen)
for i := 0; i < wantLen && i < len(d); i++ {
out[i] = int64(d[i])
}
return out, nil
default:
return nil, fmt.Errorf("unexpected int output type %T", v)
}
}
func boolRow(v ort.Value, wantLen int) ([]bool, error) {
switch t := v.(type) {
case *ort.Tensor[bool]:
d := t.GetData()
if len(d) == wantLen {
return d, nil
}
if len(d) > wantLen {
return d[len(d)-wantLen:], nil
}
return nil, fmt.Errorf("bool output short")
case *ort.Tensor[float32]:
d := t.GetData()
out := make([]bool, wantLen)
for i := 0; i < wantLen && i < len(d); i++ {
out[i] = d[i] > 0.5
}
return out, nil
default:
return nil, fmt.Errorf("unexpected bool output type %T", v)
}
}
func boolMatrix(v ort.Value, seqLen int) ([][]bool, error) {
switch t := v.(type) {
case *ort.Tensor[bool]:
shape := t.GetShape()
d := t.GetData()
if len(shape) == 3 {
_, sl, width := shape[0], shape[1], shape[2]
out := make([][]bool, sl)
for i := 0; i < int(sl); i++ {
row := make([]bool, width)
base := int(i) * int(width)
copy(row, d[base:base+int(width)])
out[i] = row
}
return out, nil
}
width := len(d) / seqLen
if width < 1 {
width = 1
}
out := make([][]bool, seqLen)
for i := 0; i < seqLen; i++ {
row := make([]bool, width)
base := i * width
if base+width <= len(d) {
copy(row, d[base:base+width])
}
out[i] = row
}
return out, nil
case *ort.Tensor[float32]:
shape := t.GetShape()
d := t.GetData()
if len(shape) == 3 {
_, sl, width := shape[0], shape[1], shape[2]
out := make([][]bool, sl)
for i := 0; i < int(sl); i++ {
row := make([]bool, width)
base := int(i) * int(width)
for j := 0; j < int(width); j++ {
row[j] = d[base+j] > 0.5
}
out[i] = row
}
return out, nil
}
width := len(d) / seqLen
if width < 1 {
width = 1
}
out := make([][]bool, seqLen)
for i := 0; i < seqLen; i++ {
row := make([]bool, width)
base := i * width
for j := 0; j < width && base+j < len(d); j++ {
row[j] = d[base+j] > 0.5
}
out[i] = row
}
return out, nil
default:
return nil, fmt.Errorf("unexpected cap output type %T", v)
}
}
func normalizeXLMSpaces(s string) string {
var b strings.Builder
prevSpace := false
for _, r := range s {
if unicode.IsSpace(r) {
if !prevSpace {
b.WriteRune(' ')
prevSpace = true
}
continue
}
prevSpace = false
b.WriteRune(r)
}
return b.String()
}