Ensure the computation stays in single precision

On my computer this change makes picoGPT several times faster.
pull/12/head
Ondřej Čertík 2023-02-16 21:07:50 -07:00
rodzic f7dfc78ffa
commit f1b88a4943
2 zmienionych plików z 2 dodań i 2 usunięć

Wyświetl plik

@ -46,7 +46,7 @@ def mha(x, c_attn, c_proj, n_head): # [n_seq, n_embd] -> [n_seq, n_embd]
qkv_heads = list(map(lambda x: np.split(x, n_head, axis=-1), qkv)) # [3, n_seq, n_embd] -> [3, n_head, n_seq, n_embd/n_head]
# causal mask to hide future inputs from being attended to
causal_mask = (1 - np.tri(x.shape[0])) * -1e10 # [n_seq, n_seq]
causal_mask = (1 - np.tri(x.shape[0], dtype=np.float32)) * -1e10 # [n_seq, n_seq]
# perform attention over each head
out_heads = [attention(q, k, v, causal_mask) for q, k, v in zip(*qkv_heads)] # [3, n_head, n_seq, n_embd/n_head] -> [n_head, n_seq, n_embd/n_head]

Wyświetl plik

@ -24,7 +24,7 @@ def attention(q, k, v, mask):
def mha(x, c_attn, c_proj, n_head):
x = linear(x, **c_attn)
qkv_heads = list(map(lambda x: np.split(x, n_head, axis=-1), np.split(x, 3, axis=-1)))
causal_mask = (1 - np.tri(x.shape[0])) * -1e10
causal_mask = (1 - np.tri(x.shape[0], dtype=np.float32)) * -1e10
out_heads = [attention(q, k, v, causal_mask) for q, k, v in zip(*qkv_heads)]
x = linear(np.hstack(out_heads), **c_proj)
return x