Merge pull request #2 from M-Daniyal-123/daniyal_branch

Daniyal branch
pull/20/head
M-Daniyal-123 2023-02-24 19:27:08 +05:00 zatwierdzone przez GitHub
commit e4f1d30c99
Nie znaleziono w bazie danych klucza dla tego podpisu
ID klucza GPG: 4AEE18F83AFDEB23
3 zmienionych plików z 6 dodań i 6 usunięć

Wyświetl plik

@ -14,7 +14,7 @@ picoGPT features:
* Fast? ❌ Nah, picoGPT is megaSLOW 🐌
* Training code? ❌ Error, 4⃣0⃣4⃣ not found
* Batch inference? ❌ picoGPT is civilized, single file line, one at a time only
* top-p sampling? ❌ top-k? ❌ temperature? ❌ categorical sampling?! ❌ greedy? ✅
* top-p sampling? ❌ top-k? ❌ categorical sampling?! ❌ greedy? ✅
* Readable? `gpt2.py``gpt2_pico.py`
* Smol??? ✅✅✅✅✅✅ YESS!!! TEENIE TINY in fact 🤏

Wyświetl plik

@ -92,7 +92,7 @@ def generate(inputs, params, n_head, n_tokens_to_generate,temperature):
for _ in tqdm(range(n_tokens_to_generate), "generating"): # auto-regressive decode loop
logits = gpt2(inputs, **params, n_head=n_head) # model forward pass
next_id = np.argmax(logits[-1]) # greedy sampling
next_id = np.argmax(softmax(logits[-1])/temperature) # greedy sampling ## Added Temperature
inputs.append(int(next_id)) # append prediction to input
return inputs[len(inputs) - n_tokens_to_generate :] # only return generated ids

Wyświetl plik

@ -40,20 +40,20 @@ def gpt2(inputs, wte, wpe, blocks, ln_f, n_head):
x = transformer_block(x, **block, n_head=n_head)
return layer_norm(x, **ln_f) @ wte.T
def generate(inputs, params, n_head, n_tokens_to_generate):
def generate(inputs, params, n_head, n_tokens_to_generate,temperature = 1.0):
from tqdm import tqdm
for _ in tqdm(range(n_tokens_to_generate), "generating"):
logits = gpt2(inputs, **params, n_head=n_head)
next_id = np.argmax(logits[-1])
next_id = np.argmax(softmax(logits[-1])/temperature)
inputs.append(int(next_id))
return inputs[len(inputs) - n_tokens_to_generate :]
def main(prompt: str, n_tokens_to_generate: int = 40, model_size: str = "124M", models_dir: str = "models"):
def main(prompt: str, n_tokens_to_generate: int = 40, model_size: str = "124M", models_dir: str = "models",temperature = 1.0):
from utils import load_encoder_hparams_and_params
encoder, hparams, params = load_encoder_hparams_and_params(model_size, models_dir)
input_ids = encoder.encode(prompt)
assert len(input_ids) + n_tokens_to_generate < hparams["n_ctx"]
output_ids = generate(input_ids, params, hparams["n_head"], n_tokens_to_generate)
output_ids = generate(input_ids, params, hparams["n_head"], n_tokens_to_generate,temperature)
output_text = encoder.decode(output_ids)
return output_text