add full support for JAX in utils.py

pull/17/head
aalmarhabi 2023-02-21 16:04:43 -05:00
rodzic dfb5df895a
commit 42fe558d6c
2 zmienionych plików z 2 dodań i 2 usunięć

Wyświetl plik

@ -72,7 +72,7 @@ def transformer_block(x, mlp, attn, ln_1, ln_2, n_head): # [n_seq, n_embd] -> [
def gpt2(inputs, wte, wpe, blocks, ln_f, n_head): # [n_seq] -> [n_seq, n_vocab]
# token + positional embeddings
x = wte[inputs] + wpe[range(len(inputs))] # [n_seq] -> [n_seq, n_embd]
x = wte[np.array(inputs)] + wpe[np.array(range(len(inputs)))] # [n_seq] -> [n_seq, n_embd]
# forward pass through n_layer transformer blocks
for block in blocks:

Wyświetl plik

@ -35,7 +35,7 @@ def transformer_block(x, mlp, attn, ln_1, ln_2, n_head):
return x
def gpt2(inputs, wte, wpe, blocks, ln_f, n_head):
x = wte[inputs] + wpe[range(len(inputs))]
x = wte[np.array(inputs)] + wpe[np.array(range(len(inputs)))]
for block in blocks:
x = transformer_block(x, **block, n_head=n_head)
return layer_norm(x, **ln_f) @ wte.T