2025-10-23 13:06:22 +07:00

447 lines
11 KiB
Go

package wav
import (
"bytes"
"encoding/binary"
"errors"
"fmt"
"io"
"io/ioutil"
"time"
"rtp-app/pkg/audio"
"rtp-app/pkg/riff"
)
var (
CIDList = [4]byte{'L', 'I', 'S', 'T'}
CIDSmpl = [4]byte{'s', 'm', 'p', 'l'}
CIDInfo = []byte{'I', 'N', 'F', 'O'}
CIDCue = [4]byte{'c', 'u', 'e', 0x20}
)
type Decoder struct {
r io.ReadSeeker
parser *riff.Parser
NumChans uint16
BitDepth uint16
SampleRate uint32
AvgBytesPerSec uint32
WavAudioFormat uint16
err error
PCMSize int
pcmDataAccessed bool
PCMChunk *riff.Chunk
Metadata *Metadata
}
func NewDecoder(r io.ReadSeeker) *Decoder {
return &Decoder{r: r, parser: riff.New(r)}
}
func (d *Decoder) Seek(offset int64, whence int) (int64, error) {
return d.r.Seek(offset, whence)
}
func (d *Decoder) Rewind() error {
_, err := d.r.Seek(0, io.SeekStart)
if err != nil {
return fmt.Errorf("failed to seek back to the start %w", err)
}
d.parser = riff.New(d.r)
d.pcmDataAccessed = false
d.PCMChunk = nil
d.err = nil
d.NumChans = 0
err = d.FwdToPCM()
if err != nil {
return fmt.Errorf("failed to seek to the PCM data: %w", err)
}
return nil
}
func (d *Decoder) SampleBitDepth() int32 {
if d == nil {
return 0
}
return int32(d.BitDepth)
}
func (d *Decoder) PCMLen() int64 {
if d == nil {
return 0
}
return int64(d.PCMSize)
}
func (d *Decoder) Err() error {
if d.err == io.EOF {
return nil
}
return d.err
}
func (d *Decoder) EOF() bool {
if d == nil || d.err == io.EOF {
return true
}
return false
}
func (d *Decoder) IsValidFile() bool {
d.err = d.readHeaders()
if d.err != nil {
return false
}
if d.NumChans < 1 {
return false
}
if d.BitDepth < 8 {
return false
}
if d, err := d.Duration(); err != nil || d <= 0 {
return false
}
return true
}
func (d *Decoder) ReadInfo() {
d.err = d.readHeaders()
}
func (d *Decoder) ReadMetadata() {
if d.Metadata != nil {
return
}
d.ReadInfo()
if d.Err() != nil || d.Metadata != nil {
return
}
var (
chunk *riff.Chunk
err error
)
for err == nil {
chunk, err = d.parser.NextChunk()
if err != nil {
break
}
switch chunk.ID {
case CIDList:
if err = DecodeListChunk(d, chunk); err != nil {
if err != io.EOF {
d.err = err
}
}
if d.Metadata != nil && d.Metadata.SamplerInfo != nil {
break
}
case CIDSmpl:
if err = DecodeSamplerChunk(d, chunk); err != nil {
if err != io.EOF {
d.err = err
}
}
case CIDCue:
if err = DecodeCueChunk(d, chunk); err != nil {
if err != io.EOF {
d.err = err
}
}
default:
chunk.Drain()
}
}
}
func (d *Decoder) FwdToPCM() error {
if d == nil {
return fmt.Errorf("Данные PCM не найдены")
}
d.err = d.readHeaders()
if d.err != nil {
return nil
}
var chunk *riff.Chunk
for d.err == nil {
chunk, d.err = d.NextChunk()
if d.err != nil {
return d.err
}
if chunk.ID == riff.DataFormatID {
d.PCMSize = chunk.Size
d.PCMChunk = chunk
break
}
if chunk.ID == CIDList {
DecodeListChunk(d, chunk)
}
chunk.Drain()
}
if chunk == nil {
return fmt.Errorf("Данные PCM не найдены")
}
d.pcmDataAccessed = true
return nil
}
func (d *Decoder) WasPCMAccessed() bool {
if d == nil {
return false
}
return d.pcmDataAccessed
}
func (d *Decoder) FullPCMBuffer() (*audio.IntBuffer, error) {
if !d.WasPCMAccessed() {
err := d.FwdToPCM()
if err != nil {
return nil, d.err
}
}
if d.PCMChunk == nil {
return nil, errors.New("Фрагмент PCM не найден")
}
format := &audio.Format{
NumChannels: int(d.NumChans),
SampleRate: int(d.SampleRate),
}
buf := &audio.IntBuffer{Data: make([]int, 4096), Format: format, SourceBitDepth: int(d.BitDepth)}
bytesPerSample := (d.BitDepth-1)/8 + 1
sampleBufData := make([]byte, bytesPerSample)
decodeF, err := sampleDecodeFunc(int(d.BitDepth))
if err != nil {
return nil, fmt.Errorf("не удалось получить функцию декодирования образца %v", err)
}
i := 0
for err == nil {
buf.Data[i], err = decodeF(d.PCMChunk, sampleBufData)
if err != nil {
break
}
i++
if i == len(buf.Data) {
buf.Data = append(buf.Data, make([]int, 4096)...)
}
}
buf.Data = buf.Data[:i]
if err == io.EOF {
err = nil
}
return buf, err
}
func (d *Decoder) PCMBuffer(buf *audio.IntBuffer) (n int, err error) {
if buf == nil {
return 0, nil
}
if !d.pcmDataAccessed {
err := d.FwdToPCM()
if err != nil {
return 0, d.err
}
}
if d.PCMChunk == nil {
return 0, ErrPCMChunkNotFound
}
format := &audio.Format{
NumChannels: int(d.NumChans),
SampleRate: int(d.SampleRate),
}
buf.SourceBitDepth = int(d.BitDepth)
decodeF, err := sampleDecodeFunc(int(d.BitDepth))
if err != nil {
return 0, fmt.Errorf("не удалось получить функцию декодирования образца %v", err)
}
bPerSample := bytesPerSample(int(d.BitDepth))
size := len(buf.Data) * bPerSample
tmpBuf := make([]byte, size)
var m int
m, err = d.PCMChunk.R.Read(tmpBuf)
if err != nil {
if err == io.EOF {
return m, nil
}
return m, err
}
if m == 0 {
return m, nil
}
bufR := bytes.NewReader(tmpBuf[:m])
sampleBuf := make([]byte, bPerSample, bPerSample)
var misaligned bool
if m%bPerSample > 0 {
misaligned = true
}
for n = 0; n < len(buf.Data); n++ {
buf.Data[n], err = decodeF(bufR, sampleBuf)
if err != nil {
if misaligned {
n--
}
break
}
}
buf.Format = format
if err == io.EOF {
err = nil
}
return n, err
}
func (d *Decoder) Format() *audio.Format {
if d == nil {
return nil
}
return &audio.Format{
NumChannels: int(d.NumChans),
SampleRate: int(d.SampleRate),
}
}
func (d *Decoder) NextChunk() (*riff.Chunk, error) {
if d.err = d.readHeaders(); d.err != nil {
d.err = fmt.Errorf("не удалось прочитать заголовок - %v", d.err)
return nil, d.err
}
var (
id [4]byte
size uint32
)
id, size, d.err = d.parser.IDnSize()
if d.err != nil {
d.err = fmt.Errorf("ошибка чтения заголовка фрагмента - %v", d.err)
return nil, d.err
}
if size%2 == 1 {
size++
}
c := &riff.Chunk{
ID: id,
Size: int(size),
R: io.LimitReader(d.r, int64(size)),
}
return c, d.err
}
func (d *Decoder) Duration() (time.Duration, error) {
if d == nil || d.parser == nil {
return 0, errors.New("не могу рассчитать продолжительность pointer=nil")
}
return d.parser.Duration()
}
func (d *Decoder) String() string {
return d.parser.String()
}
func (d *Decoder) readHeaders() error {
if d == nil || d.NumChans > 0 {
return nil
}
id, size, err := d.parser.IDnSize()
if err != nil {
return err
}
d.parser.ID = id
if d.parser.ID != riff.RiffID {
return fmt.Errorf("%s - %s", d.parser.ID, riff.ErrFmtNotSupported)
}
d.parser.Size = size
if err := binary.Read(d.r, binary.BigEndian, &d.parser.Format); err != nil {
return err
}
var chunk *riff.Chunk
var rewindBytes int64
for err == nil {
chunk, err = d.parser.NextChunk()
if err != nil {
break
}
if chunk.ID == riff.FmtID {
chunk.DecodeWavHeader(d.parser)
d.NumChans = d.parser.NumChannels
d.BitDepth = d.parser.BitsPerSample
d.SampleRate = d.parser.SampleRate
d.WavAudioFormat = d.parser.WavAudioFormat
d.AvgBytesPerSec = d.parser.AvgBytesPerSec
if rewindBytes > 0 {
d.r.Seek(-(rewindBytes + int64(chunk.Size) + 8), 1)
}
break
} else if chunk.ID == CIDList {
DecodeListChunk(d, chunk)
rewindBytes += int64(chunk.Size) + 8
} else if chunk.ID == CIDSmpl {
DecodeSamplerChunk(d, chunk)
rewindBytes += int64(chunk.Size) + 8
} else {
rewindBytes += int64(chunk.Size) + 8
io.CopyN(ioutil.Discard, d.r, int64(chunk.Size))
}
}
return d.err
}
func bytesPerSample(bitDepth int) int {
return bitDepth / 8
}
func sampleDecodeFunc(bitsPerSample int) (func(io.Reader, []byte) (int, error), error) {
switch bitsPerSample {
case 8:
return func(r io.Reader, buf []byte) (int, error) {
_, err := r.Read(buf[:1])
return int(buf[0]), err
}, nil
case 16:
return func(r io.Reader, buf []byte) (int, error) {
_, err := r.Read(buf[:2])
return int(int16(binary.LittleEndian.Uint16(buf[:2]))), err
}, nil
case 24:
return func(r io.Reader, buf []byte) (int, error) {
_, err := r.Read(buf[:3])
if err != nil {
return 0, err
}
return int(audio.Int24LETo32(buf[:3])), nil
}, nil
case 32:
return func(r io.Reader, buf []byte) (int, error) {
_, err := r.Read(buf[:4])
return int(int32(binary.LittleEndian.Uint32(buf[:4]))), err
}, nil
default:
return nil, fmt.Errorf("необработываемый битрейт:%d", bitsPerSample)
}
}
func sampleFloat64DecodeFunc(bitsPerSample int) (func([]byte) float64, error) {
bytesPerSample := bitsPerSample / 8
switch bytesPerSample {
case 1:
return func(s []byte) float64 {
return float64(uint8(s[0]))
}, nil
case 2:
return func(s []byte) float64 {
return float64(int(s[0]) + int(s[1])<<8)
}, nil
case 3:
return func(s []byte) float64 {
var output int32
output |= int32(s[2]) << 0
output |= int32(s[1]) << 8
output |= int32(s[0]) << 16
return float64(output)
}, nil
case 4:
return func(s []byte) float64 {
return float64(int(s[0]) + int(s[1])<<8 + int(s[2])<<16 + int(s[3])<<24)
}, nil
default:
return nil, fmt.Errorf("необработываемый битрейт:%d", bitsPerSample)
}
}