Some checks failed
CodeQL / Analyze (go) (push) Successful in 6m28s
Docker Image / build-docker (push) Failing after 13m26s
Lint and Testing / lint (push) Successful in 11m17s
Lint and Testing / test (push) Successful in 11m17s
Lint and Testing / golangci (push) Successful in 2m40s
89 lines
2.1 KiB
C++
89 lines
2.1 KiB
C++
#include "sp_wrap.h"
|
|
|
|
#include <sentencepiece_processor.h>
|
|
|
|
#include <cstdlib>
|
|
#include <cstring>
|
|
#include <string>
|
|
#include <vector>
|
|
|
|
struct SPProcessor {
|
|
sentencepiece::SentencePieceProcessor proc;
|
|
};
|
|
|
|
static char *copy_err(const std::string &msg) {
|
|
char *out = static_cast<char *>(std::malloc(msg.size() + 1));
|
|
if (out != nullptr) {
|
|
std::memcpy(out, msg.c_str(), msg.size() + 1);
|
|
}
|
|
return out;
|
|
}
|
|
|
|
SPProcessor *sp_load(const char *path, char **err) {
|
|
if (err != nullptr) {
|
|
*err = nullptr;
|
|
}
|
|
auto *p = new SPProcessor();
|
|
const auto status = p->proc.Load(path);
|
|
if (!status.ok()) {
|
|
if (err != nullptr) {
|
|
*err = copy_err(status.ToString());
|
|
}
|
|
delete p;
|
|
return nullptr;
|
|
}
|
|
return p;
|
|
}
|
|
|
|
void sp_free(SPProcessor *p) { delete p; }
|
|
|
|
int sp_bos_id(const SPProcessor *p) { return p->proc.bos_id(); }
|
|
|
|
int sp_eos_id(const SPProcessor *p) { return p->proc.eos_id(); }
|
|
|
|
int sp_pad_id(const SPProcessor *p) { return p->proc.pad_id(); }
|
|
|
|
int sp_encode(const SPProcessor *p, const char *text, int **out_ids, int *out_len, char **err) {
|
|
if (err != nullptr) {
|
|
*err = nullptr;
|
|
}
|
|
if (out_ids != nullptr) {
|
|
*out_ids = nullptr;
|
|
}
|
|
if (out_len != nullptr) {
|
|
*out_len = 0;
|
|
}
|
|
std::vector<int> ids;
|
|
const auto status = p->proc.Encode(text, &ids);
|
|
if (!status.ok()) {
|
|
if (err != nullptr) {
|
|
*err = copy_err(status.ToString());
|
|
}
|
|
return 0;
|
|
}
|
|
if (ids.empty()) {
|
|
return 1;
|
|
}
|
|
int *buf = static_cast<int *>(std::malloc(sizeof(int) * ids.size()));
|
|
if (buf == nullptr) {
|
|
if (err != nullptr) {
|
|
*err = copy_err("malloc failed");
|
|
}
|
|
return 0;
|
|
}
|
|
for (size_t i = 0; i < ids.size(); i++) {
|
|
buf[i] = ids[i];
|
|
}
|
|
*out_ids = buf;
|
|
*out_len = static_cast<int>(ids.size());
|
|
return 1;
|
|
}
|
|
|
|
char *sp_id_to_piece(const SPProcessor *p, int id, char **err) {
|
|
if (err != nullptr) {
|
|
*err = nullptr;
|
|
}
|
|
const std::string piece = p->proc.IdToPiece(id);
|
|
return copy_err(piece);
|
|
}
|