From 6487b41e0337be1e6da1ab0db51311b2f45cfac3 Mon Sep 17 00:00:00 2001 From: M-Daniyal-123 <55560086+M-Daniyal-123@users.noreply.github.com> Date: Fri, 24 Feb 2023 19:19:50 +0500 Subject: [PATCH] Added temperature --- gpt2_pico.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/gpt2_pico.py b/gpt2_pico.py index 762ed12..a8ba29d 100644 --- a/gpt2_pico.py +++ b/gpt2_pico.py @@ -40,20 +40,20 @@ def gpt2(inputs, wte, wpe, blocks, ln_f, n_head): x = transformer_block(x, **block, n_head=n_head) return layer_norm(x, **ln_f) @ wte.T -def generate(inputs, params, n_head, n_tokens_to_generate): +def generate(inputs, params, n_head, n_tokens_to_generate,temperature = 1.0): from tqdm import tqdm for _ in tqdm(range(n_tokens_to_generate), "generating"): logits = gpt2(inputs, **params, n_head=n_head) - next_id = np.argmax(logits[-1]) + next_id = np.argmax(softmax(logits[-1])/temperature) inputs.append(int(next_id)) return inputs[len(inputs) - n_tokens_to_generate :] -def main(prompt: str, n_tokens_to_generate: int = 40, model_size: str = "124M", models_dir: str = "models"): +def main(prompt: str, n_tokens_to_generate: int = 40, model_size: str = "124M", models_dir: str = "models",temperature = 1.0): from utils import load_encoder_hparams_and_params encoder, hparams, params = load_encoder_hparams_and_params(model_size, models_dir) input_ids = encoder.encode(prompt) assert len(input_ids) + n_tokens_to_generate < hparams["n_ctx"] - output_ids = generate(input_ids, params, hparams["n_head"], n_tokens_to_generate) + output_ids = generate(input_ids, params, hparams["n_head"], n_tokens_to_generate,temperature) output_text = encoder.decode(output_ids) return output_text