From f1b88a49437dfcd99c5865cc660126ff8a3397d4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ond=C5=99ej=20=C4=8Cert=C3=ADk?= Date: Thu, 16 Feb 2023 21:07:50 -0700 Subject: [PATCH] Ensure the computation stays in single precision On my computer this change makes picoGPT several times faster. --- gpt2.py | 2 +- gpt2_pico.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/gpt2.py b/gpt2.py index e227ec2..b4fc9a9 100644 --- a/gpt2.py +++ b/gpt2.py @@ -46,7 +46,7 @@ def mha(x, c_attn, c_proj, n_head): # [n_seq, n_embd] -> [n_seq, n_embd] qkv_heads = list(map(lambda x: np.split(x, n_head, axis=-1), qkv)) # [3, n_seq, n_embd] -> [3, n_head, n_seq, n_embd/n_head] # causal mask to hide future inputs from being attended to - causal_mask = (1 - np.tri(x.shape[0])) * -1e10 # [n_seq, n_seq] + causal_mask = (1 - np.tri(x.shape[0], dtype=np.float32)) * -1e10 # [n_seq, n_seq] # perform attention over each head out_heads = [attention(q, k, v, causal_mask) for q, k, v in zip(*qkv_heads)] # [3, n_head, n_seq, n_embd/n_head] -> [n_head, n_seq, n_embd/n_head] diff --git a/gpt2_pico.py b/gpt2_pico.py index 6c4bf80..4e411b6 100644 --- a/gpt2_pico.py +++ b/gpt2_pico.py @@ -24,7 +24,7 @@ def attention(q, k, v, mask): def mha(x, c_attn, c_proj, n_head): x = linear(x, **c_attn) qkv_heads = list(map(lambda x: np.split(x, n_head, axis=-1), np.split(x, 3, axis=-1))) - causal_mask = (1 - np.tri(x.shape[0])) * -1e10 + causal_mask = (1 - np.tri(x.shape[0], dtype=np.float32)) * -1e10 out_heads = [attention(q, k, v, causal_mask) for q, k, v in zip(*qkv_heads)] x = linear(np.hstack(out_heads), **c_proj) return x