//go:build xlm package punctuation import ( "context" "fmt" "os" "strings" "unicode" "go-whisper-api/config" "go-whisper-api/punctuation/internal/spwrap" ort "github.com/yalue/onnxruntime_go" ) type XLM struct { cfg config.Punctuation modelCfg xlmModelConfig sp *spwrap.Processor session *ort.DynamicAdvancedSession inputName string outputNames []string joinSBD bool } func newXLM(cfg config.Punctuation) (*XLM, error) { if err := ensureORT(); err != nil { return nil, fmt.Errorf("onnxruntime: %w (set ONNXRUNTIME_SHARED_LIBRARY_PATH or install sherpa-onnx libs)", err) } onnxPath := cfg.ModelPath() if _, err := os.Stat(onnxPath); err != nil { return nil, fmt.Errorf("xlm onnx model not found at %s (run: make download-xlm-punctuation-model)", onnxPath) } spPath := cfg.SPModelPath() if _, err := os.Stat(spPath); err != nil { return nil, fmt.Errorf("xlm sp.model not found at %s (run: make download-xlm-punctuation-model)", spPath) } cfgPath := cfg.XLMConfigPath() modelCfg, err := loadXLMConfig(cfgPath) if err != nil { return nil, err } sp, err := spwrap.Load(spPath) if err != nil { return nil, err } inputs, outputs, err := ort.GetInputOutputInfo(onnxPath) if err != nil { sp.Close() return nil, fmt.Errorf("xlm model io: %w", err) } if len(inputs) == 0 || len(outputs) < 4 { sp.Close() return nil, fmt.Errorf("xlm model: unexpected inputs/outputs") } inNames := make([]string, len(inputs)) for i, in := range inputs { inNames[i] = in.Name } outNames := make([]string, len(outputs)) for i, out := range outputs { outNames[i] = out.Name } session, err := ort.NewDynamicAdvancedSession(onnxPath, inNames, outNames, nil) if err != nil { sp.Close() return nil, fmt.Errorf("xlm onnx session: %w", err) } return &XLM{ cfg: cfg, modelCfg: modelCfg, sp: sp, session: session, inputName: inNames[0], outputNames: outNames, joinSBD: cfg.XLMJoinSentences(), }, nil } func (x *XLM) Active() bool { return true } func (x *XLM) Close() { if x.session != nil { _ = x.session.Destroy() x.session = nil } if x.sp != nil { x.sp.Close() x.sp = nil } } func (x *XLM) Restore(ctx context.Context, text, language string) (string, error) { if err := ctx.Err(); err != nil { return "", err } text = strings.TrimSpace(normalizeXLMSpaces(text)) if text == "" { return text, nil } ids, err := x.sp.EncodeAsIDs(text) if err != nil { return "", err } full := make([]int, 0, len(ids)+2) full = append(full, x.sp.BOSID()) full = append(full, ids...) full = append(full, x.sp.EOSID()) maxLen := x.modelCfg.MaxLength if maxLen <= 2 { maxLen = 256 } if len(full) <= maxLen { return x.inferIDs(full) } var parts []string content := full[1 : len(full)-1] step := maxLen - 2 for start := 0; start < len(content); { end := start + step if end > len(content) { end = len(content) } chunk := make([]int, 0, end-start+2) chunk = append(chunk, x.sp.BOSID()) chunk = append(chunk, content[start:end]...) chunk = append(chunk, x.sp.EOSID()) out, err := x.inferIDs(chunk) if err != nil { return "", err } if out != "" { parts = append(parts, out) } if end >= len(content) { break } start = end } return strings.TrimSpace(strings.Join(parts, " ")), nil } func (x *XLM) inferIDs(inputIDs []int) (string, error) { data := make([]int64, len(inputIDs)) for i, id := range inputIDs { data[i] = int64(id) } inputTensor, err := ort.NewTensor(ort.NewShape(1, int64(len(inputIDs))), data) if err != nil { return "", err } defer inputTensor.Destroy() outputs := make([]ort.Value, len(x.outputNames)) if err := x.session.Run([]ort.Value{inputTensor}, outputs); err != nil { return "", err } defer destroyValues(outputs) pre, err := int64Row(outputs[0], len(inputIDs)) if err != nil { return "", err } post, err := int64Row(outputs[1], len(inputIDs)) if err != nil { return "", err } cap, err := boolMatrix(outputs[2], len(inputIDs)) if err != nil { return "", err } sbd, err := boolRow(outputs[3], len(inputIDs)) if err != nil { return "", err } return decodeXLMSegment(x.sp, x.modelCfg, inputIDs, pre, post, cap, sbd, x.joinSBD) } func destroyValues(vals []ort.Value) { for _, v := range vals { if v != nil { _ = v.Destroy() } } } func int64Row(v ort.Value, wantLen int) ([]int64, error) { switch t := v.(type) { case *ort.Tensor[int64]: d := t.GetData() if len(d) == wantLen { return d, nil } if len(d) > wantLen { return d[len(d)-wantLen:], nil } return nil, fmt.Errorf("int64 output short: %d < %d", len(d), wantLen) case *ort.Tensor[int32]: d := t.GetData() if len(d) > wantLen { d = d[len(d)-wantLen:] } out := make([]int64, wantLen) for i := 0; i < wantLen && i < len(d); i++ { out[i] = int64(d[i]) } return out, nil default: return nil, fmt.Errorf("unexpected int output type %T", v) } } func boolRow(v ort.Value, wantLen int) ([]bool, error) { switch t := v.(type) { case *ort.Tensor[bool]: d := t.GetData() if len(d) == wantLen { return d, nil } if len(d) > wantLen { return d[len(d)-wantLen:], nil } return nil, fmt.Errorf("bool output short") case *ort.Tensor[float32]: d := t.GetData() out := make([]bool, wantLen) for i := 0; i < wantLen && i < len(d); i++ { out[i] = d[i] > 0.5 } return out, nil default: return nil, fmt.Errorf("unexpected bool output type %T", v) } } func boolMatrix(v ort.Value, seqLen int) ([][]bool, error) { switch t := v.(type) { case *ort.Tensor[bool]: shape := t.GetShape() d := t.GetData() if len(shape) == 3 { _, sl, width := shape[0], shape[1], shape[2] out := make([][]bool, sl) for i := 0; i < int(sl); i++ { row := make([]bool, width) base := int(i) * int(width) copy(row, d[base:base+int(width)]) out[i] = row } return out, nil } width := len(d) / seqLen if width < 1 { width = 1 } out := make([][]bool, seqLen) for i := 0; i < seqLen; i++ { row := make([]bool, width) base := i * width if base+width <= len(d) { copy(row, d[base:base+width]) } out[i] = row } return out, nil case *ort.Tensor[float32]: shape := t.GetShape() d := t.GetData() if len(shape) == 3 { _, sl, width := shape[0], shape[1], shape[2] out := make([][]bool, sl) for i := 0; i < int(sl); i++ { row := make([]bool, width) base := int(i) * int(width) for j := 0; j < int(width); j++ { row[j] = d[base+j] > 0.5 } out[i] = row } return out, nil } width := len(d) / seqLen if width < 1 { width = 1 } out := make([][]bool, seqLen) for i := 0; i < seqLen; i++ { row := make([]bool, width) base := i * width for j := 0; j < width && base+j < len(d); j++ { row[j] = d[base+j] > 0.5 } out[i] = row } return out, nil default: return nil, fmt.Errorf("unexpected cap output type %T", v) } } func normalizeXLMSpaces(s string) string { var b strings.Builder prevSpace := false for _, r := range s { if unicode.IsSpace(r) { if !prevSpace { b.WriteRune(' ') prevSpace = true } continue } prevSpace = false b.WriteRune(r) } return b.String() }