kopia lustrzana https://github.com/jaymody/picoGPT
commit
e4f1d30c99
|
@ -14,7 +14,7 @@ picoGPT features:
|
|||
* Fast? ❌ Nah, picoGPT is megaSLOW 🐌
|
||||
* Training code? ❌ Error, 4️⃣0️⃣4️⃣ not found
|
||||
* Batch inference? ❌ picoGPT is civilized, single file line, one at a time only
|
||||
* top-p sampling? ❌ top-k? ❌ temperature? ❌ categorical sampling?! ❌ greedy? ✅
|
||||
* top-p sampling? ❌ top-k? ❌ categorical sampling?! ❌ greedy? ✅
|
||||
* Readable? `gpt2.py` ✅ `gpt2_pico.py` ❌
|
||||
* Smol??? ✅✅✅✅✅✅ YESS!!! TEENIE TINY in fact 🤏
|
||||
|
||||
|
|
2
gpt2.py
2
gpt2.py
|
@ -92,7 +92,7 @@ def generate(inputs, params, n_head, n_tokens_to_generate,temperature):
|
|||
|
||||
for _ in tqdm(range(n_tokens_to_generate), "generating"): # auto-regressive decode loop
|
||||
logits = gpt2(inputs, **params, n_head=n_head) # model forward pass
|
||||
next_id = np.argmax(logits[-1]) # greedy sampling
|
||||
next_id = np.argmax(softmax(logits[-1])/temperature) # greedy sampling ## Added Temperature
|
||||
inputs.append(int(next_id)) # append prediction to input
|
||||
|
||||
return inputs[len(inputs) - n_tokens_to_generate :] # only return generated ids
|
||||
|
|
|
@ -40,20 +40,20 @@ def gpt2(inputs, wte, wpe, blocks, ln_f, n_head):
|
|||
x = transformer_block(x, **block, n_head=n_head)
|
||||
return layer_norm(x, **ln_f) @ wte.T
|
||||
|
||||
def generate(inputs, params, n_head, n_tokens_to_generate):
|
||||
def generate(inputs, params, n_head, n_tokens_to_generate,temperature = 1.0):
|
||||
from tqdm import tqdm
|
||||
for _ in tqdm(range(n_tokens_to_generate), "generating"):
|
||||
logits = gpt2(inputs, **params, n_head=n_head)
|
||||
next_id = np.argmax(logits[-1])
|
||||
next_id = np.argmax(softmax(logits[-1])/temperature)
|
||||
inputs.append(int(next_id))
|
||||
return inputs[len(inputs) - n_tokens_to_generate :]
|
||||
|
||||
def main(prompt: str, n_tokens_to_generate: int = 40, model_size: str = "124M", models_dir: str = "models"):
|
||||
def main(prompt: str, n_tokens_to_generate: int = 40, model_size: str = "124M", models_dir: str = "models",temperature = 1.0):
|
||||
from utils import load_encoder_hparams_and_params
|
||||
encoder, hparams, params = load_encoder_hparams_and_params(model_size, models_dir)
|
||||
input_ids = encoder.encode(prompt)
|
||||
assert len(input_ids) + n_tokens_to_generate < hparams["n_ctx"]
|
||||
output_ids = generate(input_ids, params, hparams["n_head"], n_tokens_to_generate)
|
||||
output_ids = generate(input_ids, params, hparams["n_head"], n_tokens_to_generate,temperature)
|
||||
output_text = encoder.decode(output_ids)
|
||||
return output_text
|
||||
|
||||
|
|
Ładowanie…
Reference in New Issue