Some checks failed
CodeQL / Analyze (go) (push) Successful in 6m28s
Docker Image / build-docker (push) Failing after 13m26s
Lint and Testing / lint (push) Successful in 11m17s
Lint and Testing / test (push) Successful in 11m17s
Lint and Testing / golangci (push) Successful in 2m40s
178 lines
3.6 KiB
Go
178 lines
3.6 KiB
Go
//go:build xlm
|
|
|
|
package punctuation
|
|
|
|
import (
|
|
"os"
|
|
"path/filepath"
|
|
"runtime"
|
|
"strings"
|
|
"sync"
|
|
|
|
ort "github.com/yalue/onnxruntime_go"
|
|
)
|
|
|
|
var (
|
|
ortOnce sync.Once
|
|
ortErr error
|
|
)
|
|
|
|
func ensureORT() error {
|
|
ortOnce.Do(func() {
|
|
if p := resolveONNXRuntimeLib(); p != "" {
|
|
ort.SetSharedLibraryPath(p)
|
|
}
|
|
ortErr = ort.InitializeEnvironment()
|
|
})
|
|
return ortErr
|
|
}
|
|
|
|
func resolveONNXRuntimeLib() string {
|
|
if p := strings.TrimSpace(os.Getenv("ONNXRUNTIME_SHARED_LIBRARY_PATH")); p != "" {
|
|
return p
|
|
}
|
|
for _, p := range onnxRuntimeCandidates() {
|
|
if st, err := os.Stat(p); err == nil && !st.IsDir() {
|
|
return p
|
|
}
|
|
}
|
|
return ""
|
|
}
|
|
|
|
func onnxRuntimeCandidates() []string {
|
|
arch := sherpaLibArch()
|
|
ver := sherpaLinuxModuleVersion()
|
|
var out []string
|
|
for _, root := range goModCacheRoots() {
|
|
if ver != "" {
|
|
out = append(out, filepath.Join(root,
|
|
"github.com/k2-fsa/sherpa-onnx-go-linux@"+ver,
|
|
"lib", arch, "libonnxruntime.so"))
|
|
}
|
|
}
|
|
if exe, err := os.Executable(); err == nil {
|
|
exeDir := filepath.Dir(exe)
|
|
out = append(out,
|
|
filepath.Join(exeDir, "libonnxruntime.so"),
|
|
filepath.Join(exeDir, "lib", "libonnxruntime.so"),
|
|
filepath.Join(exeDir, "..", "lib", "libonnxruntime.so"),
|
|
)
|
|
if modRoot := findModuleRoot(exeDir); modRoot != "" {
|
|
out = append(out, filepath.Join(modRoot, "lib", "libonnxruntime.so"))
|
|
}
|
|
}
|
|
return out
|
|
}
|
|
|
|
func goModCacheRoots() []string {
|
|
var roots []string
|
|
seen := map[string]struct{}{}
|
|
add := func(p string) {
|
|
p = strings.TrimSpace(p)
|
|
if p == "" {
|
|
return
|
|
}
|
|
if _, ok := seen[p]; ok {
|
|
return
|
|
}
|
|
seen[p] = struct{}{}
|
|
roots = append(roots, p)
|
|
}
|
|
add(os.Getenv("GOMODCACHE"))
|
|
if gopath := os.Getenv("GOPATH"); gopath != "" {
|
|
for _, gp := range filepath.SplitList(gopath) {
|
|
add(filepath.Join(gp, "pkg", "mod"))
|
|
}
|
|
}
|
|
if home, err := os.UserHomeDir(); err == nil {
|
|
add(filepath.Join(home, "go", "pkg", "mod"))
|
|
}
|
|
return roots
|
|
}
|
|
|
|
func findModuleRoot(start string) string {
|
|
dir := start
|
|
for i := 0; i < 8; i++ {
|
|
if _, err := os.Stat(filepath.Join(dir, "go.mod")); err == nil {
|
|
return dir
|
|
}
|
|
parent := filepath.Dir(dir)
|
|
if parent == dir {
|
|
break
|
|
}
|
|
dir = parent
|
|
}
|
|
if cwd, err := os.Getwd(); err == nil {
|
|
return findModuleRootFrom(cwd)
|
|
}
|
|
return ""
|
|
}
|
|
|
|
func findModuleRootFrom(dir string) string {
|
|
for i := 0; i < 8; i++ {
|
|
if _, err := os.Stat(filepath.Join(dir, "go.mod")); err == nil {
|
|
return dir
|
|
}
|
|
parent := filepath.Dir(dir)
|
|
if parent == dir {
|
|
break
|
|
}
|
|
dir = parent
|
|
}
|
|
return ""
|
|
}
|
|
|
|
func sherpaLinuxModuleVersion() string {
|
|
for _, dir := range []string{findModuleRootFrom(mustCwd()), ""} {
|
|
if dir == "" {
|
|
continue
|
|
}
|
|
if v := readSherpaVersion(filepath.Join(dir, "go.mod")); v != "" {
|
|
return v
|
|
}
|
|
}
|
|
if exe, err := os.Executable(); err == nil {
|
|
if root := findModuleRoot(filepath.Dir(exe)); root != "" {
|
|
return readSherpaVersion(filepath.Join(root, "go.mod"))
|
|
}
|
|
}
|
|
return ""
|
|
}
|
|
|
|
func readSherpaVersion(goModPath string) string {
|
|
data, err := os.ReadFile(goModPath)
|
|
if err != nil {
|
|
return ""
|
|
}
|
|
for _, line := range strings.Split(string(data), "\n") {
|
|
line = strings.TrimSpace(line)
|
|
if !strings.Contains(line, "github.com/k2-fsa/sherpa-onnx-go-linux") {
|
|
continue
|
|
}
|
|
fields := strings.Fields(line)
|
|
if len(fields) >= 2 {
|
|
return fields[1] // e.g. v1.13.2 — must match pkg/mod path @v1.13.2
|
|
}
|
|
}
|
|
return ""
|
|
}
|
|
|
|
func sherpaLibArch() string {
|
|
switch runtime.GOARCH {
|
|
case "arm64":
|
|
return "aarch64-unknown-linux-gnu"
|
|
case "arm":
|
|
return "arm-unknown-linux-gnueabihf"
|
|
default:
|
|
return "x86_64-unknown-linux-gnu"
|
|
}
|
|
}
|
|
|
|
func mustCwd() string {
|
|
cwd, err := os.Getwd()
|
|
if err != nil {
|
|
return ""
|
|
}
|
|
return cwd
|
|
}
|