From 840afb5e4eac79a8bed7620b6a015e5e95237fcc Mon Sep 17 00:00:00 2001 From: Ilya Prockofiev <125394064+somepatt@users.noreply.github.com> Date: Tue, 10 Feb 2026 10:15:12 +0300 Subject: [PATCH] add kv cache on inference --- Coconut/coconut.py | 40 ++++++++++++++++++++++++++++++++-------- 1 file changed, 32 insertions(+), 8 deletions(-) diff --git a/Coconut/coconut.py b/Coconut/coconut.py index e0c200b..151adcc 100755 --- a/Coconut/coconut.py +++ b/Coconut/coconut.py @@ -214,6 +214,8 @@ def generate( tokens = input_ids[0].detach().tolist() + past_key_values = None + labels = input_ids.clone() # placeholder. not used. outputs = self.forward( input_ids, @@ -224,27 +226,44 @@ def generate( ).reshape(1, -1), ) inputs_embeds = outputs.inputs_embeds + + pref_out = self.base_causallm( + inputs_embeds=inputs_embeds, + use_cache=True + ) + + self.gen_forward_cnt += 1 + + past_key_values = pref_out.past_key_values # get the first token using the current hidden state - next_token = torch.argmax(outputs.logits[0, -1]).item() + next_token = torch.argmax(pref_out.logits[0, -1]).item() tokens.append(next_token) new_token_embed = self.embedding( - torch.tensor(next_token, device=input_ids.device) + torch.tensor([next_token], device=input_ids.device) ).view(1, 1, -1) - new_inputs_embeds = torch.cat((inputs_embeds, new_token_embed), dim=1) + full_inputs_embeds = torch.cat((inputs_embeds, new_token_embed), dim=1) + # get other tokens for _ in range(max_new_tokens - 1): - outputs = self.base_causallm(inputs_embeds=new_inputs_embeds) + outputs = self.base_causallm( + inputs_embeds=new_token_embed, + past_key_values=past_key_values, + use_cache=True + ) + + past_key_values = outputs.past_key_values + self.gen_forward_cnt += 1 next_token = torch.argmax(outputs.logits[0, -1]).item() if next_token == self.eos_token_id: break tokens.append(next_token) new_token_embed = self.embedding( - torch.tensor(next_token, device=input_ids.device) + torch.tensor([next_token], device=input_ids.device) ).view(1, 1, -1) - new_inputs_embeds = torch.cat((new_inputs_embeds, new_token_embed), dim=1) + full_inputs_embeds = torch.cat((full_inputs_embeds, new_token_embed), dim=1) if synced_gpus: # in FSDP, the number of forward pass need to be the same across devices @@ -252,11 +271,16 @@ def generate( self.gen_forward_cnt < max_new_tokens + MAX_N_LATENT ): # leave some room for latent tokens self.gen_forward_cnt += 1 - _ = self.base_causallm(inputs_embeds=new_inputs_embeds) + _ = self.base_causallm( + inputs_embeds=new_token_embed, + past_key_values=past_key_values, + use_cache=True + ) + past_key_values = _.past_key_values if output_embedding: # for analysis purpose - return torch.tensor(tokens).view(1, -1), new_inputs_embeds + return torch.tensor(tokens).view(1, -1), full_inputs_embeds else: return torch.tensor(tokens).view(1, -1)