720 lines
22 KiB
C++
720 lines
22 KiB
C++
#include "binding.h"
|
|
|
|
#include "common.h"
|
|
#include "llama.h"
|
|
#include "sampling.h"
|
|
|
|
#include <algorithm>
|
|
#include <cstdio>
|
|
#include <cstring>
|
|
#include <fstream>
|
|
#include <regex>
|
|
#include <sstream>
|
|
#include <string>
|
|
#include <vector>
|
|
|
|
struct llama_binding_state {
|
|
common_init_result_ptr init;
|
|
llama_model * model = nullptr;
|
|
llama_context * ctx = nullptr;
|
|
common_sampler * smpl = nullptr;
|
|
bool embeddings = false;
|
|
};
|
|
|
|
static llama_binding_state * binding_state(void * state_pr) {
|
|
return static_cast<llama_binding_state *>(state_pr);
|
|
}
|
|
|
|
static void parse_tensor_split(const char * tensorsplit, float * out, size_t n) {
|
|
for (size_t i = 0; i < n; ++i) {
|
|
out[i] = 0.0f;
|
|
}
|
|
if (tensorsplit == nullptr || tensorsplit[0] == '\0') {
|
|
return;
|
|
}
|
|
std::string arg_next = tensorsplit;
|
|
const std::regex regex{R"([,/]+)"};
|
|
std::sregex_token_iterator it{arg_next.begin(), arg_next.end(), regex, -1};
|
|
std::vector<std::string> split_arg{it, {}};
|
|
for (size_t i = 0; i < split_arg.size() && i < n; ++i) {
|
|
out[i] = std::stof(split_arg[i]);
|
|
}
|
|
}
|
|
|
|
static void apply_model_load_options(
|
|
common_params & params,
|
|
int n_ctx,
|
|
int n_seed,
|
|
bool memory_f16,
|
|
bool mlock,
|
|
bool embeddings,
|
|
bool mmap,
|
|
int n_gpu,
|
|
int n_batch,
|
|
const char * maingpu,
|
|
const char * tensorsplit,
|
|
bool numa,
|
|
float rope_freq_base,
|
|
float rope_freq_scale,
|
|
const char * lora,
|
|
const char * lora_base,
|
|
bool perplexity) {
|
|
(void) lora_base;
|
|
|
|
if (n_ctx > 0) {
|
|
params.n_ctx = n_ctx;
|
|
}
|
|
if (n_seed >= 0) {
|
|
params.sampling.seed = (uint32_t) n_seed;
|
|
}
|
|
params.use_mlock = mlock;
|
|
params.embedding = embeddings;
|
|
params.use_mmap = mmap;
|
|
params.n_gpu_layers = n_gpu;
|
|
params.n_batch = n_batch > 0 ? n_batch : params.n_batch;
|
|
params.n_ubatch = std::min(params.n_batch, params.n_ubatch);
|
|
params.numa = numa ? GGML_NUMA_STRATEGY_DISTRIBUTE : GGML_NUMA_STRATEGY_DISABLED;
|
|
params.warmup = false;
|
|
params.fit_params = false;
|
|
|
|
if (rope_freq_base > 0.0f) {
|
|
params.rope_freq_base = rope_freq_base;
|
|
}
|
|
if (rope_freq_scale > 0.0f) {
|
|
params.rope_freq_scale = rope_freq_scale;
|
|
}
|
|
|
|
if (memory_f16) {
|
|
params.cache_type_k = GGML_TYPE_F16;
|
|
params.cache_type_v = GGML_TYPE_F16;
|
|
}
|
|
|
|
if (maingpu != nullptr && maingpu[0] != '\0') {
|
|
params.main_gpu = std::stoi(maingpu);
|
|
}
|
|
|
|
parse_tensor_split(tensorsplit, params.tensor_split, sizeof(params.tensor_split) / sizeof(params.tensor_split[0]));
|
|
|
|
if (perplexity) {
|
|
params.compute_ppl = true;
|
|
}
|
|
|
|
if (lora != nullptr && lora[0] != '\0') {
|
|
common_adapter_lora_info la;
|
|
la.path = lora;
|
|
la.scale = 1.0f;
|
|
params.lora_adapters.push_back(la);
|
|
}
|
|
}
|
|
|
|
static bool check_antiprompt(
|
|
const std::string & output,
|
|
const std::vector<std::string> & antiprompt,
|
|
bool interactive) {
|
|
for (const auto & ap : antiprompt) {
|
|
if (ap.empty()) {
|
|
continue;
|
|
}
|
|
const size_t extra = interactive ? 0 : 2;
|
|
const size_t search_start = output.length() > ap.length() + extra
|
|
? output.length() - ap.length() - extra
|
|
: 0;
|
|
if (output.find(ap, search_start) != std::string::npos) {
|
|
return true;
|
|
}
|
|
}
|
|
return false;
|
|
}
|
|
|
|
extern "C" {
|
|
|
|
void * load_model(
|
|
const char * fname,
|
|
int n_ctx,
|
|
int n_seed,
|
|
bool memory_f16,
|
|
bool mlock,
|
|
bool embeddings,
|
|
bool mmap,
|
|
bool low_vram,
|
|
int n_gpu,
|
|
int n_batch,
|
|
const char * maingpu,
|
|
const char * tensorsplit,
|
|
bool numa,
|
|
float rope_freq_base,
|
|
float rope_freq_scale,
|
|
bool mul_mat_q,
|
|
const char * lora,
|
|
const char * lora_base,
|
|
bool perplexity) {
|
|
(void) low_vram;
|
|
(void) mul_mat_q;
|
|
|
|
common_init();
|
|
llama_backend_init();
|
|
|
|
common_params params;
|
|
params.model.path = fname;
|
|
|
|
apply_model_load_options(
|
|
params, n_ctx, n_seed, memory_f16, mlock, embeddings, mmap,
|
|
n_gpu, n_batch, maingpu, tensorsplit, numa,
|
|
rope_freq_base, rope_freq_scale, lora, lora_base, perplexity);
|
|
|
|
llama_numa_init(params.numa);
|
|
|
|
auto * binding = new llama_binding_state();
|
|
binding->init = common_init_from_params(params);
|
|
if (!binding->init || binding->init->context() == nullptr) {
|
|
delete binding;
|
|
return nullptr;
|
|
}
|
|
|
|
binding->model = binding->init->model();
|
|
binding->ctx = binding->init->context();
|
|
binding->smpl = binding->init->sampler(0);
|
|
binding->embeddings = embeddings;
|
|
|
|
return binding;
|
|
}
|
|
|
|
void llama_binding_free_model(void * state_pr) {
|
|
delete binding_state(state_pr);
|
|
}
|
|
|
|
int load_state(void * state_pr, char * statefile, char * modes) {
|
|
(void) modes;
|
|
auto * state = binding_state(state_pr);
|
|
if (state == nullptr || state->ctx == nullptr) {
|
|
return 1;
|
|
}
|
|
|
|
std::vector<llama_token> tokens(llama_n_ctx(state->ctx));
|
|
size_t n_out = 0;
|
|
if (!llama_state_load_file(state->ctx, statefile, tokens.data(), tokens.size(), &n_out)) {
|
|
return 1;
|
|
}
|
|
return 0;
|
|
}
|
|
|
|
void save_state(void * state_pr, char * dst, char * modes) {
|
|
(void) modes;
|
|
auto * state = binding_state(state_pr);
|
|
if (state == nullptr || state->ctx == nullptr) {
|
|
return;
|
|
}
|
|
llama_state_save_file(state->ctx, dst, nullptr, 0);
|
|
}
|
|
|
|
void * llama_allocate_params(
|
|
const char * prompt,
|
|
int seed,
|
|
int threads,
|
|
int tokens,
|
|
int top_k,
|
|
float top_p,
|
|
float temp,
|
|
float repeat_penalty,
|
|
int repeat_last_n,
|
|
bool ignore_eos,
|
|
bool memory_f16,
|
|
int n_batch,
|
|
int n_keep,
|
|
const char ** antiprompt,
|
|
int antiprompt_count,
|
|
float tfs_z,
|
|
float typical_p,
|
|
float frequency_penalty,
|
|
float presence_penalty,
|
|
int mirostat,
|
|
float mirostat_eta,
|
|
float mirostat_tau,
|
|
bool penalize_nl,
|
|
const char * logit_bias,
|
|
const char * session_file,
|
|
bool prompt_cache_all,
|
|
bool mlock,
|
|
bool mmap,
|
|
const char * maingpu,
|
|
const char * tensorsplit,
|
|
bool prompt_cache_ro,
|
|
const char * grammar,
|
|
float rope_freq_base,
|
|
float rope_freq_scale,
|
|
float negative_prompt_scale,
|
|
const char * negative_prompt,
|
|
int n_draft) {
|
|
(void) tfs_z;
|
|
(void) penalize_nl;
|
|
(void) negative_prompt_scale;
|
|
(void) negative_prompt;
|
|
(void) memory_f16;
|
|
|
|
auto * params = new common_params();
|
|
params->prompt = prompt != nullptr ? prompt : "";
|
|
params->n_predict = tokens;
|
|
params->n_batch = n_batch > 0 ? n_batch : params->n_batch;
|
|
params->n_keep = n_keep;
|
|
params->use_mlock = mlock;
|
|
params->use_mmap = mmap;
|
|
params->path_prompt_cache = session_file != nullptr ? session_file : "";
|
|
params->prompt_cache_all = prompt_cache_all;
|
|
params->prompt_cache_ro = prompt_cache_ro;
|
|
|
|
if (rope_freq_base > 0.0f) {
|
|
params->rope_freq_base = rope_freq_base;
|
|
}
|
|
if (rope_freq_scale > 0.0f) {
|
|
params->rope_freq_scale = rope_freq_scale;
|
|
}
|
|
|
|
params->sampling.seed = seed >= 0 ? (uint32_t) seed : LLAMA_DEFAULT_SEED;
|
|
params->cpuparams.n_threads = threads > 0 ? threads : 4;
|
|
params->cpuparams_batch.n_threads = params->cpuparams.n_threads;
|
|
params->sampling.top_k = top_k;
|
|
params->sampling.top_p = top_p;
|
|
params->sampling.temp = temp;
|
|
params->sampling.penalty_repeat = repeat_penalty;
|
|
params->sampling.penalty_last_n = repeat_last_n;
|
|
params->sampling.penalty_freq = frequency_penalty;
|
|
params->sampling.penalty_present = presence_penalty;
|
|
params->sampling.typ_p = typical_p > 0 ? typical_p : 1.0f;
|
|
params->sampling.mirostat = mirostat;
|
|
params->sampling.mirostat_eta = mirostat_eta;
|
|
params->sampling.mirostat_tau = mirostat_tau;
|
|
params->sampling.ignore_eos = ignore_eos;
|
|
|
|
if (grammar != nullptr && grammar[0] != '\0') {
|
|
params->sampling.grammar = common_grammar(COMMON_GRAMMAR_TYPE_USER, grammar);
|
|
}
|
|
|
|
if (maingpu != nullptr && maingpu[0] != '\0') {
|
|
params->main_gpu = std::stoi(maingpu);
|
|
}
|
|
parse_tensor_split(tensorsplit, params->tensor_split, sizeof(params->tensor_split) / sizeof(params->tensor_split[0]));
|
|
|
|
if (antiprompt_count > 0 && antiprompt != nullptr) {
|
|
params->antiprompt = create_vector(antiprompt, antiprompt_count);
|
|
}
|
|
|
|
if (logit_bias != nullptr && logit_bias[0] != '\0') {
|
|
std::stringstream ss(logit_bias);
|
|
llama_token key;
|
|
char sign = 0;
|
|
std::string value_str;
|
|
if (ss >> key >> sign && std::getline(ss, value_str) && (sign == '+' || sign == '-')) {
|
|
params->sampling.logit_bias.push_back({key, std::stof(value_str) * ((sign == '-') ? -1.0f : 1.0f)});
|
|
}
|
|
}
|
|
|
|
params->speculative.draft.n_max = n_draft > 0 ? n_draft : params->speculative.draft.n_max;
|
|
|
|
return params;
|
|
}
|
|
|
|
void llama_free_params(void * params_ptr) {
|
|
delete static_cast<common_params *>(params_ptr);
|
|
}
|
|
|
|
int eval(void * params_ptr, void * state_pr, char * text) {
|
|
auto * params = static_cast<common_params *>(params_ptr);
|
|
auto * state = binding_state(state_pr);
|
|
if (state == nullptr || state->ctx == nullptr) {
|
|
return 1;
|
|
}
|
|
|
|
std::string str = text != nullptr ? text : params->prompt;
|
|
auto embd = common_tokenize(state->ctx, str, true, true);
|
|
if (embd.empty()) {
|
|
return 1;
|
|
}
|
|
|
|
int n_past = 0;
|
|
if (!common_prompt_batch_decode(state->ctx, embd, n_past, params->n_batch, "", false)) {
|
|
return 1;
|
|
}
|
|
return 0;
|
|
}
|
|
|
|
int get_embeddings(void * params_ptr, void * state_pr, float * res_embeddings) {
|
|
auto * params = static_cast<common_params *>(params_ptr);
|
|
auto * state = binding_state(state_pr);
|
|
if (state == nullptr || state->ctx == nullptr || !state->embeddings) {
|
|
return 1;
|
|
}
|
|
|
|
auto embd = common_tokenize(state->ctx, params->prompt, true, true);
|
|
if (!embd.empty()) {
|
|
int n_past = 0;
|
|
if (!common_prompt_batch_decode(state->ctx, embd, n_past, params->n_batch, "", false)) {
|
|
return 1;
|
|
}
|
|
}
|
|
|
|
const int n_embd = llama_model_n_embd(state->model);
|
|
const float * emb = llama_get_embeddings_ith(state->ctx, -1);
|
|
if (emb == nullptr) {
|
|
emb = llama_get_embeddings(state->ctx);
|
|
}
|
|
if (emb == nullptr) {
|
|
return 1;
|
|
}
|
|
|
|
for (int i = 0; i < n_embd; ++i) {
|
|
res_embeddings[i] = emb[i];
|
|
}
|
|
return 0;
|
|
}
|
|
|
|
int get_token_embeddings(void * params_ptr, void * state_pr, int * tokens, int tokenSize, float * res_embeddings) {
|
|
auto * params = static_cast<common_params *>(params_ptr);
|
|
auto * state = binding_state(state_pr);
|
|
if (state == nullptr || state->ctx == nullptr) {
|
|
return 1;
|
|
}
|
|
|
|
std::string text;
|
|
for (int i = 0; i < tokenSize; ++i) {
|
|
text += common_token_to_piece(state->ctx, tokens[i]);
|
|
}
|
|
params->prompt = text;
|
|
return get_embeddings(params_ptr, state_pr, res_embeddings);
|
|
}
|
|
|
|
int llama_tokenize_string(void * params_ptr, void * state_pr, int * result) {
|
|
auto * params = static_cast<common_params *>(params_ptr);
|
|
auto * state = binding_state(state_pr);
|
|
if (state == nullptr || state->ctx == nullptr) {
|
|
return -1;
|
|
}
|
|
|
|
const llama_vocab * vocab = llama_model_get_vocab(state->model);
|
|
const bool add_bos = llama_vocab_get_add_bos(vocab);
|
|
const int32_t max_tokens = params->n_ctx > 0 ? params->n_ctx : 4096;
|
|
|
|
return llama_tokenize(
|
|
vocab,
|
|
params->prompt.c_str(),
|
|
(int32_t) params->prompt.size(),
|
|
reinterpret_cast<llama_token *>(result),
|
|
max_tokens,
|
|
add_bos,
|
|
true);
|
|
}
|
|
|
|
int llama_predict(void * params_ptr, void * state_pr, char * result, bool debug) {
|
|
auto * params = static_cast<common_params *>(params_ptr);
|
|
auto * state = binding_state(state_pr);
|
|
if (state == nullptr || state->ctx == nullptr || state->smpl == nullptr) {
|
|
return 1;
|
|
}
|
|
|
|
llama_context * ctx = state->ctx;
|
|
llama_model * model = state->model;
|
|
const llama_vocab * vocab = llama_model_get_vocab(model);
|
|
llama_memory_t mem = llama_get_memory(ctx);
|
|
|
|
common_sampler_ptr smpl_ptr(common_sampler_init(model, params->sampling));
|
|
if (!smpl_ptr) {
|
|
return 1;
|
|
}
|
|
common_sampler * smpl = smpl_ptr.get();
|
|
|
|
const int n_ctx = llama_n_ctx(ctx);
|
|
if (params->n_predict < 0) {
|
|
params->n_predict = 128;
|
|
}
|
|
|
|
llama_set_n_threads(ctx, params->cpuparams.n_threads, params->cpuparams_batch.n_threads);
|
|
|
|
std::string path_session = params->path_prompt_cache;
|
|
std::vector<llama_token> session_tokens;
|
|
|
|
if (!path_session.empty()) {
|
|
session_tokens.resize(n_ctx);
|
|
size_t n_out = 0;
|
|
if (std::ifstream(path_session).good()) {
|
|
llama_state_load_file(ctx, path_session.c_str(), session_tokens.data(), session_tokens.size(), &n_out);
|
|
session_tokens.resize(n_out);
|
|
}
|
|
}
|
|
|
|
const bool add_bos = llama_vocab_get_add_bos(vocab);
|
|
std::vector<llama_token> embd_inp = common_tokenize(ctx, params->prompt, add_bos, true);
|
|
if (embd_inp.empty()) {
|
|
embd_inp.push_back(llama_vocab_bos(vocab));
|
|
}
|
|
|
|
if ((int) embd_inp.size() > n_ctx - 4) {
|
|
return 1;
|
|
}
|
|
|
|
if (params->n_keep < 0 || params->n_keep > (int) embd_inp.size()) {
|
|
params->n_keep = (int) embd_inp.size();
|
|
}
|
|
|
|
common_sampler_reset(smpl);
|
|
|
|
int n_past = 0;
|
|
int n_remain = params->n_predict;
|
|
int n_consumed = 0;
|
|
int n_session_consumed = 0;
|
|
bool is_antiprompt = false;
|
|
bool need_save_session = !path_session.empty() && !params->prompt_cache_ro;
|
|
|
|
std::vector<llama_token> embd;
|
|
std::string res;
|
|
|
|
while (n_remain > 0 && !is_antiprompt) {
|
|
if (!embd.empty()) {
|
|
const int max_embd_size = n_ctx - 4;
|
|
if ((int) embd.size() > max_embd_size) {
|
|
embd.resize(max_embd_size);
|
|
}
|
|
|
|
if (n_past + (int) embd.size() >= n_ctx) {
|
|
const int n_left = n_past - params->n_keep;
|
|
const int n_discard = n_left / 2;
|
|
llama_memory_seq_rm(mem, 0, params->n_keep, params->n_keep + n_discard);
|
|
llama_memory_seq_add(mem, 0, params->n_keep + n_discard, n_past, -n_discard);
|
|
n_past -= n_discard;
|
|
path_session.clear();
|
|
}
|
|
|
|
if (n_session_consumed < (int) session_tokens.size()) {
|
|
size_t i = 0;
|
|
for (; i < embd.size(); ++i) {
|
|
if (embd[i] != session_tokens[n_session_consumed]) {
|
|
session_tokens.resize(n_session_consumed);
|
|
break;
|
|
}
|
|
n_past++;
|
|
n_session_consumed++;
|
|
if (n_session_consumed >= (int) session_tokens.size()) {
|
|
++i;
|
|
break;
|
|
}
|
|
}
|
|
if (i > 0) {
|
|
embd.erase(embd.begin(), embd.begin() + i);
|
|
}
|
|
}
|
|
|
|
if (!embd.empty()) {
|
|
const bool save_now = need_save_session && n_consumed >= (int) embd_inp.size();
|
|
if (!common_prompt_batch_decode(ctx, embd, n_past, params->n_batch, path_session, save_now)) {
|
|
return 1;
|
|
}
|
|
session_tokens.insert(session_tokens.end(), embd.begin(), embd.end());
|
|
n_session_consumed = session_tokens.size();
|
|
need_save_session = false;
|
|
}
|
|
}
|
|
|
|
embd.clear();
|
|
|
|
if ((int) embd_inp.size() <= n_consumed) {
|
|
const llama_token id = common_sampler_sample(smpl, ctx, -1);
|
|
common_sampler_accept(smpl, id, true);
|
|
embd.push_back(id);
|
|
|
|
auto piece = common_token_to_piece(ctx, id);
|
|
if (!tokenCallback(state_pr, const_cast<char *>(piece.c_str()))) {
|
|
break;
|
|
}
|
|
|
|
res += piece;
|
|
--n_remain;
|
|
|
|
if (llama_vocab_is_eog(vocab, id)) {
|
|
break;
|
|
}
|
|
} else {
|
|
while ((int) embd_inp.size() > n_consumed) {
|
|
embd.push_back(embd_inp[n_consumed]);
|
|
common_sampler_accept(smpl, embd_inp[n_consumed], false);
|
|
++n_consumed;
|
|
if ((int) embd.size() >= params->n_batch) {
|
|
break;
|
|
}
|
|
}
|
|
}
|
|
|
|
for (const auto id : embd) {
|
|
res += common_token_to_piece(ctx, id);
|
|
}
|
|
|
|
if ((int) embd_inp.size() <= n_consumed && !params->antiprompt.empty()) {
|
|
is_antiprompt = check_antiprompt(res, params->antiprompt, false);
|
|
}
|
|
}
|
|
|
|
if (!path_session.empty() && params->prompt_cache_all && !params->prompt_cache_ro) {
|
|
llama_state_save_file(ctx, path_session.c_str(), session_tokens.data(), session_tokens.size());
|
|
}
|
|
|
|
if (debug) {
|
|
common_perf_print(ctx, smpl);
|
|
}
|
|
|
|
if (result != nullptr) {
|
|
std::strncpy(result, res.c_str(), params->n_predict > 0 ? (size_t) params->n_predict : res.size());
|
|
result[params->n_predict > 0 ? params->n_predict - 1 : res.size()] = '\0';
|
|
}
|
|
|
|
return 0;
|
|
}
|
|
|
|
int speculative_sampling(void * params_ptr, void * target_model, void * draft_model, char * result, bool debug) {
|
|
auto * params = static_cast<common_params *>(params_ptr);
|
|
auto * tgt = binding_state(target_model);
|
|
auto * dft = binding_state(draft_model);
|
|
if (tgt == nullptr || dft == nullptr || tgt->ctx == nullptr || dft->ctx == nullptr) {
|
|
return 1;
|
|
}
|
|
|
|
llama_context * ctx_tgt = tgt->ctx;
|
|
llama_context * ctx_dft = dft->ctx;
|
|
const llama_vocab * vocab = llama_model_get_vocab(tgt->model);
|
|
|
|
common_sampler_ptr smpl_ptr(common_sampler_init(tgt->model, params->sampling));
|
|
if (!smpl_ptr) {
|
|
return 1;
|
|
}
|
|
common_sampler * smpl_tgt = smpl_ptr.get();
|
|
|
|
auto inp = common_tokenize(ctx_tgt, params->prompt, true, true);
|
|
const int max_tokens = llama_n_ctx(ctx_tgt) - 4;
|
|
if ((int) inp.size() > max_tokens) {
|
|
return 1;
|
|
}
|
|
|
|
int n_past_tgt = 0;
|
|
int n_past_dft = 0;
|
|
if (!inp.empty()) {
|
|
if (!common_prompt_batch_decode(ctx_tgt, inp, n_past_tgt, params->n_batch, "", false)) {
|
|
return 1;
|
|
}
|
|
if (!common_prompt_batch_decode(ctx_dft, inp, n_past_dft, params->n_batch, "", false)) {
|
|
return 1;
|
|
}
|
|
}
|
|
|
|
const int n_draft = params->speculative.draft.n_max > 0 ? params->speculative.draft.n_max : 16;
|
|
int n_predict = 0;
|
|
std::string res;
|
|
bool has_eos = false;
|
|
|
|
std::vector<llama_token> drafted;
|
|
std::vector<llama_token> last_tokens(llama_n_ctx(ctx_tgt), 0);
|
|
for (auto id : inp) {
|
|
last_tokens.erase(last_tokens.begin());
|
|
last_tokens.push_back(id);
|
|
}
|
|
|
|
while (n_predict < params->n_predict && !has_eos) {
|
|
int i_dft = 0;
|
|
while (true) {
|
|
const llama_token id = common_sampler_sample(smpl_tgt, ctx_tgt, -1);
|
|
common_sampler_accept(smpl_tgt, id, true);
|
|
|
|
last_tokens.erase(last_tokens.begin());
|
|
last_tokens.push_back(id);
|
|
|
|
auto piece = common_token_to_piece(ctx_tgt, id);
|
|
if (!tokenCallback(draft_model, const_cast<char *>(piece.c_str()))) {
|
|
break;
|
|
}
|
|
res += piece;
|
|
|
|
if (llama_vocab_is_eog(vocab, id)) {
|
|
has_eos = true;
|
|
}
|
|
|
|
++n_predict;
|
|
|
|
if (i_dft < (int) drafted.size() && id == drafted[i_dft]) {
|
|
++i_dft;
|
|
continue;
|
|
}
|
|
|
|
llama_token dft_id = id;
|
|
llama_batch batch = llama_batch_get_one(&dft_id, 1);
|
|
if (llama_decode(ctx_dft, batch) != 0) {
|
|
return 1;
|
|
}
|
|
++n_past_dft;
|
|
|
|
drafted.clear();
|
|
drafted.push_back(id);
|
|
break;
|
|
}
|
|
|
|
if (n_predict >= params->n_predict || has_eos) {
|
|
break;
|
|
}
|
|
|
|
int n_past_cur = n_past_dft;
|
|
for (int i = 0; i < n_draft; ++i) {
|
|
float * logits = llama_get_logits(ctx_dft);
|
|
const int n_vocab = llama_vocab_n_tokens(vocab);
|
|
|
|
llama_token draft_id = 0;
|
|
float max_logit = logits[0];
|
|
for (llama_token t = 1; t < n_vocab; ++t) {
|
|
if (logits[t] > max_logit) {
|
|
max_logit = logits[t];
|
|
draft_id = t;
|
|
}
|
|
}
|
|
drafted.push_back(draft_id);
|
|
|
|
if (i == n_draft - 1) {
|
|
break;
|
|
}
|
|
|
|
llama_batch batch = llama_batch_get_one(&draft_id, 1);
|
|
if (llama_decode(ctx_dft, batch) != 0) {
|
|
return 1;
|
|
}
|
|
++n_past_cur;
|
|
}
|
|
|
|
llama_batch batch = llama_batch_get_one(drafted.data(), (int32_t) drafted.size());
|
|
if (llama_decode(ctx_tgt, batch) != 0) {
|
|
return 1;
|
|
}
|
|
++n_past_tgt;
|
|
|
|
if (!drafted.empty()) {
|
|
drafted.erase(drafted.begin());
|
|
}
|
|
}
|
|
|
|
if (debug) {
|
|
common_perf_print(ctx_tgt, smpl_tgt);
|
|
common_perf_print(ctx_dft, nullptr);
|
|
}
|
|
|
|
if (result != nullptr) {
|
|
std::strncpy(result, res.c_str(), params->n_predict > 0 ? (size_t) params->n_predict : res.size());
|
|
}
|
|
|
|
return 0;
|
|
}
|
|
|
|
} // extern "C"
|
|
|
|
std::vector<std::string> create_vector(const char ** strings, int count) {
|
|
std::vector<std::string> vec;
|
|
for (int i = 0; i < count; ++i) {
|
|
vec.emplace_back(strings[i]);
|
|
}
|
|
return vec;
|
|
}
|
|
|
|
void delete_vector(std::vector<std::string> * vec) {
|
|
delete vec;
|
|
}
|