Some checks failed
CodeQL / Analyze (go) (push) Successful in 6m28s
Docker Image / build-docker (push) Failing after 13m26s
Lint and Testing / lint (push) Successful in 11m17s
Lint and Testing / test (push) Successful in 11m17s
Lint and Testing / golangci (push) Successful in 2m40s
93 lines
2.2 KiB
Go
93 lines
2.2 KiB
Go
//go:build xlm
|
|
|
|
package spwrap
|
|
|
|
/*
|
|
#cgo CXXFLAGS: -std=c++17
|
|
#cgo LDFLAGS: -lsentencepiece
|
|
#include <stdlib.h>
|
|
#include "sp_wrap.h"
|
|
*/
|
|
import "C"
|
|
import (
|
|
"fmt"
|
|
"unsafe"
|
|
)
|
|
|
|
type Processor struct {
|
|
p *C.SPProcessor
|
|
}
|
|
|
|
func Load(path string) (*Processor, error) {
|
|
cpath := C.CString(path)
|
|
defer C.free(unsafe.Pointer(cpath))
|
|
var errMsg *C.char
|
|
p := C.sp_load(cpath, &errMsg)
|
|
if p == nil {
|
|
if errMsg != nil {
|
|
defer C.free(unsafe.Pointer(errMsg))
|
|
return nil, fmt.Errorf("sentencepiece: %s", C.GoString(errMsg))
|
|
}
|
|
return nil, fmt.Errorf("sentencepiece: failed to load %s", path)
|
|
}
|
|
return &Processor{p: p}, nil
|
|
}
|
|
|
|
func (proc *Processor) Close() {
|
|
if proc.p != nil {
|
|
C.sp_free(proc.p)
|
|
proc.p = nil
|
|
}
|
|
}
|
|
|
|
func (proc *Processor) BOSID() int {
|
|
return int(C.sp_bos_id(proc.p))
|
|
}
|
|
|
|
func (proc *Processor) EOSID() int {
|
|
return int(C.sp_eos_id(proc.p))
|
|
}
|
|
|
|
func (proc *Processor) PadID() int {
|
|
return int(C.sp_pad_id(proc.p))
|
|
}
|
|
|
|
func (proc *Processor) EncodeAsIDs(text string) ([]int, error) {
|
|
ctext := C.CString(text)
|
|
defer C.free(unsafe.Pointer(ctext))
|
|
var ids *C.int
|
|
var n C.int
|
|
var errMsg *C.char
|
|
if C.sp_encode(proc.p, ctext, &ids, &n, &errMsg) == 0 {
|
|
if errMsg != nil {
|
|
defer C.free(unsafe.Pointer(errMsg))
|
|
return nil, fmt.Errorf("sentencepiece encode: %s", C.GoString(errMsg))
|
|
}
|
|
return nil, fmt.Errorf("sentencepiece encode failed")
|
|
}
|
|
if ids == nil || n == 0 {
|
|
return nil, nil
|
|
}
|
|
defer C.free(unsafe.Pointer(ids))
|
|
out := make([]int, int(n))
|
|
slice := unsafe.Slice(ids, int(n))
|
|
for i := range out {
|
|
out[i] = int(slice[i])
|
|
}
|
|
return out, nil
|
|
}
|
|
|
|
func (proc *Processor) IDToPiece(id int) (string, error) {
|
|
var errMsg *C.char
|
|
piece := C.sp_id_to_piece(proc.p, C.int(id), &errMsg)
|
|
if piece == nil {
|
|
if errMsg != nil {
|
|
defer C.free(unsafe.Pointer(errMsg))
|
|
return "", fmt.Errorf("sentencepiece id to piece: %s", C.GoString(errMsg))
|
|
}
|
|
return "", fmt.Errorf("sentencepiece id to piece failed")
|
|
}
|
|
defer C.free(unsafe.Pointer(piece))
|
|
return C.GoString(piece), nil
|
|
}
|