Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 3 additions & 12 deletions common/ngram-map.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -47,21 +47,15 @@ static std::string common_tokens_to_str(const llama_tokens & inp, size_t start,
* @return Vector of draft tokens, empty if no matching pattern is found
*/
llama_tokens common_ngram_simple_draft(
common_ngram_simple_state & state,
const common_ngram_simple_config & config,
const llama_tokens & tokens, llama_token sampled) {

// Simple implementation of self-speculative decoding without a draft model.
//
const size_t cur_len = tokens.size();
// Only check every check_rate tokens to save compute
// i.e., perform check if (cur_len - idx_last_check) >= check_rate
if (state.idx_last_check + state.config.check_rate > cur_len) {
llama_tokens draft_tokens;
return draft_tokens;
}

size_t n_draft_min = state.config.size_ngram; // size of n-gram to lookup in token history
size_t n_draft_max = state.config.size_mgram; // the m-gram following the found n-gram is used for draft
const size_t n_draft_min = config.size_ngram; // size of n-gram to lookup in token history
const size_t n_draft_max = config.size_mgram; // the m-gram following the found n-gram is used for draft

// vector for tokens we want to verify.
// return empty vector if there is no match.
Expand All @@ -80,9 +74,6 @@ llama_tokens common_ngram_simple_draft(
}
pattern.push_back(sampled); // add the last token to the pattern

// We do a search in the token history.
state.idx_last_check = cur_len;

size_t match_pos = 0; // we ignore position 0, position 0 == no match
// search backwards, but skip the current match (we are currently there)
for (size_t j = cur_len - n_draft_min - 1; j > 0; --j) {
Expand Down
16 changes: 1 addition & 15 deletions common/ngram-map.h
Original file line number Diff line number Diff line change
Expand Up @@ -27,23 +27,9 @@ struct common_ngram_simple_config {
uint16_t check_rate; // check for speculative decoding without draft model for each check_rate token
};

// current state (and config) of n-gram simple.
struct common_ngram_simple_state {
common_ngram_simple_config config;

size_t idx_last_check = 0; // index of last check in context history (mutable)

common_ngram_simple_state(const common_ngram_simple_config & config)
: config(config) {}
};

// Searches for a n-gram in the history and checks whether a draft sequence should be generated.
// state: the ngram simple state to search in.
// inp: the tokens generated so far.
// sampled: the token that was just sampled.
// draft: vector to store the draft tokens, initially empty.
llama_tokens common_ngram_simple_draft(
common_ngram_simple_state & state,
const common_ngram_simple_config & config,
const llama_tokens & tokens, llama_token sampled);


Expand Down
20 changes: 14 additions & 6 deletions common/speculative.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -463,12 +463,14 @@ struct common_speculative_state_eagle3 : public common_speculative_state {

// state of self-speculation (simple implementation, not ngram-map)
struct common_speculative_state_ngram_simple : public common_speculative_state {
common_ngram_simple_state state;
common_ngram_simple_config config;

uint16_t check_id = 0; // used to control the frequency of generating drafts

common_speculative_state_ngram_simple(
enum common_speculative_type type,
common_ngram_simple_state state)
: common_speculative_state(type), state(state) {}
common_ngram_simple_config config)
: common_speculative_state(type), config(config) {}

void begin(const llama_tokens & prompt) override {
GGML_UNUSED(prompt);
Expand All @@ -479,7 +481,13 @@ struct common_speculative_state_ngram_simple : public common_speculative_state {
const llama_tokens & prompt_tgt,
llama_token id_last,
llama_tokens & result) override {
result = common_ngram_simple_draft(state, prompt_tgt, id_last);
++check_id;
if (check_id < config.check_rate) {
return;
}
check_id = 0;

result = common_ngram_simple_draft(config, prompt_tgt, id_last);
GGML_UNUSED(params);
}

Expand Down Expand Up @@ -889,14 +897,14 @@ common_speculative * common_speculative_init(
uint16_t mgram_size_value = ngram_map.size_value;
uint16_t check_rate = ngram_map.check_rate;

auto config_simple = common_ngram_simple_config{
auto config_simple = common_ngram_simple_config {
/* .size_ngram = */ ngram_size_key,
/* .size_mgram = */ mgram_size_value,
/* .check_rate = */ check_rate
};
auto state = std::make_unique<common_speculative_state_ngram_simple>(
/* .type = */ config.type,
/* .state = */ common_ngram_simple_state(config_simple)
/* .state = */ config_simple
);
impls.push_back(std::move(state));
break;
Expand Down
Loading