kopia lustrzana https://github.com/jaymody/picoGPT
add full support for JAX in utils.py
rodzic
dfb5df895a
commit
42fe558d6c
2
gpt2.py
2
gpt2.py
|
@ -72,7 +72,7 @@ def transformer_block(x, mlp, attn, ln_1, ln_2, n_head): # [n_seq, n_embd] -> [
|
|||
|
||||
def gpt2(inputs, wte, wpe, blocks, ln_f, n_head): # [n_seq] -> [n_seq, n_vocab]
|
||||
# token + positional embeddings
|
||||
x = wte[inputs] + wpe[range(len(inputs))] # [n_seq] -> [n_seq, n_embd]
|
||||
x = wte[np.array(inputs)] + wpe[np.array(range(len(inputs)))] # [n_seq] -> [n_seq, n_embd]
|
||||
|
||||
# forward pass through n_layer transformer blocks
|
||||
for block in blocks:
|
||||
|
|
|
@ -35,7 +35,7 @@ def transformer_block(x, mlp, attn, ln_1, ln_2, n_head):
|
|||
return x
|
||||
|
||||
def gpt2(inputs, wte, wpe, blocks, ln_f, n_head):
|
||||
x = wte[inputs] + wpe[range(len(inputs))]
|
||||
x = wte[np.array(inputs)] + wpe[np.array(range(len(inputs)))]
|
||||
for block in blocks:
|
||||
x = transformer_block(x, **block, n_head=n_head)
|
||||
return layer_norm(x, **ln_f) @ wte.T
|
||||
|
|
Ładowanie…
Reference in New Issue