diff --git a/api/swagger.json b/api/swagger.json index 188ddf7..ba2fafb 100644 --- a/api/swagger.json +++ b/api/swagger.json @@ -1,143 +1,91 @@ { "swagger": "2.0", + "info": { + "title": "go-whisper-api", + "version": "1.0.0", + "description": "HTTP API распознавания речи на whisper.cpp: SPR (/spr/*) и OpenAI-совместимый STT (/v1/*). Swagger UI: GET /" + }, + "host": "localhost:8080", + "schemes": ["http"], "basePath": "/", + "consumes": ["application/json", "multipart/form-data"], + "produces": ["application/json", "text/plain", "audio/wav", "application/octet-stream"], + "tags": [ + { + "name": "spr", + "description": "Short Phrase Recognizer (асинхронная очередь, модели, waveform)" + }, + { + "name": "openai", + "description": "OpenAI Whisper API (синхронная транскрипция, Open WebUI)" + }, + { + "name": "meta", + "description": "Документация API" + } + ], "paths": { - "/spr/audio/{taskID}": { - "parameters": [ - { - "name": "taskID", - "in": "path", - "required": true, - "type": "string" - } - ], + "/swagger.json": { "get": { + "tags": ["meta"], + "summary": "OpenAPI/Swagger 2.0 спецификация", + "operationId": "getSwaggerJSON", + "produces": ["application/json"], "responses": { "200": { - "description": "Success" + "description": "Спецификация JSON", + "schema": { "type": "object" } } - }, - "operationId": "get_audio_stt", - "tags": [ - "spr" - ] - } - }, - "/spr/delete/{id}": { - "parameters": [ - { - "name": "id", - "in": "path", - "required": true, - "type": "string" } - ], - "delete": { - "responses": { - "200": { - "description": "Success" - } - }, - "operationId": "delete_model_delete", - "tags": [ - "spr" - ] - } - }, - "/spr/export/{id}": { - "parameters": [ - { - "name": "id", - "in": "path", - "required": true, - "type": "string" - } - ], - "get": { - "responses": { - "200": { - "description": "Success" - } - }, - "operationId": "get_model_export", - "tags": [ - "spr" - ] - } - }, - "/spr/hostname": { - "get": { - "responses": { - "200": { - "description": "Success" - } - }, - "operationId": "get_hostname_class", - "tags": [ - "spr" - ] - } - }, - "/spr/import/{id}": { - "parameters": [ - { - "name": "id", - "in": "path", - "required": true, - "type": "string" - } - ], - "post": { - "responses": { - "200": { - "description": "Success" - } - }, - "operationId": "post_model_import", - "parameters": [ - { - "name": "zip-model", - "in": "formData", - "type": "file", - "required": true, - "description": "prepared model zip file" - } - ], - "consumes": [ - "multipart/form-data" - ], - "tags": [ - "spr" - ] } }, "/spr/models": { "get": { + "tags": ["spr"], + "summary": "Список моделей STT", + "description": "Имена файлов `*.bin` в корне `api.models_dir` (без подкаталогов vad/punctuation).", + "operationId": "sprListModels", "responses": { "200": { - "description": "Success", - "schema": { - "$ref": "#/definitions/modelList" - } + "description": "OK", + "schema": { "$ref": "#/definitions/modelList" } + }, + "500": { + "description": "Ошибка чтения каталога", + "schema": { "$ref": "#/definitions/plainError" } } - }, - "operationId": "get_model_list", - "tags": [ - "spr" - ] + } + } + }, + "/spr/hostname": { + "get": { + "tags": ["spr"], + "summary": "Информация о хосте и сервере", + "operationId": "sprHostname", + "responses": { + "200": { + "description": "OK", + "schema": { "$ref": "#/definitions/hostnameInfo" } + } + } } }, "/spr/queue": { "get": { + "tags": ["spr"], + "summary": "Список задач в кэше", + "description": "Все задачи в `cache/waiting` и `cache/ready`: для каждого id — `created`, `status`.", + "operationId": "sprListQueue", "responses": { "200": { - "description": "Success" + "description": "OK", + "schema": { "$ref": "#/definitions/queueList" } + }, + "500": { + "description": "Ошибка", + "schema": { "$ref": "#/definitions/plainError" } } - }, - "operationId": "get_queue_stt", - "tags": [ - "spr" - ] + } } }, "/spr/queue/{taskID}": { @@ -146,29 +94,130 @@ "name": "taskID", "in": "path", "required": true, - "type": "string" + "type": "string", + "description": "UUID задачи" } ], - "delete": { + "get": { + "tags": ["spr"], + "summary": "Статус задачи", + "description": "При `status=ready` и каталоге в waiting — промоут в ready. `message=Success` когда готово.", + "operationId": "sprGetQueueTask", "responses": { "200": { - "description": "Success", - "schema": { - "type": "object", - "properties": { - "error": { "type": "integer" }, - "message": { "type": "string" } - } - } + "description": "OK", + "schema": { "$ref": "#/definitions/queueStatus" } }, "404": { - "description": "TaskNotFound" + "description": "Задача не найдена или ошибка транскрипции", + "schema": { "$ref": "#/definitions/apiError" } } - }, - "operationId": "delete_queue_del_stt", - "tags": [ - "spr" - ] + } + }, + "delete": { + "tags": ["spr"], + "summary": "Удалить задачу и каталог кэша", + "operationId": "sprDeleteQueueTask", + "responses": { + "200": { + "description": "OK", + "schema": { "$ref": "#/definitions/apiSuccess" } + }, + "404": { + "description": "TaskNotFound", + "schema": { "$ref": "#/definitions/apiError" } + } + } + } + }, + "/spr/stt/{id}": { + "parameters": [ + { + "name": "id", + "in": "path", + "required": true, + "type": "string", + "description": "ID модели (имя файла без .bin, например ggml-small)" + } + ], + "post": { + "tags": ["spr"], + "summary": "Транскрипция аудио", + "description": "Загрузка multipart: одно из полей `audio`, `wav`, `file`. Аудио конвертируется в 16 kHz mono WAV. По умолчанию async (`async=1`) — ответ `taskID`; при `async=0` — синхронный JSON с `text` и `words`.", + "operationId": "sprTranscribe", + "consumes": ["multipart/form-data"], + "parameters": [ + { + "name": "audio", + "in": "formData", + "type": "file", + "description": "Аудиофайл (wav, mp3, flac, ogg, m4a, mp4, aac)" + }, + { + "name": "wav", + "in": "formData", + "type": "file", + "description": "Алиас для audio" + }, + { + "name": "file", + "in": "formData", + "type": "file", + "description": "Алиас для audio" + }, + { + "name": "async", + "in": "query", + "type": "integer", + "enum": [0, 1], + "default": 1, + "description": "1 — поставить в очередь; 0 — синхронный ответ" + }, + { + "name": "language", + "in": "query", + "type": "string", + "description": "Язык распознавания (по умолчанию из config `api.language`, обычно ru). Значение `auto` — автоопределение whisper." + }, + { + "name": "punctuation", + "in": "query", + "type": "integer", + "enum": [0, 1], + "description": "Восстановление пунктуации (если `punctuation.enabled` в config)" + }, + { + "name": "speakers", + "in": "query", + "type": "integer", + "enum": [0, 1], + "description": "1 — диаризация и метки «Спикер N:» (нужен build-sherpa и модели)" + }, + { + "name": "speaker_counter", + "in": "query", + "type": "integer", + "description": "Число спикеров: 0 — авто, -1 — отключить диаризацию для запроса, N>0 — подсказка" + } + ], + "responses": { + "200": { + "description": "Синхронный результат или taskID (async)", + "schema": { "$ref": "#/definitions/sttResponse" } + }, + "400": { + "description": "Нет файла / неверные параметры", + "schema": { "$ref": "#/definitions/apiError" } + }, + "404": { + "description": "Модель не найдена", + "schema": { "$ref": "#/definitions/apiError" } + }, + "405": { + "description": "Ошибка транскрипции (sync)", + "schema": { "$ref": "#/definitions/apiError" } + } + } } }, "/spr/result/{taskID}": { @@ -181,176 +230,42 @@ } ], "get": { + "tags": ["spr"], + "summary": "Результат async-задачи", + "operationId": "sprGetResult", "responses": { - "404": { - "description": "Not found", - "schema": { - "$ref": "#/definitions/error" - } - }, "200": { - "description": "Success", - "schema": { - "$ref": "#/definitions/resultTTS" - } + "description": "waiting/processing/ready", + "schema": { "$ref": "#/definitions/resultResponse" } + }, + "404": { + "description": "TaskNotFound или ошибка", + "schema": { "$ref": "#/definitions/apiError" } } - }, - "operationId": "get_result_stt", - "tags": [ - "spr" - ] + } } }, - "/spr/stt/{id}": { - "post": { + "/spr/audio/{taskID}": { + "parameters": [ + { + "name": "taskID", + "in": "path", + "required": true, + "type": "string" + } + ], + "get": { + "tags": ["spr"], + "summary": "WAV задачи (16 kHz mono)", + "operationId": "sprGetTaskAudio", + "produces": ["audio/wav"], "responses": { - "405": { - "description": "Error", - "schema": { - "$ref": "#/definitions/error" - } - }, + "200": { "description": "Файл audio.wav" }, "404": { - "description": "Not found", - "schema": { - "$ref": "#/definitions/error" - } - }, - "200": { - "description": "Success", - "schema": { - "$ref": "#/definitions/modelTTS" - } + "description": "Задача не найдена", + "schema": { "$ref": "#/definitions/plainError" } } - }, - "operationId": "post_model_test", - "parameters": [ - { - "name": "id", - "in": "path", - "required": true, - "type": "string", - "description": "NN Model ID" - }, - { - "name": "wav", - "in": "formData", - "type": "file", - "description": "file to recognize" - }, - { - "name": "async", - "in": "query", - "type": "integer", - "description": "async mode (default 1: enqueue to cache/waiting; 0: sync text response)", - "default": 1, - "enum": [ - 0, - 1 - ] - }, - { - "name": "speakers", - "in": "query", - "type": "integer", - "description": "find speakers", - "default": 0, - "enum": [ - 0, - 1 - ] - }, - { - "name": "speaker_counter", - "in": "query", - "type": "integer", - "description": "number of speakers. 0 for autodetect. -1 disable speaker detection.", - "default": 0 - }, - { - "name": "normalization", - "in": "query", - "type": "integer", - "description": "normalize text", - "default": 1, - "enum": [ - 0, - 1 - ] - }, - { - "name": "punctuation", - "in": "query", - "type": "integer", - "description": "punctuate text", - "default": 1, - "enum": [ - 0, - 1 - ] - }, - { - "name": "toxicity", - "in": "query", - "type": "integer", - "description": "toxicity analyzer", - "default": 1, - "enum": [ - 0, - 1 - ] - }, - { - "name": "emotion", - "in": "query", - "type": "integer", - "description": "emotion analyzer", - "default": 0, - "enum": [ - 0, - 1 - ] - }, - { - "name": "voice_analyzer", - "in": "query", - "type": "integer", - "description": "voice analyzer", - "default": 1, - "enum": [ - 0, - 1 - ] - }, - { - "name": "vad", - "in": "query", - "type": "string", - "description": "VAD type", - "default": "webrtc", - "enum": [ - "webrtc" - ] - }, - { - "name": "classifiers", - "in": "query", - "type": "string", - "description": "JSON with classification models to process each sentence" - }, - { - "name": "webhook", - "in": "query", - "type": "string", - "description": "webhook url to send stt async result" - } - ], - "consumes": [ - "multipart/form-data" - ], - "tags": [ - "spr" - ] + } } }, "/spr/waveform/{taskID}": { @@ -363,95 +278,396 @@ } ], "get": { + "tags": ["spr"], + "summary": "Пики waveform для UI", + "operationId": "sprGetWaveform", "responses": { "200": { - "description": "Success" + "description": "OK", + "schema": { "$ref": "#/definitions/waveformResponse" } + }, + "400": { + "description": "Неверный taskID", + "schema": { "$ref": "#/definitions/apiError" } + }, + "500": { + "description": "Ошибка чтения audio.json", + "schema": { "$ref": "#/definitions/plainError" } } - }, - "operationId": "get_audioarray_stt", - "tags": [ - "spr" - ] + } + } + }, + "/spr/import/{id}": { + "parameters": [ + { + "name": "id", + "in": "path", + "required": true, + "type": "string", + "description": "ID новой модели (имя файла без расширения)" + } + ], + "post": { + "tags": ["spr"], + "summary": "Загрузить модель .bin", + "operationId": "sprImportModel", + "consumes": ["multipart/form-data"], + "parameters": [ + { + "name": "zip-model", + "in": "formData", + "type": "file", + "required": true, + "description": "Файл модели ggml (*.bin). ZIP не поддерживается." + }, + { + "name": "model", + "in": "formData", + "type": "file", + "description": "Алиас для zip-model" + } + ], + "responses": { + "200": { "description": "Модель сохранена" }, + "400": { + "description": "Нет файла / zip / ошибка записи", + "schema": { "$ref": "#/definitions/apiError" } + } + } + } + }, + "/spr/export/{id}": { + "parameters": [ + { + "name": "id", + "in": "path", + "required": true, + "type": "string" + } + ], + "get": { + "tags": ["spr"], + "summary": "Скачать модель .bin", + "operationId": "sprExportModel", + "produces": ["application/octet-stream"], + "responses": { + "200": { "description": "Бинарный файл модели" }, + "404": { + "description": "Модель не найдена", + "schema": { "$ref": "#/definitions/apiError" } + } + } + } + }, + "/spr/delete/{id}": { + "parameters": [ + { + "name": "id", + "in": "path", + "required": true, + "type": "string" + } + ], + "delete": { + "tags": ["spr"], + "summary": "Удалить файл модели", + "operationId": "sprDeleteModel", + "responses": { + "200": { "description": "OK" }, + "404": { + "description": "Модель не найдена", + "schema": { "$ref": "#/definitions/apiError" } + } + } + } + }, + "/v1/models": { + "get": { + "tags": ["openai"], + "summary": "Список моделей (OpenAI format)", + "operationId": "openaiListModels", + "responses": { + "200": { + "description": "OK", + "schema": { "$ref": "#/definitions/openAIModelList" } + }, + "500": { + "description": "server_error", + "schema": { "$ref": "#/definitions/openAIError" } + } + } + } + }, + "/v1/audio/transcriptions": { + "post": { + "tags": ["openai"], + "summary": "Транскрипция (OpenAI Whisper)", + "description": "Всегда синхронно. Поле `model`: id локальной модели или `whisper-1` (маппинг через `api.default_model`). Пунктуация: query `?punctuation=1`.", + "operationId": "openaiTranscribe", + "consumes": ["multipart/form-data"], + "parameters": [ + { + "name": "file", + "in": "formData", + "type": "file", + "required": true, + "description": "Аудио (алиасы: audio, wav)" + }, + { + "name": "model", + "in": "formData", + "type": "string", + "required": true, + "description": "Например whisper-1 или ggml-large-v3-turbo" + }, + { + "name": "language", + "in": "formData", + "type": "string", + "description": "Код языка (ru, en, auto)" + }, + { + "name": "response_format", + "in": "formData", + "type": "string", + "enum": ["json", "text"], + "default": "json", + "description": "json — {\"text\":\"...\"}; text — plain body" + }, + { + "name": "punctuation", + "in": "query", + "type": "integer", + "enum": [0, 1], + "description": "Пунктуация (как в SPR)" + } + ], + "responses": { + "200": { + "description": "Транскрипт", + "schema": { "$ref": "#/definitions/openAITranscription" } + }, + "400": { + "description": "invalid_request_error", + "schema": { "$ref": "#/definitions/openAIError" } + }, + "500": { + "description": "server_error", + "schema": { "$ref": "#/definitions/openAIError" } + } + } + } + }, + "/v1/audio/transcriptions/": { + "post": { + "tags": ["openai"], + "summary": "Транскрипция (trailing slash)", + "operationId": "openaiTranscribeSlash", + "consumes": ["multipart/form-data"], + "parameters": [ + { "$ref": "#/parameters/openAIFile" }, + { "$ref": "#/parameters/openAIModel" } + ], + "responses": { + "200": { + "description": "Транскрипт", + "schema": { "$ref": "#/definitions/openAITranscription" } + }, + "400": { + "schema": { "$ref": "#/definitions/openAIError" } + }, + "500": { + "schema": { "$ref": "#/definitions/openAIError" } + } + } } } }, - "info": { - "title": "go-whisper-api", - "version": "5.008 release", - "description": "Whisper speech-to-text API (SPR-compatible)" - }, - "produces": [ - "application/json" - ], - "consumes": [ - "application/json" - ], - "tags": [ - { - "name": "spr", - "description": "NN Model operations" + "parameters": { + "openAIFile": { + "name": "file", + "in": "formData", + "type": "file", + "required": true + }, + "openAIModel": { + "name": "model", + "in": "formData", + "type": "string", + "required": true } - ], + }, "definitions": { "modelList": { + "type": "object", "properties": { "models": { "type": "array", - "items": { + "items": { "type": "string" }, + "description": "ID моделей" + } + } + }, + "hostnameInfo": { + "type": "object", + "properties": { + "error": { "type": "integer", "example": 0 }, + "message": { "type": "string", "example": "Success" }, + "hostname": { "type": "string" }, + "version": { "type": "string", "example": "go-whisper-api" }, + "cwd": { "type": "string" }, + "models": { "type": "string", "description": "api.models_dir" }, + "cache": { "type": "string", "description": "api.cache_dir" } + } + }, + "queueList": { + "type": "object", + "additionalProperties": { + "type": "object", + "properties": { + "created": { "type": "string", "description": "2006-01-02 15:04:05" }, + "status": { "type": "string", - "description": "NN Model ID" + "enum": ["waiting", "processing", "ready", "error"] } } }, - "type": "object" + "description": "Ключ — taskID" }, - "error": { + "queueStatus": { + "type": "object", "properties": { - "error": { - "type": "integer", - "description": "Error flag" - }, - "message": { - "type": "string", - "description": "Error description" - } - }, - "type": "object" + "error": { "type": "integer" }, + "message": { "type": "string", "description": "Success, waiting, processing, …" }, + "status": { "type": "string" } + } }, - "modelTTS": { - "required": [ - "text" - ], + "apiSuccess": { + "type": "object", "properties": { - "text": { - "type": "string", - "description": "Recognized text" - } - }, - "type": "object" + "error": { "type": "integer", "example": 0 }, + "message": { "type": "string", "example": "Success" } + } }, - "resultTTS": { + "apiError": { + "type": "object", "properties": { + "error": { "type": "integer", "example": 1 }, + "message": { "type": "string" } + } + }, + "plainError": { + "type": "string", + "description": "Текст ошибки (http.Error)" + }, + "sttResponse": { + "type": "object", + "description": "async: только taskID; sync: model, text, words", + "properties": { + "taskID": { "type": "string" }, "model": { "type": "string" }, "text": { "type": "string" }, - "words": { "type": "array" }, - "toxicity": { "type": "object" }, - "emotion": { "type": "object" }, - "voice_analysis": { "type": "object" }, - "status": { "type": "string" }, + "words": { + "type": "array", + "items": { "$ref": "#/definitions/wordToken" } + } + } + }, + "resultResponse": { + "type": "object", + "properties": { + "status": { + "type": "string", + "enum": ["waiting", "processing", "ready", "error"], + "description": "Краткий статус (не готово)" + }, + "model": { "type": "string" }, + "text": { "type": "string" }, + "words": { + "type": "array", + "items": { "$ref": "#/definitions/wordToken" } + }, "taskID": { "type": "string" }, "created": { "type": "string" }, - "processed": { "type": "string" } - }, - "type": "object" - } - }, - "responses": { - "ParseError": { - "description": "When a mask can't be parsed" + "processed": { "type": "string" }, + "toxicity": { "$ref": "#/definitions/toxicityStub" }, + "emotion": { "type": "object" }, + "voice_analysis": { "type": "object" } + } }, - "MaskError": { - "description": "When any error occurs on mask" + "wordToken": { + "type": "object", + "properties": { + "word": { "type": "string" }, + "start": { "type": "integer", "description": "мс" }, + "stop": { "type": "integer", "description": "мс" } + } + }, + "toxicityStub": { + "type": "object", + "description": "Заглушка SPR (всегда нули)", + "properties": { + "insult": { "type": "number" }, + "obscenity": { "type": "number" }, + "threat": { "type": "number" }, + "politeness": { "type": "number" } + } + }, + "waveformResponse": { + "type": "object", + "properties": { + "error": { "type": "integer", "example": 0 }, + "waveform": { + "type": "array", + "items": { "type": "number" }, + "description": "Пики для отрисовки" + } + } + }, + "openAIModelList": { + "type": "object", + "properties": { + "object": { "type": "string", "example": "list" }, + "data": { + "type": "array", + "items": { "$ref": "#/definitions/openAIModel" } + } + } + }, + "openAIModel": { + "type": "object", + "properties": { + "id": { "type": "string" }, + "object": { "type": "string", "example": "model" }, + "created": { "type": "integer" }, + "owned_by": { "type": "string", "example": "go-whisper-api" } + } + }, + "openAITranscription": { + "type": "object", + "properties": { + "text": { "type": "string" } + } + }, + "openAIError": { + "type": "object", + "properties": { + "error": { + "type": "object", + "properties": { + "message": { "type": "string" }, + "type": { + "type": "string", + "enum": [ + "invalid_request_error", + "authentication_error", + "not_found_error", + "server_error" + ] + } + } + } + } } } -} \ No newline at end of file +} diff --git a/internal/apidoc/swagger_test.go b/internal/apidoc/swagger_test.go new file mode 100644 index 0000000..0cf24f3 --- /dev/null +++ b/internal/apidoc/swagger_test.go @@ -0,0 +1,131 @@ +package apidoc_test + +import ( + "encoding/json" + "os" + "path/filepath" + "strings" + "testing" +) + +func swaggerPath(t *testing.T) string { + t.Helper() + path := filepath.Join("..", "..", "api", "swagger.json") + abs, err := filepath.Abs(path) + if err != nil { + t.Fatal(err) + } + if _, err := os.Stat(abs); err != nil { + t.Fatal(err) + } + return abs +} + +var documentedPaths = []struct { + path string + method string +}{ + {"/swagger.json", "get"}, + {"/spr/models", "get"}, + {"/spr/hostname", "get"}, + {"/spr/queue", "get"}, + {"/spr/queue/{taskID}", "get"}, + {"/spr/queue/{taskID}", "delete"}, + {"/spr/stt/{id}", "post"}, + {"/spr/result/{taskID}", "get"}, + {"/spr/audio/{taskID}", "get"}, + {"/spr/waveform/{taskID}", "get"}, + {"/spr/import/{id}", "post"}, + {"/spr/export/{id}", "get"}, + {"/spr/delete/{id}", "delete"}, + {"/v1/models", "get"}, + {"/v1/audio/transcriptions", "post"}, + {"/v1/audio/transcriptions/", "post"}, +} + +func TestSwaggerJSON_validAndComplete(t *testing.T) { + data, err := os.ReadFile(swaggerPath(t)) + if err != nil { + t.Fatal(err) + } + var spec struct { + Swagger string `json:"swagger"` + Paths map[string]map[string]json.RawMessage `json:"paths"` + Info struct { + Title string `json:"title"` + } `json:"info"` + } + if err := json.Unmarshal(data, &spec); err != nil { + t.Fatalf("invalid JSON: %v", err) + } + if spec.Swagger != "2.0" { + t.Fatalf("swagger version: got %q", spec.Swagger) + } + for _, want := range documentedPaths { + methods, ok := spec.Paths[want.path] + if !ok { + t.Errorf("missing path %s", want.path) + continue + } + if _, ok := methods[want.method]; !ok { + t.Errorf("path %s missing method %s", want.path, want.method) + } + } +} + +func TestSwaggerJSON_sttQueryParamsMatchImplementation(t *testing.T) { + data, err := os.ReadFile(swaggerPath(t)) + if err != nil { + t.Fatal(err) + } + var spec struct { + Paths map[string]struct { + Post struct { + Parameters []struct { + Name string `json:"name"` + In string `json:"in"` + } `json:"parameters"` + } `json:"post"` + } `json:"paths"` + } + if err := json.Unmarshal(data, &spec); err != nil { + t.Fatal(err) + } + stt := spec.Paths["/spr/stt/{id}"] + var queryNames []string + for _, p := range stt.Post.Parameters { + if p.In == "query" { + queryNames = append(queryNames, p.Name) + } + } + for _, required := range []string{"async", "language", "punctuation", "speakers", "speaker_counter"} { + if !contains(queryNames, required) { + t.Errorf("STT missing query param %q", required) + } + } + for _, removed := range []string{"webhook", "toxicity", "normalization", "vad", "classifiers"} { + if contains(queryNames, removed) { + t.Errorf("STT documents unused param %q", removed) + } + } +} + +func TestSwaggerJSON_hasOpenAIPaths(t *testing.T) { + data, err := os.ReadFile(swaggerPath(t)) + if err != nil { + t.Fatal(err) + } + s := string(data) + if !strings.Contains(s, "/v1/audio/transcriptions") { + t.Fatal("missing OpenAI transcription path") + } +} + +func contains(ss []string, s string) bool { + for _, x := range ss { + if x == s { + return true + } + } + return false +} diff --git a/punctuation/heuristic.go b/punctuation/heuristic.go index 441b931..8c3a2d7 100644 --- a/punctuation/heuristic.go +++ b/punctuation/heuristic.go @@ -69,14 +69,6 @@ func heuristicEN(s string) string { return s } -func hasTerminalPunct(s string) bool { - s = strings.TrimSpace(s) - if s == "" { - return false - } - r, _ := utf8.DecodeLastRuneInString(s) - return r == '.' || r == '?' || r == '!' || r == '…' -} func ensureTerminalPunct(s string) string { if hasTerminalPunct(s) { diff --git a/punctuation/normalize.go b/punctuation/normalize.go new file mode 100644 index 0000000..27ac926 --- /dev/null +++ b/punctuation/normalize.go @@ -0,0 +1,70 @@ +package punctuation + +import ( + "strings" + "unicode" + "unicode/utf8" +) + +// terminalPunctRunes — знаки, после которых не добавляем ещё одну фразовую точку. +var terminalPunctRunes = map[rune]bool{ + '.': true, '?': true, '!': true, '…': true, + ',': true, ';': true, ':': true, + ')': true, ']': true, '"': true, '\'': true, + '»': true, '”': true, '’': true, + '。': true, ',': true, '?': true, '!': true, +} + +// CleanExcessive collapses duplicate and conflicting punctuation marks. +func CleanExcessive(s string) string { + s = strings.TrimSpace(s) + if s == "" { + return s + } + var b strings.Builder + b.Grow(len(s)) + prevClass := 0 // 0 none, 1 comma-like, 2 end, 3 other punct + for i := 0; i < len(s); { + r, size := utf8.DecodeRuneInString(s[i:]) + cls := punctClass(r) + if cls != 0 && cls == prevClass { + i += size + continue + } + if cls == 2 && prevClass == 1 { + // drop sentence end right after comma-like (e.g. "привет,.") + i += size + continue + } + b.WriteRune(r) + if cls != 0 { + prevClass = cls + } else if !unicode.IsSpace(r) { + prevClass = 0 + } + i += size + } + return strings.TrimSpace(b.String()) +} + +func punctClass(r rune) int { + switch r { + case ',', ',', '、', '،', ';', '؛', ':': + return 1 + case '.', '?', '!', '…', '。', '?', '!': + return 2 + } + if unicode.IsPunct(r) { + return 3 + } + return 0 +} + +func hasTerminalPunct(s string) bool { + s = strings.TrimSpace(s) + if s == "" { + return false + } + r, _ := utf8.DecodeLastRuneInString(s) + return terminalPunctRunes[r] +} diff --git a/punctuation/normalize_test.go b/punctuation/normalize_test.go new file mode 100644 index 0000000..5d93233 --- /dev/null +++ b/punctuation/normalize_test.go @@ -0,0 +1,46 @@ +package punctuation + +import ( + "context" + "strings" + "testing" +) + +func TestCleanExcessive(t *testing.T) { + cases := []struct { + in, want string + }{ + {"привет,,", "привет,"}, + {"привет,.", "привет,"}, + {"hello..", "hello."}, + {"what??", "what?"}, + {"ok!!!", "ok!"}, + {"а. б. в.", "а. б. в."}, + } + for _, tc := range cases { + got := CleanExcessive(tc.in) + if got != tc.want { + t.Errorf("CleanExcessive(%q) = %q, want %q", tc.in, got, tc.want) + } + } +} + +func TestHasTerminalPunct_comma(t *testing.T) { + if !hasTerminalPunct("привет,") { + t.Fatal("comma should count as terminal for heuristic") + } + if hasTerminalPunct("привет") { + t.Fatal("bare word should not") + } +} + +func TestHeuristic_noCommaPeriod(t *testing.T) { + h := Heuristic{} + out, err := h.Restore(context.Background(), "привет, мир", "ru") + if err != nil { + t.Fatal(err) + } + if strings.Contains(out, ",.") { + t.Fatalf("unexpected comma+period: %q", out) + } +} diff --git a/punctuation/punctuation.go b/punctuation/punctuation.go index 3a995fb..c04d3b7 100644 --- a/punctuation/punctuation.go +++ b/punctuation/punctuation.go @@ -116,7 +116,11 @@ func Apply(ctx context.Context, r Restorer, enabled bool, text, language string) if text == "" { return text, nil } - return r.Restore(ctx, text, language) + out, err := r.Restore(ctx, text, language) + if err != nil { + return "", err + } + return CleanExcessive(out), nil } func Close(r Restorer) { diff --git a/whisper/format.go b/whisper/format.go index 4dfd7f7..22e86b9 100644 --- a/whisper/format.go +++ b/whisper/format.go @@ -135,7 +135,8 @@ func min32(a, b float32) float32 { return b } -// PunctuateSegments runs punctuation on each segment separately (preserves line breaks). +// PunctuateSegments runs punctuation per Whisper segment (legacy helper). +// Prefer punctuating the full transcript after FormatSegments (see Engine.Result). func PunctuateSegments(segments []wpkg.Segment, restore func(text string) (string, error)) ([]wpkg.Segment, error) { out := make([]wpkg.Segment, len(segments)) copy(out, segments) diff --git a/whisper/whisper.go b/whisper/whisper.go index 24c7b9f..97cae5b 100644 --- a/whisper/whisper.go +++ b/whisper/whisper.go @@ -186,13 +186,12 @@ func (e *Engine) SetTranscriptText(text string) { func (e *Engine) Result() TranscriptResult { segments := e.segments - if e.runOpts.PunctuateRestore != nil { - updated, err := PunctuateSegments(segments, e.runOpts.PunctuateRestore) - if err == nil { - segments = updated + text := FormatSegments(segments, e.runOpts.Turns, e.runOpts.Format) + if e.runOpts.PunctuateRestore != nil && text != "" { + if updated, err := e.runOpts.PunctuateRestore(text); err == nil && strings.TrimSpace(updated) != "" { + text = updated } } - text := FormatSegments(segments, e.runOpts.Turns, e.runOpts.Format) var words []Word for _, segment := range segments { words = append(words, segmentWords(segment)...)