diff --git a/gpt2.py b/gpt2.py index ff7d7d6..167bf1a 100644 --- a/gpt2.py +++ b/gpt2.py @@ -35,8 +35,9 @@ def attention(q, k, v, mask): # [n_q, d_k], [n_k, d_k], [n_k, d_v], [n_q, n_k] return softmax(q @ k.T / np.sqrt(q.shape[-1]) + mask) @ v -def mha(x, c_attn, c_proj, n_head, kvcache): # [n_seq, n_embd] -> [n_seq, n_embd] +def mha(x, c_attn, c_proj, n_head, kvcache=None): # [n_seq, n_embd] -> [n_seq, n_embd] # qkv projection + # when we pass kvcache, n_seq = 1. so we will compute new_q, new_k and new_v x = linear(x, **c_attn) # [n_seq, n_embd] -> [n_seq, 3*n_embd] # split into qkv @@ -44,11 +45,11 @@ def mha(x, c_attn, c_proj, n_head, kvcache): # [n_seq, n_embd] -> [n_seq, n_emb if kvcache: # qkv - q, k, v = qkv + new_q, new_k, new_v = qkv # new_q, new_k, new_v = [1, n_embd] old_k, old_v = kvcache - k = np.vstack([old_k, k]) - v = np.vstack([old_v, v]) - qkv = [q, k, v] + k = np.vstack([old_k, new_k]) # k = [n_seq, n_embd], where n_seq = prev_n_seq + 1 + v = np.vstack([old_v, new_v]) # v = [n_seq, n_embd], where n_seq = prev_n_seq + 1 + qkv = [new_q, k, v] current_cache = [qkv[1], qkv[2]] @@ -56,9 +57,11 @@ def mha(x, c_attn, c_proj, n_head, kvcache): # [n_seq, n_embd] -> [n_seq, n_emb qkv_heads = list(map(lambda x: np.split(x, n_head, axis=-1), qkv)) # [3, n_seq, n_embd] -> [n_head, 3, n_seq, n_embd/n_head] # causal mask to hide future inputs from being attended to - if kvcache: + if kvcache: + # when we pass kvcache, we are passing single token as input which need to attend to all previous tokens, so we create mask with all 0s causal_mask = np.zeros((1, k.shape[0])) else: + # create triangular causal mask causal_mask = (1 - np.tri(x.shape[0])) * -1e10 # [n_seq, n_seq] # perform attention over each head @@ -74,7 +77,7 @@ def mha(x, c_attn, c_proj, n_head, kvcache): # [n_seq, n_embd] -> [n_seq, n_emb return x, current_cache -def transformer_block(x, mlp, attn, ln_1, ln_2, n_head, kvcache): # [n_seq, n_embd] -> [n_seq, n_embd] +def transformer_block(x, mlp, attn, ln_1, ln_2, n_head, kvcache=None): # [n_seq, n_embd] -> [n_seq, n_embd] # multi-head causal self attention attn_out, kvcache_updated = mha(layer_norm(x, **ln_1), **attn, n_head=n_head, kvcache=kvcache) x = x + attn_out # [n_seq, n_embd] -> [n_seq, n_embd] @@ -85,7 +88,7 @@ def transformer_block(x, mlp, attn, ln_1, ln_2, n_head, kvcache): # [n_seq, n_e return x, kvcache_updated -def gpt2(inputs, wte, wpe, blocks, ln_f, n_head, kvcache): # [n_seq] -> [n_seq, n_vocab] +def gpt2(inputs, wte, wpe, blocks, ln_f, n_head, kvcache = None): # [n_seq] -> [n_seq, n_vocab] if not kvcache: kvcache = [None]*len(blocks) wpe_out = wpe[range(len(inputs))]