kopia lustrzana https://github.com/jaymody/picoGPT
I can't spell.
rodzic
29e78cc52b
commit
d4e955d0ca
8
gpt2.py
8
gpt2.py
|
@ -45,11 +45,11 @@ def mha(x, c_attn, c_proj, n_head): # [n_seq, n_embd] -> [n_seq, n_embd]
|
||||||
# split into heads
|
# split into heads
|
||||||
qkv_heads = list(map(lambda x: np.split(x, n_head, axis=-1), qkv)) # [3, n_seq, n_embd] -> [n_head, 3, n_seq, n_embd/n_head]
|
qkv_heads = list(map(lambda x: np.split(x, n_head, axis=-1), qkv)) # [3, n_seq, n_embd] -> [n_head, 3, n_seq, n_embd/n_head]
|
||||||
|
|
||||||
# casual mask to hide future inputs from being attended to
|
# causal mask to hide future inputs from being attended to
|
||||||
casual_mask = (1 - np.tri(x.shape[0])) * -1e10 # [n_seq, n_seq]
|
causal_mask = (1 - np.tri(x.shape[0])) * -1e10 # [n_seq, n_seq]
|
||||||
|
|
||||||
# perform attention over each head
|
# perform attention over each head
|
||||||
out_heads = [attention(q, k, v, casual_mask) for q, k, v in zip(*qkv_heads)] # [n_head, 3, n_seq, n_embd/n_head] -> [n_head, n_seq, n_embd/n_head]
|
out_heads = [attention(q, k, v, causal_mask) for q, k, v in zip(*qkv_heads)] # [n_head, 3, n_seq, n_embd/n_head] -> [n_head, n_seq, n_embd/n_head]
|
||||||
|
|
||||||
# merge heads
|
# merge heads
|
||||||
x = np.hstack(out_heads) # [n_head, n_seq, n_embd/n_head] -> [n_seq, n_embd]
|
x = np.hstack(out_heads) # [n_head, n_seq, n_embd/n_head] -> [n_seq, n_embd]
|
||||||
|
@ -61,7 +61,7 @@ def mha(x, c_attn, c_proj, n_head): # [n_seq, n_embd] -> [n_seq, n_embd]
|
||||||
|
|
||||||
|
|
||||||
def transformer_block(x, mlp, attn, ln_1, ln_2, n_head): # [n_seq, n_embd] -> [n_seq, n_embd]
|
def transformer_block(x, mlp, attn, ln_1, ln_2, n_head): # [n_seq, n_embd] -> [n_seq, n_embd]
|
||||||
# multi-head casual self attention
|
# multi-head causal self attention
|
||||||
x = x + mha(layer_norm(x, **ln_1), **attn, n_head=n_head) # [n_seq, n_embd] -> [n_seq, n_embd]
|
x = x + mha(layer_norm(x, **ln_1), **attn, n_head=n_head) # [n_seq, n_embd] -> [n_seq, n_embd]
|
||||||
|
|
||||||
# position-wise feed forward network
|
# position-wise feed forward network
|
||||||
|
|
|
@ -24,8 +24,8 @@ def attention(q, k, v, mask):
|
||||||
def mha(x, c_attn, c_proj, n_head):
|
def mha(x, c_attn, c_proj, n_head):
|
||||||
x = linear(x, **c_attn)
|
x = linear(x, **c_attn)
|
||||||
qkv_heads = list(map(lambda x: np.split(x, n_head, axis=-1), np.split(x, 3, axis=-1)))
|
qkv_heads = list(map(lambda x: np.split(x, n_head, axis=-1), np.split(x, 3, axis=-1)))
|
||||||
casual_mask = (1 - np.tri(x.shape[0])) * -1e10
|
causal_mask = (1 - np.tri(x.shape[0])) * -1e10
|
||||||
out_heads = [attention(q, k, v, casual_mask) for q, k, v in zip(*qkv_heads)]
|
out_heads = [attention(q, k, v, causal_mask) for q, k, v in zip(*qkv_heads)]
|
||||||
x = linear(np.hstack(out_heads), **c_proj)
|
x = linear(np.hstack(out_heads), **c_proj)
|
||||||
return x
|
return x
|
||||||
|
|
||||||
|
|
Ładowanie…
Reference in New Issue