kopia lustrzana https://github.com/jaymody/picoGPT
Fix gpt2.py to work with Jax
The issue was that encoder returned a list, and `np.append()` expects a Jax array, not a list as arguments. This patch makes it work with both NumPy and Jax. Fixes #9.pull/10/head
rodzic
018a1e1796
commit
c456f120ae
4
gpt2.py
4
gpt2.py
|
@ -89,7 +89,7 @@ def generate(inputs, params, n_head, n_tokens_to_generate):
|
|||
for _ in tqdm(range(n_tokens_to_generate), "generating"): # auto-regressive decode loop
|
||||
logits = gpt2(inputs, **params, n_head=n_head) # model forward pass
|
||||
next_id = np.argmax(logits[-1]) # greedy sampling
|
||||
inputs = np.append(inputs, [next_id]) # append prediction to input
|
||||
inputs = np.append(inputs, next_id) # append prediction to input
|
||||
|
||||
return list(inputs[len(inputs) - n_tokens_to_generate :]) # only return generated ids
|
||||
|
||||
|
@ -101,7 +101,7 @@ def main(prompt: str, n_tokens_to_generate: int = 40, model_size: str = "124M",
|
|||
encoder, hparams, params = load_encoder_hparams_and_params(model_size, models_dir)
|
||||
|
||||
# encode the input string using the BPE tokenizer
|
||||
input_ids = encoder.encode(prompt)
|
||||
input_ids = np.array(encoder.encode(prompt))
|
||||
|
||||
# make sure we are not surpassing the max sequence length of our model
|
||||
assert len(input_ids) + n_tokens_to_generate < hparams["n_ctx"]
|
||||
|
|
Ładowanie…
Reference in New Issue