diff --git a/README.md b/README.md index db77277..1bba6c7 100644 --- a/README.md +++ b/README.md @@ -14,7 +14,7 @@ picoGPT features: * Fast? ❌ Nah, picoGPT is megaSLOW 🐌 * Training code? ❌ Error, 4️⃣0️⃣4️⃣ not found * Batch inference? ❌ picoGPT is civilized, single file line, one at a time only -* top-p sampling? ❌ top-k? ❌ temperature? ❌ categorical sampling?! ❌ greedy? ✅ +* top-p sampling? ❌ top-k? ❌ categorical sampling?! ❌ greedy? ✅ * Readable? `gpt2.py` ✅ `gpt2_pico.py` ❌ * Smol??? ✅✅✅✅✅✅ YESS!!! TEENIE TINY in fact 🤏 diff --git a/gpt2.py b/gpt2.py index 7c39355..277ff4e 100644 --- a/gpt2.py +++ b/gpt2.py @@ -92,7 +92,7 @@ def generate(inputs, params, n_head, n_tokens_to_generate,temperature): for _ in tqdm(range(n_tokens_to_generate), "generating"): # auto-regressive decode loop logits = gpt2(inputs, **params, n_head=n_head) # model forward pass - next_id = np.argmax(logits[-1]) # greedy sampling + next_id = np.argmax(softmax(logits[-1])/temperature) # greedy sampling ## Added Temperature inputs.append(int(next_id)) # append prediction to input return inputs[len(inputs) - n_tokens_to_generate :] # only return generated ids diff --git a/gpt2_pico.py b/gpt2_pico.py index 762ed12..a8ba29d 100644 --- a/gpt2_pico.py +++ b/gpt2_pico.py @@ -40,20 +40,20 @@ def gpt2(inputs, wte, wpe, blocks, ln_f, n_head): x = transformer_block(x, **block, n_head=n_head) return layer_norm(x, **ln_f) @ wte.T -def generate(inputs, params, n_head, n_tokens_to_generate): +def generate(inputs, params, n_head, n_tokens_to_generate,temperature = 1.0): from tqdm import tqdm for _ in tqdm(range(n_tokens_to_generate), "generating"): logits = gpt2(inputs, **params, n_head=n_head) - next_id = np.argmax(logits[-1]) + next_id = np.argmax(softmax(logits[-1])/temperature) inputs.append(int(next_id)) return inputs[len(inputs) - n_tokens_to_generate :] -def main(prompt: str, n_tokens_to_generate: int = 40, model_size: str = "124M", models_dir: str = "models"): +def main(prompt: str, n_tokens_to_generate: int = 40, model_size: str = "124M", models_dir: str = "models",temperature = 1.0): from utils import load_encoder_hparams_and_params encoder, hparams, params = load_encoder_hparams_and_params(model_size, models_dir) input_ids = encoder.encode(prompt) assert len(input_ids) + n_tokens_to_generate < hparams["n_ctx"] - output_ids = generate(input_ids, params, hparams["n_head"], n_tokens_to_generate) + output_ids = generate(input_ids, params, hparams["n_head"], n_tokens_to_generate,temperature) output_text = encoder.decode(output_ids) return output_text