kopia lustrzana https://github.com/jaymody/picoGPT
Merge pull request #4 from jameshfisher/patch-1
Fix mixed up dimensions in shape commentspull/8/head
commit
018a1e1796
4
gpt2.py
4
gpt2.py
|
@ -43,13 +43,13 @@ def mha(x, c_attn, c_proj, n_head): # [n_seq, n_embd] -> [n_seq, n_embd]
|
|||
qkv = np.split(x, 3, axis=-1) # [n_seq, 3*n_embd] -> [3, n_seq, n_embd]
|
||||
|
||||
# split into heads
|
||||
qkv_heads = list(map(lambda x: np.split(x, n_head, axis=-1), qkv)) # [3, n_seq, n_embd] -> [n_head, 3, n_seq, n_embd/n_head]
|
||||
qkv_heads = list(map(lambda x: np.split(x, n_head, axis=-1), qkv)) # [3, n_seq, n_embd] -> [3, n_head, n_seq, n_embd/n_head]
|
||||
|
||||
# causal mask to hide future inputs from being attended to
|
||||
causal_mask = (1 - np.tri(x.shape[0])) * -1e10 # [n_seq, n_seq]
|
||||
|
||||
# perform attention over each head
|
||||
out_heads = [attention(q, k, v, causal_mask) for q, k, v in zip(*qkv_heads)] # [n_head, 3, n_seq, n_embd/n_head] -> [n_head, n_seq, n_embd/n_head]
|
||||
out_heads = [attention(q, k, v, causal_mask) for q, k, v in zip(*qkv_heads)] # [3, n_head, n_seq, n_embd/n_head] -> [n_head, n_seq, n_embd/n_head]
|
||||
|
||||
# merge heads
|
||||
x = np.hstack(out_heads) # [n_head, n_seq, n_embd/n_head] -> [n_seq, n_embd]
|
||||
|
|
Ładowanie…
Reference in New Issue