kopia lustrzana https://github.com/jaymody/picoGPT
Ensure the computation stays in single precision
On my computer this change makes picoGPT several times faster.pull/12/head
rodzic
f7dfc78ffa
commit
f1b88a4943
2
gpt2.py
2
gpt2.py
|
@ -46,7 +46,7 @@ def mha(x, c_attn, c_proj, n_head): # [n_seq, n_embd] -> [n_seq, n_embd]
|
|||
qkv_heads = list(map(lambda x: np.split(x, n_head, axis=-1), qkv)) # [3, n_seq, n_embd] -> [3, n_head, n_seq, n_embd/n_head]
|
||||
|
||||
# causal mask to hide future inputs from being attended to
|
||||
causal_mask = (1 - np.tri(x.shape[0])) * -1e10 # [n_seq, n_seq]
|
||||
causal_mask = (1 - np.tri(x.shape[0], dtype=np.float32)) * -1e10 # [n_seq, n_seq]
|
||||
|
||||
# perform attention over each head
|
||||
out_heads = [attention(q, k, v, causal_mask) for q, k, v in zip(*qkv_heads)] # [3, n_head, n_seq, n_embd/n_head] -> [n_head, n_seq, n_embd/n_head]
|
||||
|
|
|
@ -24,7 +24,7 @@ def attention(q, k, v, mask):
|
|||
def mha(x, c_attn, c_proj, n_head):
|
||||
x = linear(x, **c_attn)
|
||||
qkv_heads = list(map(lambda x: np.split(x, n_head, axis=-1), np.split(x, 3, axis=-1)))
|
||||
causal_mask = (1 - np.tri(x.shape[0])) * -1e10
|
||||
causal_mask = (1 - np.tri(x.shape[0], dtype=np.float32)) * -1e10
|
||||
out_heads = [attention(q, k, v, causal_mask) for q, k, v in zip(*qkv_heads)]
|
||||
x = linear(np.hstack(out_heads), **c_proj)
|
||||
return x
|
||||
|
|
Ładowanie…
Reference in New Issue