kopia lustrzana https://github.com/jaymody/picoGPT
Fix gpt2.py to work with Jax (#10)
rodzic
018a1e1796
commit
f7dfc78ffa
4
gpt2.py
4
gpt2.py
|
@ -89,9 +89,9 @@ def generate(inputs, params, n_head, n_tokens_to_generate):
|
||||||
for _ in tqdm(range(n_tokens_to_generate), "generating"): # auto-regressive decode loop
|
for _ in tqdm(range(n_tokens_to_generate), "generating"): # auto-regressive decode loop
|
||||||
logits = gpt2(inputs, **params, n_head=n_head) # model forward pass
|
logits = gpt2(inputs, **params, n_head=n_head) # model forward pass
|
||||||
next_id = np.argmax(logits[-1]) # greedy sampling
|
next_id = np.argmax(logits[-1]) # greedy sampling
|
||||||
inputs = np.append(inputs, [next_id]) # append prediction to input
|
inputs.append(int(next_id)) # append prediction to input
|
||||||
|
|
||||||
return list(inputs[len(inputs) - n_tokens_to_generate :]) # only return generated ids
|
return inputs[len(inputs) - n_tokens_to_generate :] # only return generated ids
|
||||||
|
|
||||||
|
|
||||||
def main(prompt: str, n_tokens_to_generate: int = 40, model_size: str = "124M", models_dir: str = "models"):
|
def main(prompt: str, n_tokens_to_generate: int = 40, model_size: str = "124M", models_dir: str = "models"):
|
||||||
|
|
|
@ -45,8 +45,8 @@ def generate(inputs, params, n_head, n_tokens_to_generate):
|
||||||
for _ in tqdm(range(n_tokens_to_generate), "generating"):
|
for _ in tqdm(range(n_tokens_to_generate), "generating"):
|
||||||
logits = gpt2(inputs, **params, n_head=n_head)
|
logits = gpt2(inputs, **params, n_head=n_head)
|
||||||
next_id = np.argmax(logits[-1])
|
next_id = np.argmax(logits[-1])
|
||||||
inputs = np.append(inputs, [next_id])
|
inputs.append(int(next_id))
|
||||||
return list(inputs[len(inputs) - n_tokens_to_generate :])
|
return inputs[len(inputs) - n_tokens_to_generate :]
|
||||||
|
|
||||||
def main(prompt: str, n_tokens_to_generate: int = 40, model_size: str = "124M", models_dir: str = "models"):
|
def main(prompt: str, n_tokens_to_generate: int = 40, model_size: str = "124M", models_dir: str = "models"):
|
||||||
from utils import load_encoder_hparams_and_params
|
from utils import load_encoder_hparams_and_params
|
||||||
|
|
Ładowanie…
Reference in New Issue