diff --git a/gpt2.py b/gpt2.py index a4fd175..5fef383 100644 --- a/gpt2.py +++ b/gpt2.py @@ -45,11 +45,11 @@ def mha(x, c_attn, c_proj, n_head): # [n_seq, n_embd] -> [n_seq, n_embd] # split into heads qkv_heads = list(map(lambda x: np.split(x, n_head, axis=-1), qkv)) # [3, n_seq, n_embd] -> [n_head, 3, n_seq, n_embd/n_head] - # casual mask to hide future inputs from being attended to - casual_mask = (1 - np.tri(x.shape[0])) * -1e10 # [n_seq, n_seq] + # causal mask to hide future inputs from being attended to + causal_mask = (1 - np.tri(x.shape[0])) * -1e10 # [n_seq, n_seq] # perform attention over each head - out_heads = [attention(q, k, v, casual_mask) for q, k, v in zip(*qkv_heads)] # [n_head, 3, n_seq, n_embd/n_head] -> [n_head, n_seq, n_embd/n_head] + out_heads = [attention(q, k, v, causal_mask) for q, k, v in zip(*qkv_heads)] # [n_head, 3, n_seq, n_embd/n_head] -> [n_head, n_seq, n_embd/n_head] # merge heads x = np.hstack(out_heads) # [n_head, n_seq, n_embd/n_head] -> [n_seq, n_embd] @@ -61,7 +61,7 @@ def mha(x, c_attn, c_proj, n_head): # [n_seq, n_embd] -> [n_seq, n_embd] def transformer_block(x, mlp, attn, ln_1, ln_2, n_head): # [n_seq, n_embd] -> [n_seq, n_embd] - # multi-head casual self attention + # multi-head causal self attention x = x + mha(layer_norm(x, **ln_1), **attn, n_head=n_head) # [n_seq, n_embd] -> [n_seq, n_embd] # position-wise feed forward network diff --git a/gpt2_pico.py b/gpt2_pico.py index 42cac00..904a97d 100644 --- a/gpt2_pico.py +++ b/gpt2_pico.py @@ -24,8 +24,8 @@ def attention(q, k, v, mask): def mha(x, c_attn, c_proj, n_head): x = linear(x, **c_attn) qkv_heads = list(map(lambda x: np.split(x, n_head, axis=-1), np.split(x, 3, axis=-1))) - casual_mask = (1 - np.tri(x.shape[0])) * -1e10 - out_heads = [attention(q, k, v, casual_mask) for q, k, v in zip(*qkv_heads)] + causal_mask = (1 - np.tri(x.shape[0])) * -1e10 + out_heads = [attention(q, k, v, causal_mask) for q, k, v in zip(*qkv_heads)] x = linear(np.hstack(out_heads), **c_proj) return x