Fix gpt2.py to work with Jax

The issue was that encoder returned a list, and `np.append()` expects a
Jax array, not a list as arguments.

This patch makes it work with both NumPy and Jax.

Fixes #9.
pull/10/head
Ondřej Čertík 2023-02-16 13:09:06 -07:00
rodzic 018a1e1796
commit c456f120ae
1 zmienionych plików z 2 dodań i 2 usunięć

Wyświetl plik

@ -89,7 +89,7 @@ def generate(inputs, params, n_head, n_tokens_to_generate):
for _ in tqdm(range(n_tokens_to_generate), "generating"): # auto-regressive decode loop for _ in tqdm(range(n_tokens_to_generate), "generating"): # auto-regressive decode loop
logits = gpt2(inputs, **params, n_head=n_head) # model forward pass logits = gpt2(inputs, **params, n_head=n_head) # model forward pass
next_id = np.argmax(logits[-1]) # greedy sampling next_id = np.argmax(logits[-1]) # greedy sampling
inputs = np.append(inputs, [next_id]) # append prediction to input inputs = np.append(inputs, next_id) # append prediction to input
return list(inputs[len(inputs) - n_tokens_to_generate :]) # only return generated ids return list(inputs[len(inputs) - n_tokens_to_generate :]) # only return generated ids
@ -101,7 +101,7 @@ def main(prompt: str, n_tokens_to_generate: int = 40, model_size: str = "124M",
encoder, hparams, params = load_encoder_hparams_and_params(model_size, models_dir) encoder, hparams, params = load_encoder_hparams_and_params(model_size, models_dir)
# encode the input string using the BPE tokenizer # encode the input string using the BPE tokenizer
input_ids = encoder.encode(prompt) input_ids = np.array(encoder.encode(prompt))
# make sure we are not surpassing the max sequence length of our model # make sure we are not surpassing the max sequence length of our model
assert len(input_ids) + n_tokens_to_generate < hparams["n_ctx"] assert len(input_ids) + n_tokens_to_generate < hparams["n_ctx"]