kopia lustrzana https://github.com/jaymody/picoGPT
feat: added k, v cache for inference speed up
rodzic
0550692d73
commit
0c1dd6c466
52
gpt2.py
52
gpt2.py
|
@ -35,66 +35,92 @@ def attention(q, k, v, mask): # [n_q, d_k], [n_k, d_k], [n_k, d_v], [n_q, n_k]
|
||||||
return softmax(q @ k.T / np.sqrt(q.shape[-1]) + mask) @ v
|
return softmax(q @ k.T / np.sqrt(q.shape[-1]) + mask) @ v
|
||||||
|
|
||||||
|
|
||||||
def mha(x, c_attn, c_proj, n_head): # [n_seq, n_embd] -> [n_seq, n_embd]
|
def mha(x, c_attn, c_proj, n_head, kvcache): # [n_seq, n_embd] -> [n_seq, n_embd]
|
||||||
# qkv projection
|
# qkv projection
|
||||||
x = linear(x, **c_attn) # [n_seq, n_embd] -> [n_seq, 3*n_embd]
|
x = linear(x, **c_attn) # [n_seq, n_embd] -> [n_seq, 3*n_embd]
|
||||||
|
|
||||||
# split into qkv
|
# split into qkv
|
||||||
qkv = np.split(x, 3, axis=-1) # [n_seq, 3*n_embd] -> [3, n_seq, n_embd]
|
qkv = np.split(x, 3, axis=-1) # [n_seq, 3*n_embd] -> [3, n_seq, n_embd]
|
||||||
|
|
||||||
|
if kvcache:
|
||||||
|
# qkv
|
||||||
|
q, k, v = qkv
|
||||||
|
old_k, old_v = kvcache
|
||||||
|
k = np.vstack([old_k, k])
|
||||||
|
v = np.vstack([old_v, v])
|
||||||
|
qkv = [q, k, v]
|
||||||
|
|
||||||
|
current_cache = [qkv[1], qkv[2]]
|
||||||
|
|
||||||
# split into heads
|
# split into heads
|
||||||
qkv_heads = list(map(lambda x: np.split(x, n_head, axis=-1), qkv)) # [3, n_seq, n_embd] -> [n_head, 3, n_seq, n_embd/n_head]
|
qkv_heads = list(map(lambda x: np.split(x, n_head, axis=-1), qkv)) # [3, n_seq, n_embd] -> [n_head, 3, n_seq, n_embd/n_head]
|
||||||
|
|
||||||
# causal mask to hide future inputs from being attended to
|
# causal mask to hide future inputs from being attended to
|
||||||
causal_mask = (1 - np.tri(x.shape[0])) * -1e10 # [n_seq, n_seq]
|
if kvcache:
|
||||||
|
causal_mask = np.zeros((1, k.shape[0]))
|
||||||
|
else:
|
||||||
|
causal_mask = (1 - np.tri(x.shape[0])) * -1e10 # [n_seq, n_seq]
|
||||||
|
|
||||||
# perform attention over each head
|
# perform attention over each head
|
||||||
out_heads = [attention(q, k, v, causal_mask) for q, k, v in zip(*qkv_heads)] # [n_head, 3, n_seq, n_embd/n_head] -> [n_head, n_seq, n_embd/n_head]
|
out_heads = [attention(q, k, v, causal_mask) for q, k, v in zip(*qkv_heads)] # [n_head, 3, n_seq, n_embd/n_head] -> [n_head, n_seq, n_embd/n_head]
|
||||||
|
|
||||||
|
|
||||||
# merge heads
|
# merge heads
|
||||||
x = np.hstack(out_heads) # [n_head, n_seq, n_embd/n_head] -> [n_seq, n_embd]
|
x = np.hstack(out_heads) # [n_head, n_seq, n_embd/n_head] -> [n_seq, n_embd]
|
||||||
|
|
||||||
# out projection
|
# out projection
|
||||||
x = linear(x, **c_proj) # [n_seq, n_embd] -> [n_seq, n_embd]
|
x = linear(x, **c_proj) # [n_seq, n_embd] -> [n_seq, n_embd]
|
||||||
|
|
||||||
return x
|
return x, current_cache
|
||||||
|
|
||||||
|
|
||||||
def transformer_block(x, mlp, attn, ln_1, ln_2, n_head): # [n_seq, n_embd] -> [n_seq, n_embd]
|
def transformer_block(x, mlp, attn, ln_1, ln_2, n_head, kvcache): # [n_seq, n_embd] -> [n_seq, n_embd]
|
||||||
# multi-head causal self attention
|
# multi-head causal self attention
|
||||||
x = x + mha(layer_norm(x, **ln_1), **attn, n_head=n_head) # [n_seq, n_embd] -> [n_seq, n_embd]
|
attn_out, kvcache_updated = mha(layer_norm(x, **ln_1), **attn, n_head=n_head, kvcache=kvcache)
|
||||||
|
x = x + attn_out # [n_seq, n_embd] -> [n_seq, n_embd]
|
||||||
|
|
||||||
# position-wise feed forward network
|
# position-wise feed forward network
|
||||||
x = x + ffn(layer_norm(x, **ln_2), **mlp) # [n_seq, n_embd] -> [n_seq, n_embd]
|
x = x + ffn(layer_norm(x, **ln_2), **mlp) # [n_seq, n_embd] -> [n_seq, n_embd]
|
||||||
|
|
||||||
return x
|
return x, kvcache_updated
|
||||||
|
|
||||||
|
|
||||||
def gpt2(inputs, wte, wpe, blocks, ln_f, n_head): # [n_seq] -> [n_seq, n_vocab]
|
def gpt2(inputs, wte, wpe, blocks, ln_f, n_head, kvcache): # [n_seq] -> [n_seq, n_vocab]
|
||||||
|
if not kvcache:
|
||||||
|
kvcache = [None]*len(blocks)
|
||||||
|
wpe_out = wpe[range(len(inputs))]
|
||||||
|
else:
|
||||||
|
wpe_out = wpe[[len(inputs)-1]]
|
||||||
|
inputs = [inputs[-1]]
|
||||||
|
|
||||||
# token + positional embeddings
|
# token + positional embeddings
|
||||||
x = wte[inputs] + wpe[range(len(inputs))] # [n_seq] -> [n_seq, n_embd]
|
x = wte[inputs] + wpe_out # [n_seq] -> [n_seq, n_embd]
|
||||||
|
|
||||||
|
|
||||||
# forward pass through n_layer transformer blocks
|
# forward pass through n_layer transformer blocks
|
||||||
for block in blocks:
|
new_kvcache = []
|
||||||
x = transformer_block(x, **block, n_head=n_head) # [n_seq, n_embd] -> [n_seq, n_embd]
|
for block, kvcache_block in zip(blocks, kvcache):
|
||||||
|
x, updated_cache = transformer_block(x, **block, n_head=n_head, kvcache=kvcache_block) # [n_seq, n_embd] -> [n_seq, n_embd]
|
||||||
|
new_kvcache.append(updated_cache) # TODO: inplace extend new cache instead of re-saving whole
|
||||||
|
|
||||||
# projection to vocab
|
# projection to vocab
|
||||||
x = layer_norm(x, **ln_f) # [n_seq, n_embd] -> [n_seq, n_embd]
|
x = layer_norm(x, **ln_f) # [n_seq, n_embd] -> [n_seq, n_embd]
|
||||||
return x @ wte.T # [n_seq, n_embd] -> [n_seq, n_vocab]
|
return x @ wte.T, new_kvcache # [n_seq, n_embd] -> [n_seq, n_vocab]
|
||||||
|
|
||||||
|
|
||||||
def generate(inputs, params, n_head, n_tokens_to_generate):
|
def generate(inputs, params, n_head, n_tokens_to_generate):
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
|
|
||||||
|
kvcache = None
|
||||||
for _ in tqdm(range(n_tokens_to_generate), "generating"): # auto-regressive decode loop
|
for _ in tqdm(range(n_tokens_to_generate), "generating"): # auto-regressive decode loop
|
||||||
logits = gpt2(inputs, **params, n_head=n_head) # model forward pass
|
logits, kvcache = gpt2(inputs, **params, n_head=n_head, kvcache=kvcache) # model forward pass
|
||||||
next_id = np.argmax(logits[-1]) # greedy sampling
|
next_id = np.argmax(logits[-1]) # greedy sampling
|
||||||
inputs = np.append(inputs, [next_id]) # append prediction to input
|
inputs = np.append(inputs, [next_id]) # append prediction to input
|
||||||
|
|
||||||
return list(inputs[len(inputs) - n_tokens_to_generate :]) # only return generated ids
|
return list(inputs[len(inputs) - n_tokens_to_generate :]) # only return generated ids
|
||||||
|
|
||||||
|
|
||||||
def main(prompt: str, n_tokens_to_generate: int = 40, model_size: str = "124M", models_dir: str = "models"):
|
def main(prompt: str = "Alan Turing theorized that computers would one day become", n_tokens_to_generate: int = 40, model_size: str = "124M", models_dir: str = "models"):
|
||||||
from utils import load_encoder_hparams_and_params
|
from utils import load_encoder_hparams_and_params
|
||||||
|
|
||||||
# load encoder, hparams, and params from the released open-ai gpt-2 files
|
# load encoder, hparams, and params from the released open-ai gpt-2 files
|
||||||
|
|
Ładowanie…
Reference in New Issue