kopia lustrzana https://github.com/jaymody/picoGPT
feat: added few comments and renamed symbol for more clearility
rodzic
0c1dd6c466
commit
d663909cfb
17
gpt2.py
17
gpt2.py
|
@ -35,8 +35,9 @@ def attention(q, k, v, mask): # [n_q, d_k], [n_k, d_k], [n_k, d_v], [n_q, n_k]
|
||||||
return softmax(q @ k.T / np.sqrt(q.shape[-1]) + mask) @ v
|
return softmax(q @ k.T / np.sqrt(q.shape[-1]) + mask) @ v
|
||||||
|
|
||||||
|
|
||||||
def mha(x, c_attn, c_proj, n_head, kvcache): # [n_seq, n_embd] -> [n_seq, n_embd]
|
def mha(x, c_attn, c_proj, n_head, kvcache=None): # [n_seq, n_embd] -> [n_seq, n_embd]
|
||||||
# qkv projection
|
# qkv projection
|
||||||
|
# when we pass kvcache, n_seq = 1. so we will compute new_q, new_k and new_v
|
||||||
x = linear(x, **c_attn) # [n_seq, n_embd] -> [n_seq, 3*n_embd]
|
x = linear(x, **c_attn) # [n_seq, n_embd] -> [n_seq, 3*n_embd]
|
||||||
|
|
||||||
# split into qkv
|
# split into qkv
|
||||||
|
@ -44,11 +45,11 @@ def mha(x, c_attn, c_proj, n_head, kvcache): # [n_seq, n_embd] -> [n_seq, n_emb
|
||||||
|
|
||||||
if kvcache:
|
if kvcache:
|
||||||
# qkv
|
# qkv
|
||||||
q, k, v = qkv
|
new_q, new_k, new_v = qkv # new_q, new_k, new_v = [1, n_embd]
|
||||||
old_k, old_v = kvcache
|
old_k, old_v = kvcache
|
||||||
k = np.vstack([old_k, k])
|
k = np.vstack([old_k, new_k]) # k = [n_seq, n_embd], where n_seq = prev_n_seq + 1
|
||||||
v = np.vstack([old_v, v])
|
v = np.vstack([old_v, new_v]) # v = [n_seq, n_embd], where n_seq = prev_n_seq + 1
|
||||||
qkv = [q, k, v]
|
qkv = [new_q, k, v]
|
||||||
|
|
||||||
current_cache = [qkv[1], qkv[2]]
|
current_cache = [qkv[1], qkv[2]]
|
||||||
|
|
||||||
|
@ -57,8 +58,10 @@ def mha(x, c_attn, c_proj, n_head, kvcache): # [n_seq, n_embd] -> [n_seq, n_emb
|
||||||
|
|
||||||
# causal mask to hide future inputs from being attended to
|
# causal mask to hide future inputs from being attended to
|
||||||
if kvcache:
|
if kvcache:
|
||||||
|
# when we pass kvcache, we are passing single token as input which need to attend to all previous tokens, so we create mask with all 0s
|
||||||
causal_mask = np.zeros((1, k.shape[0]))
|
causal_mask = np.zeros((1, k.shape[0]))
|
||||||
else:
|
else:
|
||||||
|
# create triangular causal mask
|
||||||
causal_mask = (1 - np.tri(x.shape[0])) * -1e10 # [n_seq, n_seq]
|
causal_mask = (1 - np.tri(x.shape[0])) * -1e10 # [n_seq, n_seq]
|
||||||
|
|
||||||
# perform attention over each head
|
# perform attention over each head
|
||||||
|
@ -74,7 +77,7 @@ def mha(x, c_attn, c_proj, n_head, kvcache): # [n_seq, n_embd] -> [n_seq, n_emb
|
||||||
return x, current_cache
|
return x, current_cache
|
||||||
|
|
||||||
|
|
||||||
def transformer_block(x, mlp, attn, ln_1, ln_2, n_head, kvcache): # [n_seq, n_embd] -> [n_seq, n_embd]
|
def transformer_block(x, mlp, attn, ln_1, ln_2, n_head, kvcache=None): # [n_seq, n_embd] -> [n_seq, n_embd]
|
||||||
# multi-head causal self attention
|
# multi-head causal self attention
|
||||||
attn_out, kvcache_updated = mha(layer_norm(x, **ln_1), **attn, n_head=n_head, kvcache=kvcache)
|
attn_out, kvcache_updated = mha(layer_norm(x, **ln_1), **attn, n_head=n_head, kvcache=kvcache)
|
||||||
x = x + attn_out # [n_seq, n_embd] -> [n_seq, n_embd]
|
x = x + attn_out # [n_seq, n_embd] -> [n_seq, n_embd]
|
||||||
|
@ -85,7 +88,7 @@ def transformer_block(x, mlp, attn, ln_1, ln_2, n_head, kvcache): # [n_seq, n_e
|
||||||
return x, kvcache_updated
|
return x, kvcache_updated
|
||||||
|
|
||||||
|
|
||||||
def gpt2(inputs, wte, wpe, blocks, ln_f, n_head, kvcache): # [n_seq] -> [n_seq, n_vocab]
|
def gpt2(inputs, wte, wpe, blocks, ln_f, n_head, kvcache = None): # [n_seq] -> [n_seq, n_vocab]
|
||||||
if not kvcache:
|
if not kvcache:
|
||||||
kvcache = [None]*len(blocks)
|
kvcache = [None]*len(blocks)
|
||||||
wpe_out = wpe[range(len(inputs))]
|
wpe_out = wpe[range(len(inputs))]
|
||||||
|
|
Ładowanie…
Reference in New Issue