245 lines
6.2 KiB
Go
245 lines
6.2 KiB
Go
package whisper
|
|
|
|
import (
|
|
"fmt"
|
|
"os"
|
|
"path"
|
|
"path/filepath"
|
|
"strings"
|
|
"time"
|
|
|
|
"go-whisper-api/config"
|
|
|
|
"github.com/ggerganov/whisper.cpp/bindings/go/pkg/whisper"
|
|
"github.com/rs/zerolog/log"
|
|
)
|
|
|
|
type OutputFormat string
|
|
|
|
func (f OutputFormat) String() string {
|
|
return string(f)
|
|
}
|
|
|
|
var (
|
|
FormatTxt OutputFormat = "txt"
|
|
FormatSrt OutputFormat = "srt"
|
|
FormatCSV OutputFormat = "csv"
|
|
)
|
|
|
|
type Engine struct {
|
|
cfg *config.Whisper
|
|
ctx whisper.Context
|
|
model whisper.Model
|
|
segments []whisper.Segment
|
|
progress int
|
|
runOpts RunOptions
|
|
}
|
|
|
|
func (e *Engine) Transcript() error {
|
|
return defaultPool.WithModel(e.cfg.Model, func(m whisper.Model) error {
|
|
return e.transcribeWithModel(m)
|
|
})
|
|
}
|
|
|
|
func (e *Engine) transcribeWithModel(model whisper.Model) error {
|
|
data, cleanup, err := prepareAudioPCM(e.cfg.AudioPath)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
defer cleanup()
|
|
e.model = model
|
|
e.ctx, err = e.model.NewContext()
|
|
if err != nil {
|
|
return err
|
|
}
|
|
e.ctx.SetThreads(e.cfg.Threads)
|
|
if e.cfg.SpeedUp {
|
|
e.ctx.SetAudioCtx(750)
|
|
}
|
|
e.ctx.SetTranslate(e.cfg.Translate)
|
|
if e.cfg.Prompt != "" {
|
|
e.ctx.SetInitialPrompt(e.cfg.Prompt)
|
|
}
|
|
e.ctx.SetMaxContext(int(e.cfg.MaxContext))
|
|
if e.cfg.Debug {
|
|
log.Info().Msgf("%s", e.ctx.SystemInfo())
|
|
}
|
|
if e.cfg.Language != "" {
|
|
_ = e.ctx.SetLanguage(e.cfg.Language)
|
|
}
|
|
if e.cfg.BeamSize > 0 {
|
|
e.ctx.SetBeamSize(int(e.cfg.BeamSize))
|
|
}
|
|
if e.cfg.EntropyThold > 0 {
|
|
e.ctx.SetEntropyThold(float32(e.cfg.EntropyThold))
|
|
}
|
|
if err := prepareVAD(&e.cfg.VAD, ""); err != nil {
|
|
return err
|
|
}
|
|
ApplyVAD(e.ctx, e.cfg.VAD)
|
|
log.Debug().Msg("start transcribe process")
|
|
e.ctx.ResetTimings()
|
|
if err := e.ctx.Process(data, e.cbEncoderBegin(), e.cbSegment(), e.cbProgress()); err != nil {
|
|
return err
|
|
}
|
|
if e.cfg.Debug {
|
|
e.ctx.PrintTimings()
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func (e *Engine) cbEncoderBegin() func() bool {
|
|
return func() bool { return true }
|
|
}
|
|
|
|
func (e *Engine) cbSegment() func(segment whisper.Segment) {
|
|
return func(segment whisper.Segment) {
|
|
e.segments = append(e.segments, segment)
|
|
if !e.cfg.PrintSegment {
|
|
return
|
|
}
|
|
log.Info().Msgf(
|
|
"[%6s -> %6s] %s",
|
|
segment.Start.Truncate(time.Millisecond),
|
|
segment.End.Truncate(time.Millisecond),
|
|
segment.Text,
|
|
)
|
|
}
|
|
}
|
|
|
|
func (e *Engine) cbProgress() func(progress int) {
|
|
return func(progress int) {
|
|
if progress > 100 {
|
|
progress = 100
|
|
}
|
|
if e.progress == progress {
|
|
return
|
|
}
|
|
e.progress = progress
|
|
if e.cfg.PrintProgress {
|
|
log.Info().Msgf("current progress: %d%%", progress)
|
|
}
|
|
}
|
|
}
|
|
|
|
func (e *Engine) getOutputPath(format string) string {
|
|
ext := filepath.Ext(e.cfg.AudioPath)
|
|
filename := filepath.Base(e.cfg.AudioPath)
|
|
if e.cfg.OutputFilename != "" {
|
|
filename = e.cfg.OutputFilename
|
|
}
|
|
folder := filepath.Dir(e.cfg.AudioPath)
|
|
if e.cfg.OutputFolder != "" {
|
|
folder = e.cfg.OutputFolder
|
|
}
|
|
return path.Join(folder, strings.TrimSuffix(filename, ext)+"."+format)
|
|
}
|
|
|
|
func (e *Engine) Save(format string) error {
|
|
outputPath := e.getOutputPath(format)
|
|
log.Info().Str("output-path", outputPath).Str("output-format", format).Msg("save text to file")
|
|
text := ""
|
|
switch OutputFormat(format) {
|
|
case FormatSrt:
|
|
for i, segment := range e.segments {
|
|
text += fmt.Sprintf("%d\n", i+1)
|
|
text += fmt.Sprintf("%s --> %s\n", srtTimestamp(segment.Start), srtTimestamp(segment.End))
|
|
text += segment.Text + "\n\n"
|
|
|
|
}
|
|
case FormatTxt:
|
|
for _, segment := range e.segments {
|
|
text += segment.Text
|
|
}
|
|
case FormatCSV:
|
|
text = "start,end,text\n"
|
|
for _, segment := range e.segments {
|
|
text += fmt.Sprintf("%s,%s,\"%s\"\n", segment.Start, segment.End, segment.Text)
|
|
}
|
|
}
|
|
if err := os.WriteFile(outputPath, []byte(text), 0o644); err != nil {
|
|
return err
|
|
}
|
|
return nil
|
|
}
|
|
|
|
type Word struct {
|
|
Word string `json:"word"`
|
|
Start int `json:"start"`
|
|
Stop int `json:"stop"`
|
|
}
|
|
|
|
type TranscriptResult struct {
|
|
Text string `json:"text"`
|
|
Words []Word `json:"words,omitempty"`
|
|
}
|
|
|
|
func (e *Engine) SetTranscriptText(text string) {
|
|
if len(e.segments) == 0 {
|
|
e.segments = []whisper.Segment{{Text: text}}
|
|
return
|
|
}
|
|
start := e.segments[0].Start
|
|
end := e.segments[len(e.segments)-1].End
|
|
e.segments = []whisper.Segment{{Text: text, Start: start, End: end}}
|
|
}
|
|
|
|
func (e *Engine) Result() TranscriptResult {
|
|
segments := e.segments
|
|
text := FormatSegments(segments, e.runOpts.Turns, e.runOpts.Format)
|
|
if e.runOpts.PunctuateRestore != nil && text != "" {
|
|
if updated, err := e.runOpts.PunctuateRestore(text); err == nil && strings.TrimSpace(updated) != "" {
|
|
text = updated
|
|
}
|
|
}
|
|
var words []Word
|
|
for _, segment := range segments {
|
|
words = append(words, segmentWords(segment)...)
|
|
}
|
|
return TranscriptResult{
|
|
Text: text,
|
|
Words: words,
|
|
}
|
|
}
|
|
|
|
func segmentWords(segment whisper.Segment) []Word {
|
|
parts := strings.Fields(strings.TrimSpace(segment.Text))
|
|
if len(parts) == 0 {
|
|
return nil
|
|
}
|
|
startMs := int(segment.Start / time.Millisecond)
|
|
endMs := int(segment.End / time.Millisecond)
|
|
if endMs < startMs {
|
|
endMs = startMs
|
|
}
|
|
span := endMs - startMs
|
|
if span <= 0 {
|
|
span = 1
|
|
}
|
|
step := span / len(parts)
|
|
if step < 1 {
|
|
step = 1
|
|
}
|
|
out := make([]Word, 0, len(parts))
|
|
for i, part := range parts {
|
|
wStart := startMs + i*step
|
|
wStop := wStart + step
|
|
if i == len(parts)-1 {
|
|
wStop = endMs
|
|
}
|
|
out = append(out, Word{
|
|
Word: part,
|
|
Start: wStart,
|
|
Stop: wStop,
|
|
})
|
|
}
|
|
return out
|
|
}
|
|
|
|
func (e *Engine) Close() error {
|
|
// Models are owned by ModelPool; do not close shared weights here.
|
|
e.ctx = nil
|
|
e.model = nil
|
|
return nil
|
|
}
|