admin b5c083e06f
Some checks failed
CodeQL / Analyze (go) (push) Successful in 6m28s
Docker Image / build-docker (push) Failing after 13m26s
Lint and Testing / lint (push) Successful in 11m17s
Lint and Testing / test (push) Successful in 11m17s
Lint and Testing / golangci (push) Successful in 2m40s
first commit
2026-06-04 18:10:52 +07:00

89 lines
2.1 KiB
C++

#include "sp_wrap.h"
#include <sentencepiece_processor.h>
#include <cstdlib>
#include <cstring>
#include <string>
#include <vector>
struct SPProcessor {
sentencepiece::SentencePieceProcessor proc;
};
static char *copy_err(const std::string &msg) {
char *out = static_cast<char *>(std::malloc(msg.size() + 1));
if (out != nullptr) {
std::memcpy(out, msg.c_str(), msg.size() + 1);
}
return out;
}
SPProcessor *sp_load(const char *path, char **err) {
if (err != nullptr) {
*err = nullptr;
}
auto *p = new SPProcessor();
const auto status = p->proc.Load(path);
if (!status.ok()) {
if (err != nullptr) {
*err = copy_err(status.ToString());
}
delete p;
return nullptr;
}
return p;
}
void sp_free(SPProcessor *p) { delete p; }
int sp_bos_id(const SPProcessor *p) { return p->proc.bos_id(); }
int sp_eos_id(const SPProcessor *p) { return p->proc.eos_id(); }
int sp_pad_id(const SPProcessor *p) { return p->proc.pad_id(); }
int sp_encode(const SPProcessor *p, const char *text, int **out_ids, int *out_len, char **err) {
if (err != nullptr) {
*err = nullptr;
}
if (out_ids != nullptr) {
*out_ids = nullptr;
}
if (out_len != nullptr) {
*out_len = 0;
}
std::vector<int> ids;
const auto status = p->proc.Encode(text, &ids);
if (!status.ok()) {
if (err != nullptr) {
*err = copy_err(status.ToString());
}
return 0;
}
if (ids.empty()) {
return 1;
}
int *buf = static_cast<int *>(std::malloc(sizeof(int) * ids.size()));
if (buf == nullptr) {
if (err != nullptr) {
*err = copy_err("malloc failed");
}
return 0;
}
for (size_t i = 0; i < ids.size(); i++) {
buf[i] = ids[i];
}
*out_ids = buf;
*out_len = static_cast<int>(ids.size());
return 1;
}
char *sp_id_to_piece(const SPProcessor *p, int id, char **err) {
if (err != nullptr) {
*err = nullptr;
}
const std::string piece = p->proc.IdToPiece(id);
return copy_err(piece);
}