kopia lustrzana https://github.com/jaymody/picoGPT
Refactor.
rodzic
4cd64933bb
commit
29e78cc52b
31
README.md
31
README.md
|
@ -7,15 +7,14 @@ You've even seen [karpathy/nanoGPT](https://github.com/karpathy/nanogpt)!
|
|||
|
||||
But have you seen [picoGPT](https://github.com/jaymody/picoGPT)??!?
|
||||
|
||||
`picoGPT` is an unnecessarily tiny and minimal implementation of [GPT-2](https://d4mucfpksywv.cloudfront.net/better-language-models/language_models_are_unsupervised_multitask_learners.pdf) in plain [NumPy](https://numpy.org). The entire forward pass code is [40 lines of code](https://github.com/jaymody/picoGPT/blob/main/model.py#L3-L41).
|
||||
`picoGPT` is an unnecessarily tiny and minimal implementation of [GPT-2](https://d4mucfpksywv.cloudfront.net/better-language-models/language_models_are_unsupervised_multitask_learners.pdf) in plain [NumPy](https://numpy.org). The entire forward pass code is [40 lines of code](https://github.com/jaymody/picoGPT/blob/main/gpt2_pico.py#L3-L41).
|
||||
|
||||
So is picoGPT:
|
||||
* Fast? ❌ Nah, picoGPT is megaSLOW 🐌
|
||||
* Commented? ❌ You joking? That would add more lines 😤😤😤😤
|
||||
* Training code? ❌ Error, 4️⃣0️⃣4️⃣ not found
|
||||
* Batch inference? ❌ picoGPT is civilized, one at a time only, single file line
|
||||
* Support top-p? ❌ top-k? ❌ temperature? ❌ categorical sampling?! ❌ greedy? ✅
|
||||
* Readable? 🤔 Well actually, it's not too horrible on the eyes, and it's minimal!
|
||||
* Readable? 🤔 Well actually, it's not too horrible on the eyes!
|
||||
* Smol??? ✅✅✅✅✅✅ YESS!!! TEENIE TINY in fact 🤏
|
||||
|
||||
#### Dependencies
|
||||
|
@ -29,15 +28,25 @@ Tested on `Python 3.9.10`.
|
|||
|
||||
#### Usage
|
||||
```bash
|
||||
python main.py \
|
||||
--prompt "Alan Turing theorized that computers would one day become" \
|
||||
--model_size "124M" \
|
||||
--n_tokens_to_generate 40
|
||||
python gpt2.py "Alan Turing theorized that computers would one day become"
|
||||
```
|
||||
Which generates:
|
||||
```text
|
||||
the most powerful machines on the planet.
|
||||
|
||||
Which generates
|
||||
|
||||
```
|
||||
the most powerful machines on the planet.
|
||||
|
||||
The computer is a machine that can perform complex calculations, and it can perform these calculations in a way that is very similar to the human brain.
|
||||
```
|
||||
Use `python main.py --help` for a full list of options.
|
||||
|
||||
You can also control the number of tokens to generate, the model size (one of `["124M", "355M", "774M", "1558M"]`), and the directory to save the models:
|
||||
|
||||
```bash
|
||||
python gpt2.py \
|
||||
"Alan Turing theorized that computers would one day become" \
|
||||
--n_tokens_to_generate 40 \
|
||||
--model_size "124M" \
|
||||
--models_dir "models"
|
||||
```
|
||||
|
||||
`gpt2_pico.py` is the same as `gpt2.py`, but has even fewer lines of code (removed comments, extra whitespace, and combined certain operations into a single line). Why? Because why not.
|
||||
|
|
|
@ -0,0 +1,121 @@
|
|||
import numpy as np
|
||||
|
||||
|
||||
def gelu(x):
|
||||
return 0.5 * x * (1 + np.tanh(np.sqrt(2 / np.pi) * (x + 0.044715 * x**3)))
|
||||
|
||||
|
||||
def softmax(x):
|
||||
exp_x = np.exp(x - np.max(x, axis=-1, keepdims=True))
|
||||
return exp_x / np.sum(exp_x, axis=-1, keepdims=True)
|
||||
|
||||
|
||||
def layer_norm(x, g, b, eps: float = 1e-5):
|
||||
mean = np.mean(x, axis=-1, keepdims=True)
|
||||
variance = np.var(x, axis=-1, keepdims=True)
|
||||
x = (x - mean) / np.sqrt(variance + eps) # normalize x to have mean=0 and var=1 over last axis
|
||||
return g * x + b # scale and offset with gamma/beta params
|
||||
|
||||
|
||||
def linear(x, w, b): # [m, in], [in, out], [out] -> [m, out]
|
||||
return x @ w + b
|
||||
|
||||
|
||||
def ffn(x, c_fc, c_proj): # [n_seq, n_embd] -> [n_seq, n_embd]
|
||||
# project up
|
||||
a = gelu(linear(x, **c_fc)) # [n_seq, n_embd] -> [n_seq, 4*n_embd]
|
||||
|
||||
# project back down
|
||||
x = linear(a, **c_proj) # [n_seq, 4*n_embd] -> [n_seq, n_embd]
|
||||
|
||||
return x
|
||||
|
||||
|
||||
def attention(q, k, v, mask): # [n_q, d_k], [n_k, d_k], [n_k, d_v], [n_q, n_k] -> [n_q, d_v]
|
||||
return softmax(q @ k.T / np.sqrt(q.shape[-1]) + mask) @ v
|
||||
|
||||
|
||||
def mha(x, c_attn, c_proj, n_head): # [n_seq, n_embd] -> [n_seq, n_embd]
|
||||
# qkv projection
|
||||
x = linear(x, **c_attn) # [n_seq, n_embd] -> [n_seq, 3*n_embd]
|
||||
|
||||
# split into qkv
|
||||
qkv = np.split(x, 3, axis=-1) # [n_seq, 3*n_embd] -> [3, n_seq, n_embd]
|
||||
|
||||
# split into heads
|
||||
qkv_heads = list(map(lambda x: np.split(x, n_head, axis=-1), qkv)) # [3, n_seq, n_embd] -> [n_head, 3, n_seq, n_embd/n_head]
|
||||
|
||||
# casual mask to hide future inputs from being attended to
|
||||
casual_mask = (1 - np.tri(x.shape[0])) * -1e10 # [n_seq, n_seq]
|
||||
|
||||
# perform attention over each head
|
||||
out_heads = [attention(q, k, v, casual_mask) for q, k, v in zip(*qkv_heads)] # [n_head, 3, n_seq, n_embd/n_head] -> [n_head, n_seq, n_embd/n_head]
|
||||
|
||||
# merge heads
|
||||
x = np.hstack(out_heads) # [n_head, n_seq, n_embd/n_head] -> [n_seq, n_embd]
|
||||
|
||||
# out projection
|
||||
x = linear(x, **c_proj) # [n_seq, n_embd] -> [n_seq, n_embd]
|
||||
|
||||
return x
|
||||
|
||||
|
||||
def transformer_block(x, mlp, attn, ln_1, ln_2, n_head): # [n_seq, n_embd] -> [n_seq, n_embd]
|
||||
# multi-head casual self attention
|
||||
x = x + mha(layer_norm(x, **ln_1), **attn, n_head=n_head) # [n_seq, n_embd] -> [n_seq, n_embd]
|
||||
|
||||
# position-wise feed forward network
|
||||
x = x + ffn(layer_norm(x, **ln_2), **mlp) # [n_seq, n_embd] -> [n_seq, n_embd]
|
||||
|
||||
return x
|
||||
|
||||
|
||||
def gpt2(inputs, wte, wpe, blocks, ln_f, n_head): # [n_seq] -> [n_seq, n_vocab]
|
||||
# token + positional embeddings
|
||||
x = wte[inputs] + wpe[range(len(inputs))] # [n_seq] -> [n_seq, n_embd]
|
||||
|
||||
# forward pass through n_layer transformer blocks
|
||||
for block in blocks:
|
||||
x = transformer_block(x, **block, n_head=n_head) # [n_seq, n_embd] -> [n_seq, n_embd]
|
||||
|
||||
# projection to vocab
|
||||
x = layer_norm(x, **ln_f) # [n_seq, n_embd] -> [n_seq, n_embd]
|
||||
return x @ wte.T # [n_seq, n_embd] -> [n_seq, n_vocab]
|
||||
|
||||
|
||||
def generate(inputs, params, n_head, n_tokens_to_generate):
|
||||
from tqdm import tqdm
|
||||
|
||||
for _ in tqdm(range(n_tokens_to_generate), "generating"): # auto-regressive decode loop
|
||||
logits = gpt2(inputs, **params, n_head=n_head) # model forward pass
|
||||
next_id = np.argmax(logits[-1]) # greedy sampling
|
||||
inputs = np.append(inputs, [next_id]) # append prediction to input
|
||||
|
||||
return list(inputs[len(inputs) - n_tokens_to_generate :]) # only return generated ids
|
||||
|
||||
|
||||
def main(prompt: str, n_tokens_to_generate: int = 40, model_size: str = "124M", models_dir: str = "models"):
|
||||
from utils import load_encoder_hparams_and_params
|
||||
|
||||
# load encoder, hparams, and params from the released open-ai gpt-2 files
|
||||
encoder, hparams, params = load_encoder_hparams_and_params(model_size, models_dir)
|
||||
|
||||
# encode the input string using the BPE tokenizer
|
||||
input_ids = encoder.encode(prompt)
|
||||
|
||||
# make sure we are not surpassing the max sequence length of our model
|
||||
assert len(input_ids) + n_tokens_to_generate < hparams["n_ctx"]
|
||||
|
||||
# generate output ids
|
||||
output_ids = generate(input_ids, params, hparams["n_head"], n_tokens_to_generate)
|
||||
|
||||
# decode the ids back into a string
|
||||
output_text = encoder.decode(output_ids)
|
||||
|
||||
return output_text
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import fire
|
||||
|
||||
fire.Fire(main)
|
|
@ -0,0 +1,62 @@
|
|||
import numpy as np
|
||||
|
||||
def gelu(x):
|
||||
return 0.5 * x * (1 + np.tanh(np.sqrt(2 / np.pi) * (x + 0.044715 * x**3)))
|
||||
|
||||
def softmax(x):
|
||||
exp_x = np.exp(x - np.max(x, axis=-1, keepdims=True))
|
||||
return exp_x / np.sum(exp_x, axis=-1, keepdims=True)
|
||||
|
||||
def layer_norm(x, g, b, eps: float = 1e-5):
|
||||
mean = np.mean(x, axis=-1, keepdims=True)
|
||||
variance = np.var(x, axis=-1, keepdims=True)
|
||||
return g * (x - mean) / np.sqrt(variance + eps) + b
|
||||
|
||||
def linear(x, w, b):
|
||||
return x @ w + b
|
||||
|
||||
def ffn(x, c_fc, c_proj):
|
||||
return linear(gelu(linear(x, **c_fc)), **c_proj)
|
||||
|
||||
def attention(q, k, v, mask):
|
||||
return softmax(q @ k.T / np.sqrt(q.shape[-1]) + mask) @ v
|
||||
|
||||
def mha(x, c_attn, c_proj, n_head):
|
||||
x = linear(x, **c_attn)
|
||||
qkv_heads = list(map(lambda x: np.split(x, n_head, axis=-1), np.split(x, 3, axis=-1)))
|
||||
casual_mask = (1 - np.tri(x.shape[0])) * -1e10
|
||||
out_heads = [attention(q, k, v, casual_mask) for q, k, v in zip(*qkv_heads)]
|
||||
x = linear(np.hstack(out_heads), **c_proj)
|
||||
return x
|
||||
|
||||
def transformer_block(x, mlp, attn, ln_1, ln_2, n_head):
|
||||
x = x + mha(layer_norm(x, **ln_1), **attn, n_head=n_head)
|
||||
x = x + ffn(layer_norm(x, **ln_2), **mlp)
|
||||
return x
|
||||
|
||||
def gpt2(inputs, wte, wpe, blocks, ln_f, n_head):
|
||||
x = wte[inputs] + wpe[range(len(inputs))]
|
||||
for block in blocks:
|
||||
x = transformer_block(x, **block, n_head=n_head)
|
||||
return layer_norm(x, **ln_f) @ wte.T
|
||||
|
||||
def generate(inputs, params, n_head, n_tokens_to_generate):
|
||||
from tqdm import tqdm
|
||||
for _ in tqdm(range(n_tokens_to_generate), "generating"):
|
||||
logits = gpt2(inputs, **params, n_head=n_head)
|
||||
next_id = np.argmax(logits[-1])
|
||||
inputs = np.append(inputs, [next_id])
|
||||
return list(inputs[len(inputs) - n_tokens_to_generate :])
|
||||
|
||||
def main(prompt: str, n_tokens_to_generate: int = 40, model_size: str = "124M", models_dir: str = "models"):
|
||||
from utils import load_encoder_hparams_and_params
|
||||
encoder, hparams, params = load_encoder_hparams_and_params(model_size, models_dir)
|
||||
input_ids = encoder.encode(prompt)
|
||||
assert len(input_ids) + n_tokens_to_generate < hparams["n_ctx"]
|
||||
output_ids = generate(input_ids, params, hparams["n_head"], n_tokens_to_generate)
|
||||
output_text = encoder.decode(output_ids)
|
||||
return output_text
|
||||
|
||||
if __name__ == "__main__":
|
||||
import fire
|
||||
fire.Fire(main)
|
41
model.py
41
model.py
|
@ -1,41 +0,0 @@
|
|||
import numpy as np
|
||||
|
||||
def gelu(x):
|
||||
return 0.5 * x * (1 + np.tanh(np.sqrt(2 / np.pi) * (x + 0.044715 * x**3)))
|
||||
|
||||
def softmax(x, axis=-1):
|
||||
exp_x = np.exp(x - np.max(x, axis=axis, keepdims=True))
|
||||
return exp_x / np.sum(exp_x, axis=axis, keepdims=True)
|
||||
|
||||
def layer_norm(x, g, b, eps: float = 1e-5):
|
||||
mean = np.mean(x, axis=-1, keepdims=True)
|
||||
variance = np.var(x, axis=-1, keepdims=True)
|
||||
return g * (x - mean) / np.sqrt(variance + eps) + b
|
||||
|
||||
def linear(x, w, b):
|
||||
return x @ w + b
|
||||
|
||||
def ffn(x, c_fc, c_proj):
|
||||
return linear(gelu(linear(x, **c_fc)), **c_proj)
|
||||
|
||||
def attention(q, k, v, mask):
|
||||
return softmax(q @ k.T / np.sqrt(q.shape[-1]) + mask) @ v
|
||||
|
||||
def mha(x, c_attn, c_proj, h):
|
||||
x = linear(x, **c_attn)
|
||||
qkv = list(map(lambda x: np.split(x, h, axis=-1), np.split(x, 3, axis=-1)))
|
||||
casual_mask = (1 - np.tri(x.shape[0])) * -1e10
|
||||
heads = [attention(q, k, v, casual_mask) for q, k, v in zip(*qkv)]
|
||||
return linear(np.hstack(heads), **c_proj)
|
||||
|
||||
def block(x, mlp, attn, ln_1, ln_2, h):
|
||||
x = x + mha(layer_norm(x, **ln_1), **attn, h=h)
|
||||
x = x + ffn(layer_norm(x, **ln_2), **mlp)
|
||||
return x
|
||||
|
||||
def gpt2(ids, wte, wpe, blocks, ln_f, h):
|
||||
x = wte[ids] + wpe[np.arange(len(ids))]
|
||||
for block_params in blocks:
|
||||
x = block(x, **block_params, h=h)
|
||||
x = layer_norm(x, **ln_f)
|
||||
return x @ wte.T
|
|
@ -1,5 +1,6 @@
|
|||
numpy==1.24.1
|
||||
regex==2017.4.5 # used by bpe encoder
|
||||
requests==2.27.1 # used to download gpt-2 files
|
||||
tensorflow==2.11.0 # used to load the gpt-2 weights from the tf checkpoint into numpy
|
||||
numpy==1.24.1 # used for the actual model code/weights
|
||||
regex==2017.4.5 # used by the bpe tokenizer
|
||||
requests==2.27.1 # used to download gpt-2 files from openai
|
||||
tensorflow==2.11.0 # used to load the gpt-2 weights from the open-ai tf checkpoint
|
||||
tqdm==4.64.0 # progress bar to keep your sanity
|
||||
fire==0.5.0 # easy CLI creation
|
||||
|
|
|
@ -1,4 +1,3 @@
|
|||
import argparse
|
||||
import json
|
||||
import os
|
||||
import re
|
||||
|
@ -9,7 +8,6 @@ import tensorflow as tf
|
|||
from tqdm import tqdm
|
||||
|
||||
from encoder import get_encoder
|
||||
from model import gpt2
|
||||
|
||||
|
||||
def download_gpt2_files(model_size, model_dir):
|
||||
|
@ -23,7 +21,7 @@ def download_gpt2_files(model_size, model_dir):
|
|||
"model.ckpt.meta",
|
||||
"vocab.bpe",
|
||||
]:
|
||||
url = "https://openaipublic.blob.core.windows.net/gpt-2/models/"
|
||||
url = "https://openaipublic.blob.core.windows.net/gpt-2/models"
|
||||
r = requests.get(f"{url}/{model_size}/{filename}", stream=True)
|
||||
r.raise_for_status()
|
||||
|
||||
|
@ -67,65 +65,18 @@ def load_gpt2_params_from_tf_ckpt(tf_ckpt_path, hparams):
|
|||
return params
|
||||
|
||||
|
||||
def generate(ids, params, h, n_tokens_to_generate):
|
||||
max_seq_len = params["wpe"].shape[0]
|
||||
assert len(ids) + n_tokens_to_generate < max_seq_len
|
||||
|
||||
for _ in tqdm(range(n_tokens_to_generate), "generating"):
|
||||
logits = gpt2(ids, **params, h=h)
|
||||
next_id = np.argmax(logits[-1])
|
||||
ids = np.append(ids, [next_id])
|
||||
|
||||
return list(ids[len(ids) - n_tokens_to_generate :])
|
||||
|
||||
|
||||
def main(prompt, models_dir, model_size, n_tokens_to_generate):
|
||||
def load_encoder_hparams_and_params(model_size, models_dir):
|
||||
assert model_size in ["124M", "355M", "774M", "1558M"]
|
||||
|
||||
model_dir = os.path.join(models_dir, model_size)
|
||||
if not os.path.isdir(model_dir):
|
||||
os.makedirs(model_dir)
|
||||
download_gpt2_files(model_size, model_dir)
|
||||
tf_ckpt_path = tf.train.latest_checkpoint(model_dir)
|
||||
if not tf_ckpt_path: # download files if necessary
|
||||
os.makedirs(model_dir, exist_ok=True)
|
||||
download_gpt2_files(model_size, model_dir)
|
||||
tf_ckpt_path = tf.train.latest_checkpoint(model_dir)
|
||||
|
||||
with open(os.path.join(model_dir, "hparams.json")) as file:
|
||||
hparams = json.load(file)
|
||||
|
||||
params = load_gpt2_params_from_tf_ckpt(tf_ckpt_path, hparams)
|
||||
encoder = get_encoder(model_size, models_dir)
|
||||
input_ids = [encoder.encoder["<|endoftext|>"]] if prompt is None else encoder.encode(prompt)
|
||||
output_ids = generate(input_ids, params, hparams["n_head"], n_tokens_to_generate)
|
||||
output_text = encoder.decode(output_ids)
|
||||
hparams = json.load(open(os.path.join(model_dir, "hparams.json")))
|
||||
params = load_gpt2_params_from_tf_ckpt(tf_ckpt_path, hparams)
|
||||
|
||||
return output_text
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser("Generate text with GPT-2.", formatter_class=argparse.ArgumentDefaultsHelpFormatter)
|
||||
parser.add_argument(
|
||||
"--prompt",
|
||||
type=str,
|
||||
help="Input text to condition the outputs. If not set, we'll generate unconditioned (i.e. start with <|endoftext|> token).",
|
||||
default=None,
|
||||
)
|
||||
parser.add_argument(
|
||||
"--models_dir",
|
||||
type=str,
|
||||
default="models",
|
||||
help="Base directory for the model directories.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--model_size",
|
||||
type=str,
|
||||
default="124M",
|
||||
help="Model size. Must be one of ['124M', '355M', '774M', '1558M']",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--n_tokens_to_generate",
|
||||
type=int,
|
||||
default=40,
|
||||
help="Number of tokens to generate.",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
print(main(**args.__dict__))
|
||||
return encoder, hparams, params
|
Ładowanie…
Reference in New Issue