Some checks failed
CodeQL / Analyze (go) (push) Successful in 6m28s
Docker Image / build-docker (push) Failing after 13m26s
Lint and Testing / lint (push) Successful in 11m17s
Lint and Testing / test (push) Successful in 11m17s
Lint and Testing / golangci (push) Successful in 2m40s
320 lines
8.6 KiB
Go
320 lines
8.6 KiB
Go
//go:build xlm
|
|
|
|
package punctuation
|
|
|
|
import (
|
|
"context"
|
|
"fmt"
|
|
"os"
|
|
"strings"
|
|
"unicode"
|
|
|
|
"go-whisper-api/config"
|
|
"go-whisper-api/punctuation/internal/spwrap"
|
|
|
|
ort "github.com/yalue/onnxruntime_go"
|
|
)
|
|
|
|
type XLM struct {
|
|
cfg config.Punctuation
|
|
modelCfg xlmModelConfig
|
|
sp *spwrap.Processor
|
|
session *ort.DynamicAdvancedSession
|
|
inputName string
|
|
outputNames []string
|
|
joinSBD bool
|
|
}
|
|
|
|
func newXLM(cfg config.Punctuation) (*XLM, error) {
|
|
if err := ensureORT(); err != nil {
|
|
return nil, fmt.Errorf("onnxruntime: %w (set ONNXRUNTIME_SHARED_LIBRARY_PATH or install sherpa-onnx libs)", err)
|
|
}
|
|
onnxPath := cfg.ModelPath()
|
|
if _, err := os.Stat(onnxPath); err != nil {
|
|
return nil, fmt.Errorf("xlm onnx model not found at %s (run: make download-xlm-punctuation-model)", onnxPath)
|
|
}
|
|
spPath := cfg.SPModelPath()
|
|
if _, err := os.Stat(spPath); err != nil {
|
|
return nil, fmt.Errorf("xlm sp.model not found at %s (run: make download-xlm-punctuation-model)", spPath)
|
|
}
|
|
cfgPath := cfg.XLMConfigPath()
|
|
modelCfg, err := loadXLMConfig(cfgPath)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
sp, err := spwrap.Load(spPath)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
inputs, outputs, err := ort.GetInputOutputInfo(onnxPath)
|
|
if err != nil {
|
|
sp.Close()
|
|
return nil, fmt.Errorf("xlm model io: %w", err)
|
|
}
|
|
if len(inputs) == 0 || len(outputs) < 4 {
|
|
sp.Close()
|
|
return nil, fmt.Errorf("xlm model: unexpected inputs/outputs")
|
|
}
|
|
inNames := make([]string, len(inputs))
|
|
for i, in := range inputs {
|
|
inNames[i] = in.Name
|
|
}
|
|
outNames := make([]string, len(outputs))
|
|
for i, out := range outputs {
|
|
outNames[i] = out.Name
|
|
}
|
|
session, err := ort.NewDynamicAdvancedSession(onnxPath, inNames, outNames, nil)
|
|
if err != nil {
|
|
sp.Close()
|
|
return nil, fmt.Errorf("xlm onnx session: %w", err)
|
|
}
|
|
return &XLM{
|
|
cfg: cfg,
|
|
modelCfg: modelCfg,
|
|
sp: sp,
|
|
session: session,
|
|
inputName: inNames[0],
|
|
outputNames: outNames,
|
|
joinSBD: cfg.XLMJoinSentences(),
|
|
}, nil
|
|
}
|
|
|
|
func (x *XLM) Active() bool {
|
|
return true
|
|
}
|
|
|
|
func (x *XLM) Close() {
|
|
if x.session != nil {
|
|
_ = x.session.Destroy()
|
|
x.session = nil
|
|
}
|
|
if x.sp != nil {
|
|
x.sp.Close()
|
|
x.sp = nil
|
|
}
|
|
}
|
|
|
|
func (x *XLM) Restore(ctx context.Context, text, language string) (string, error) {
|
|
if err := ctx.Err(); err != nil {
|
|
return "", err
|
|
}
|
|
text = strings.TrimSpace(normalizeXLMSpaces(text))
|
|
if text == "" {
|
|
return text, nil
|
|
}
|
|
|
|
ids, err := x.sp.EncodeAsIDs(text)
|
|
if err != nil {
|
|
return "", err
|
|
}
|
|
full := make([]int, 0, len(ids)+2)
|
|
full = append(full, x.sp.BOSID())
|
|
full = append(full, ids...)
|
|
full = append(full, x.sp.EOSID())
|
|
maxLen := x.modelCfg.MaxLength
|
|
if maxLen <= 2 {
|
|
maxLen = 256
|
|
}
|
|
if len(full) <= maxLen {
|
|
return x.inferIDs(full)
|
|
}
|
|
var parts []string
|
|
content := full[1 : len(full)-1]
|
|
step := maxLen - 2
|
|
for start := 0; start < len(content); {
|
|
end := start + step
|
|
if end > len(content) {
|
|
end = len(content)
|
|
}
|
|
chunk := make([]int, 0, end-start+2)
|
|
chunk = append(chunk, x.sp.BOSID())
|
|
chunk = append(chunk, content[start:end]...)
|
|
chunk = append(chunk, x.sp.EOSID())
|
|
out, err := x.inferIDs(chunk)
|
|
if err != nil {
|
|
return "", err
|
|
}
|
|
if out != "" {
|
|
parts = append(parts, out)
|
|
}
|
|
if end >= len(content) {
|
|
break
|
|
}
|
|
start = end
|
|
}
|
|
return strings.TrimSpace(strings.Join(parts, " ")), nil
|
|
}
|
|
|
|
func (x *XLM) inferIDs(inputIDs []int) (string, error) {
|
|
data := make([]int64, len(inputIDs))
|
|
for i, id := range inputIDs {
|
|
data[i] = int64(id)
|
|
}
|
|
inputTensor, err := ort.NewTensor(ort.NewShape(1, int64(len(inputIDs))), data)
|
|
if err != nil {
|
|
return "", err
|
|
}
|
|
defer inputTensor.Destroy()
|
|
outputs := make([]ort.Value, len(x.outputNames))
|
|
if err := x.session.Run([]ort.Value{inputTensor}, outputs); err != nil {
|
|
return "", err
|
|
}
|
|
defer destroyValues(outputs)
|
|
pre, err := int64Row(outputs[0], len(inputIDs))
|
|
if err != nil {
|
|
return "", err
|
|
}
|
|
post, err := int64Row(outputs[1], len(inputIDs))
|
|
if err != nil {
|
|
return "", err
|
|
}
|
|
cap, err := boolMatrix(outputs[2], len(inputIDs))
|
|
if err != nil {
|
|
return "", err
|
|
}
|
|
sbd, err := boolRow(outputs[3], len(inputIDs))
|
|
if err != nil {
|
|
return "", err
|
|
}
|
|
return decodeXLMSegment(x.sp, x.modelCfg, inputIDs, pre, post, cap, sbd, x.joinSBD)
|
|
}
|
|
|
|
func destroyValues(vals []ort.Value) {
|
|
for _, v := range vals {
|
|
if v != nil {
|
|
_ = v.Destroy()
|
|
}
|
|
}
|
|
}
|
|
|
|
func int64Row(v ort.Value, wantLen int) ([]int64, error) {
|
|
switch t := v.(type) {
|
|
case *ort.Tensor[int64]:
|
|
d := t.GetData()
|
|
if len(d) == wantLen {
|
|
return d, nil
|
|
}
|
|
if len(d) > wantLen {
|
|
return d[len(d)-wantLen:], nil
|
|
}
|
|
return nil, fmt.Errorf("int64 output short: %d < %d", len(d), wantLen)
|
|
case *ort.Tensor[int32]:
|
|
d := t.GetData()
|
|
if len(d) > wantLen {
|
|
d = d[len(d)-wantLen:]
|
|
}
|
|
out := make([]int64, wantLen)
|
|
for i := 0; i < wantLen && i < len(d); i++ {
|
|
out[i] = int64(d[i])
|
|
}
|
|
return out, nil
|
|
default:
|
|
return nil, fmt.Errorf("unexpected int output type %T", v)
|
|
}
|
|
}
|
|
|
|
func boolRow(v ort.Value, wantLen int) ([]bool, error) {
|
|
switch t := v.(type) {
|
|
case *ort.Tensor[bool]:
|
|
d := t.GetData()
|
|
if len(d) == wantLen {
|
|
return d, nil
|
|
}
|
|
if len(d) > wantLen {
|
|
return d[len(d)-wantLen:], nil
|
|
}
|
|
return nil, fmt.Errorf("bool output short")
|
|
case *ort.Tensor[float32]:
|
|
d := t.GetData()
|
|
out := make([]bool, wantLen)
|
|
for i := 0; i < wantLen && i < len(d); i++ {
|
|
out[i] = d[i] > 0.5
|
|
}
|
|
return out, nil
|
|
default:
|
|
return nil, fmt.Errorf("unexpected bool output type %T", v)
|
|
}
|
|
}
|
|
|
|
func boolMatrix(v ort.Value, seqLen int) ([][]bool, error) {
|
|
switch t := v.(type) {
|
|
case *ort.Tensor[bool]:
|
|
shape := t.GetShape()
|
|
d := t.GetData()
|
|
if len(shape) == 3 {
|
|
_, sl, width := shape[0], shape[1], shape[2]
|
|
out := make([][]bool, sl)
|
|
for i := 0; i < int(sl); i++ {
|
|
row := make([]bool, width)
|
|
base := int(i) * int(width)
|
|
copy(row, d[base:base+int(width)])
|
|
out[i] = row
|
|
}
|
|
return out, nil
|
|
}
|
|
width := len(d) / seqLen
|
|
if width < 1 {
|
|
width = 1
|
|
}
|
|
out := make([][]bool, seqLen)
|
|
for i := 0; i < seqLen; i++ {
|
|
row := make([]bool, width)
|
|
base := i * width
|
|
if base+width <= len(d) {
|
|
copy(row, d[base:base+width])
|
|
}
|
|
out[i] = row
|
|
}
|
|
return out, nil
|
|
case *ort.Tensor[float32]:
|
|
shape := t.GetShape()
|
|
d := t.GetData()
|
|
if len(shape) == 3 {
|
|
_, sl, width := shape[0], shape[1], shape[2]
|
|
out := make([][]bool, sl)
|
|
for i := 0; i < int(sl); i++ {
|
|
row := make([]bool, width)
|
|
base := int(i) * int(width)
|
|
for j := 0; j < int(width); j++ {
|
|
row[j] = d[base+j] > 0.5
|
|
}
|
|
out[i] = row
|
|
}
|
|
return out, nil
|
|
}
|
|
width := len(d) / seqLen
|
|
if width < 1 {
|
|
width = 1
|
|
}
|
|
out := make([][]bool, seqLen)
|
|
for i := 0; i < seqLen; i++ {
|
|
row := make([]bool, width)
|
|
base := i * width
|
|
for j := 0; j < width && base+j < len(d); j++ {
|
|
row[j] = d[base+j] > 0.5
|
|
}
|
|
out[i] = row
|
|
}
|
|
return out, nil
|
|
default:
|
|
return nil, fmt.Errorf("unexpected cap output type %T", v)
|
|
}
|
|
}
|
|
|
|
func normalizeXLMSpaces(s string) string {
|
|
var b strings.Builder
|
|
prevSpace := false
|
|
for _, r := range s {
|
|
if unicode.IsSpace(r) {
|
|
if !prevSpace {
|
|
b.WriteRune(' ')
|
|
prevSpace = true
|
|
}
|
|
continue
|
|
}
|
|
prevSpace = false
|
|
b.WriteRune(r)
|
|
}
|
|
return b.String()
|
|
}
|