From 29e78cc52b58ed2c1c483ffea2eb46ff6bdec785 Mon Sep 17 00:00:00 2001 From: jaymody Date: Sun, 29 Jan 2023 00:52:25 +0100 Subject: [PATCH] Refactor. --- README.md | 31 ++++++++---- gpt2.py | 121 ++++++++++++++++++++++++++++++++++++++++++++ gpt2_pico.py | 62 +++++++++++++++++++++++ model.py | 41 --------------- requirements.txt | 9 ++-- main.py => utils.py | 67 ++++-------------------- 6 files changed, 217 insertions(+), 114 deletions(-) create mode 100644 gpt2.py create mode 100644 gpt2_pico.py delete mode 100644 model.py rename main.py => utils.py (56%) diff --git a/README.md b/README.md index 9b35945..0b0ca3a 100644 --- a/README.md +++ b/README.md @@ -7,15 +7,14 @@ You've even seen [karpathy/nanoGPT](https://github.com/karpathy/nanogpt)! But have you seen [picoGPT](https://github.com/jaymody/picoGPT)??!? -`picoGPT` is an unnecessarily tiny and minimal implementation of [GPT-2](https://d4mucfpksywv.cloudfront.net/better-language-models/language_models_are_unsupervised_multitask_learners.pdf) in plain [NumPy](https://numpy.org). The entire forward pass code is [40 lines of code](https://github.com/jaymody/picoGPT/blob/main/model.py#L3-L41). +`picoGPT` is an unnecessarily tiny and minimal implementation of [GPT-2](https://d4mucfpksywv.cloudfront.net/better-language-models/language_models_are_unsupervised_multitask_learners.pdf) in plain [NumPy](https://numpy.org). The entire forward pass code is [40 lines of code](https://github.com/jaymody/picoGPT/blob/main/gpt2_pico.py#L3-L41). So is picoGPT: * Fast? ❌ Nah, picoGPT is megaSLOW 🐌 -* Commented? ❌ You joking? That would add more lines 😤😤😤😤 * Training code? ❌ Error, 4️⃣0️⃣4️⃣ not found * Batch inference? ❌ picoGPT is civilized, one at a time only, single file line * Support top-p? ❌ top-k? ❌ temperature? ❌ categorical sampling?! ❌ greedy? ✅ -* Readable? 🤔 Well actually, it's not too horrible on the eyes, and it's minimal! +* Readable? 🤔 Well actually, it's not too horrible on the eyes! * Smol??? ✅✅✅✅✅✅ YESS!!! TEENIE TINY in fact 🤏 #### Dependencies @@ -29,15 +28,25 @@ Tested on `Python 3.9.10`. #### Usage ```bash -python main.py \ - --prompt "Alan Turing theorized that computers would one day become" \ - --model_size "124M" \ - --n_tokens_to_generate 40 +python gpt2.py "Alan Turing theorized that computers would one day become" ``` -Which generates: -```text -the most powerful machines on the planet. + +Which generates + +``` + the most powerful machines on the planet. The computer is a machine that can perform complex calculations, and it can perform these calculations in a way that is very similar to the human brain. ``` -Use `python main.py --help` for a full list of options. + +You can also control the number of tokens to generate, the model size (one of `["124M", "355M", "774M", "1558M"]`), and the directory to save the models: + +```bash +python gpt2.py \ + "Alan Turing theorized that computers would one day become" \ + --n_tokens_to_generate 40 \ + --model_size "124M" \ + --models_dir "models" +``` + +`gpt2_pico.py` is the same as `gpt2.py`, but has even fewer lines of code (removed comments, extra whitespace, and combined certain operations into a single line). Why? Because why not. diff --git a/gpt2.py b/gpt2.py new file mode 100644 index 0000000..a4fd175 --- /dev/null +++ b/gpt2.py @@ -0,0 +1,121 @@ +import numpy as np + + +def gelu(x): + return 0.5 * x * (1 + np.tanh(np.sqrt(2 / np.pi) * (x + 0.044715 * x**3))) + + +def softmax(x): + exp_x = np.exp(x - np.max(x, axis=-1, keepdims=True)) + return exp_x / np.sum(exp_x, axis=-1, keepdims=True) + + +def layer_norm(x, g, b, eps: float = 1e-5): + mean = np.mean(x, axis=-1, keepdims=True) + variance = np.var(x, axis=-1, keepdims=True) + x = (x - mean) / np.sqrt(variance + eps) # normalize x to have mean=0 and var=1 over last axis + return g * x + b # scale and offset with gamma/beta params + + +def linear(x, w, b): # [m, in], [in, out], [out] -> [m, out] + return x @ w + b + + +def ffn(x, c_fc, c_proj): # [n_seq, n_embd] -> [n_seq, n_embd] + # project up + a = gelu(linear(x, **c_fc)) # [n_seq, n_embd] -> [n_seq, 4*n_embd] + + # project back down + x = linear(a, **c_proj) # [n_seq, 4*n_embd] -> [n_seq, n_embd] + + return x + + +def attention(q, k, v, mask): # [n_q, d_k], [n_k, d_k], [n_k, d_v], [n_q, n_k] -> [n_q, d_v] + return softmax(q @ k.T / np.sqrt(q.shape[-1]) + mask) @ v + + +def mha(x, c_attn, c_proj, n_head): # [n_seq, n_embd] -> [n_seq, n_embd] + # qkv projection + x = linear(x, **c_attn) # [n_seq, n_embd] -> [n_seq, 3*n_embd] + + # split into qkv + qkv = np.split(x, 3, axis=-1) # [n_seq, 3*n_embd] -> [3, n_seq, n_embd] + + # split into heads + qkv_heads = list(map(lambda x: np.split(x, n_head, axis=-1), qkv)) # [3, n_seq, n_embd] -> [n_head, 3, n_seq, n_embd/n_head] + + # casual mask to hide future inputs from being attended to + casual_mask = (1 - np.tri(x.shape[0])) * -1e10 # [n_seq, n_seq] + + # perform attention over each head + out_heads = [attention(q, k, v, casual_mask) for q, k, v in zip(*qkv_heads)] # [n_head, 3, n_seq, n_embd/n_head] -> [n_head, n_seq, n_embd/n_head] + + # merge heads + x = np.hstack(out_heads) # [n_head, n_seq, n_embd/n_head] -> [n_seq, n_embd] + + # out projection + x = linear(x, **c_proj) # [n_seq, n_embd] -> [n_seq, n_embd] + + return x + + +def transformer_block(x, mlp, attn, ln_1, ln_2, n_head): # [n_seq, n_embd] -> [n_seq, n_embd] + # multi-head casual self attention + x = x + mha(layer_norm(x, **ln_1), **attn, n_head=n_head) # [n_seq, n_embd] -> [n_seq, n_embd] + + # position-wise feed forward network + x = x + ffn(layer_norm(x, **ln_2), **mlp) # [n_seq, n_embd] -> [n_seq, n_embd] + + return x + + +def gpt2(inputs, wte, wpe, blocks, ln_f, n_head): # [n_seq] -> [n_seq, n_vocab] + # token + positional embeddings + x = wte[inputs] + wpe[range(len(inputs))] # [n_seq] -> [n_seq, n_embd] + + # forward pass through n_layer transformer blocks + for block in blocks: + x = transformer_block(x, **block, n_head=n_head) # [n_seq, n_embd] -> [n_seq, n_embd] + + # projection to vocab + x = layer_norm(x, **ln_f) # [n_seq, n_embd] -> [n_seq, n_embd] + return x @ wte.T # [n_seq, n_embd] -> [n_seq, n_vocab] + + +def generate(inputs, params, n_head, n_tokens_to_generate): + from tqdm import tqdm + + for _ in tqdm(range(n_tokens_to_generate), "generating"): # auto-regressive decode loop + logits = gpt2(inputs, **params, n_head=n_head) # model forward pass + next_id = np.argmax(logits[-1]) # greedy sampling + inputs = np.append(inputs, [next_id]) # append prediction to input + + return list(inputs[len(inputs) - n_tokens_to_generate :]) # only return generated ids + + +def main(prompt: str, n_tokens_to_generate: int = 40, model_size: str = "124M", models_dir: str = "models"): + from utils import load_encoder_hparams_and_params + + # load encoder, hparams, and params from the released open-ai gpt-2 files + encoder, hparams, params = load_encoder_hparams_and_params(model_size, models_dir) + + # encode the input string using the BPE tokenizer + input_ids = encoder.encode(prompt) + + # make sure we are not surpassing the max sequence length of our model + assert len(input_ids) + n_tokens_to_generate < hparams["n_ctx"] + + # generate output ids + output_ids = generate(input_ids, params, hparams["n_head"], n_tokens_to_generate) + + # decode the ids back into a string + output_text = encoder.decode(output_ids) + + return output_text + + +if __name__ == "__main__": + import fire + + fire.Fire(main) diff --git a/gpt2_pico.py b/gpt2_pico.py new file mode 100644 index 0000000..42cac00 --- /dev/null +++ b/gpt2_pico.py @@ -0,0 +1,62 @@ +import numpy as np + +def gelu(x): + return 0.5 * x * (1 + np.tanh(np.sqrt(2 / np.pi) * (x + 0.044715 * x**3))) + +def softmax(x): + exp_x = np.exp(x - np.max(x, axis=-1, keepdims=True)) + return exp_x / np.sum(exp_x, axis=-1, keepdims=True) + +def layer_norm(x, g, b, eps: float = 1e-5): + mean = np.mean(x, axis=-1, keepdims=True) + variance = np.var(x, axis=-1, keepdims=True) + return g * (x - mean) / np.sqrt(variance + eps) + b + +def linear(x, w, b): + return x @ w + b + +def ffn(x, c_fc, c_proj): + return linear(gelu(linear(x, **c_fc)), **c_proj) + +def attention(q, k, v, mask): + return softmax(q @ k.T / np.sqrt(q.shape[-1]) + mask) @ v + +def mha(x, c_attn, c_proj, n_head): + x = linear(x, **c_attn) + qkv_heads = list(map(lambda x: np.split(x, n_head, axis=-1), np.split(x, 3, axis=-1))) + casual_mask = (1 - np.tri(x.shape[0])) * -1e10 + out_heads = [attention(q, k, v, casual_mask) for q, k, v in zip(*qkv_heads)] + x = linear(np.hstack(out_heads), **c_proj) + return x + +def transformer_block(x, mlp, attn, ln_1, ln_2, n_head): + x = x + mha(layer_norm(x, **ln_1), **attn, n_head=n_head) + x = x + ffn(layer_norm(x, **ln_2), **mlp) + return x + +def gpt2(inputs, wte, wpe, blocks, ln_f, n_head): + x = wte[inputs] + wpe[range(len(inputs))] + for block in blocks: + x = transformer_block(x, **block, n_head=n_head) + return layer_norm(x, **ln_f) @ wte.T + +def generate(inputs, params, n_head, n_tokens_to_generate): + from tqdm import tqdm + for _ in tqdm(range(n_tokens_to_generate), "generating"): + logits = gpt2(inputs, **params, n_head=n_head) + next_id = np.argmax(logits[-1]) + inputs = np.append(inputs, [next_id]) + return list(inputs[len(inputs) - n_tokens_to_generate :]) + +def main(prompt: str, n_tokens_to_generate: int = 40, model_size: str = "124M", models_dir: str = "models"): + from utils import load_encoder_hparams_and_params + encoder, hparams, params = load_encoder_hparams_and_params(model_size, models_dir) + input_ids = encoder.encode(prompt) + assert len(input_ids) + n_tokens_to_generate < hparams["n_ctx"] + output_ids = generate(input_ids, params, hparams["n_head"], n_tokens_to_generate) + output_text = encoder.decode(output_ids) + return output_text + +if __name__ == "__main__": + import fire + fire.Fire(main) diff --git a/model.py b/model.py deleted file mode 100644 index 1c8d223..0000000 --- a/model.py +++ /dev/null @@ -1,41 +0,0 @@ -import numpy as np - -def gelu(x): - return 0.5 * x * (1 + np.tanh(np.sqrt(2 / np.pi) * (x + 0.044715 * x**3))) - -def softmax(x, axis=-1): - exp_x = np.exp(x - np.max(x, axis=axis, keepdims=True)) - return exp_x / np.sum(exp_x, axis=axis, keepdims=True) - -def layer_norm(x, g, b, eps: float = 1e-5): - mean = np.mean(x, axis=-1, keepdims=True) - variance = np.var(x, axis=-1, keepdims=True) - return g * (x - mean) / np.sqrt(variance + eps) + b - -def linear(x, w, b): - return x @ w + b - -def ffn(x, c_fc, c_proj): - return linear(gelu(linear(x, **c_fc)), **c_proj) - -def attention(q, k, v, mask): - return softmax(q @ k.T / np.sqrt(q.shape[-1]) + mask) @ v - -def mha(x, c_attn, c_proj, h): - x = linear(x, **c_attn) - qkv = list(map(lambda x: np.split(x, h, axis=-1), np.split(x, 3, axis=-1))) - casual_mask = (1 - np.tri(x.shape[0])) * -1e10 - heads = [attention(q, k, v, casual_mask) for q, k, v in zip(*qkv)] - return linear(np.hstack(heads), **c_proj) - -def block(x, mlp, attn, ln_1, ln_2, h): - x = x + mha(layer_norm(x, **ln_1), **attn, h=h) - x = x + ffn(layer_norm(x, **ln_2), **mlp) - return x - -def gpt2(ids, wte, wpe, blocks, ln_f, h): - x = wte[ids] + wpe[np.arange(len(ids))] - for block_params in blocks: - x = block(x, **block_params, h=h) - x = layer_norm(x, **ln_f) - return x @ wte.T diff --git a/requirements.txt b/requirements.txt index 7ce6814..ce42b28 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,5 +1,6 @@ -numpy==1.24.1 -regex==2017.4.5 # used by bpe encoder -requests==2.27.1 # used to download gpt-2 files -tensorflow==2.11.0 # used to load the gpt-2 weights from the tf checkpoint into numpy +numpy==1.24.1 # used for the actual model code/weights +regex==2017.4.5 # used by the bpe tokenizer +requests==2.27.1 # used to download gpt-2 files from openai +tensorflow==2.11.0 # used to load the gpt-2 weights from the open-ai tf checkpoint tqdm==4.64.0 # progress bar to keep your sanity +fire==0.5.0 # easy CLI creation diff --git a/main.py b/utils.py similarity index 56% rename from main.py rename to utils.py index ea521b6..f459079 100644 --- a/main.py +++ b/utils.py @@ -1,4 +1,3 @@ -import argparse import json import os import re @@ -9,7 +8,6 @@ import tensorflow as tf from tqdm import tqdm from encoder import get_encoder -from model import gpt2 def download_gpt2_files(model_size, model_dir): @@ -23,7 +21,7 @@ def download_gpt2_files(model_size, model_dir): "model.ckpt.meta", "vocab.bpe", ]: - url = "https://openaipublic.blob.core.windows.net/gpt-2/models/" + url = "https://openaipublic.blob.core.windows.net/gpt-2/models" r = requests.get(f"{url}/{model_size}/{filename}", stream=True) r.raise_for_status() @@ -67,65 +65,18 @@ def load_gpt2_params_from_tf_ckpt(tf_ckpt_path, hparams): return params -def generate(ids, params, h, n_tokens_to_generate): - max_seq_len = params["wpe"].shape[0] - assert len(ids) + n_tokens_to_generate < max_seq_len - - for _ in tqdm(range(n_tokens_to_generate), "generating"): - logits = gpt2(ids, **params, h=h) - next_id = np.argmax(logits[-1]) - ids = np.append(ids, [next_id]) - - return list(ids[len(ids) - n_tokens_to_generate :]) - - -def main(prompt, models_dir, model_size, n_tokens_to_generate): +def load_encoder_hparams_and_params(model_size, models_dir): assert model_size in ["124M", "355M", "774M", "1558M"] model_dir = os.path.join(models_dir, model_size) - if not os.path.isdir(model_dir): - os.makedirs(model_dir) - download_gpt2_files(model_size, model_dir) tf_ckpt_path = tf.train.latest_checkpoint(model_dir) + if not tf_ckpt_path: # download files if necessary + os.makedirs(model_dir, exist_ok=True) + download_gpt2_files(model_size, model_dir) + tf_ckpt_path = tf.train.latest_checkpoint(model_dir) - with open(os.path.join(model_dir, "hparams.json")) as file: - hparams = json.load(file) - - params = load_gpt2_params_from_tf_ckpt(tf_ckpt_path, hparams) encoder = get_encoder(model_size, models_dir) - input_ids = [encoder.encoder["<|endoftext|>"]] if prompt is None else encoder.encode(prompt) - output_ids = generate(input_ids, params, hparams["n_head"], n_tokens_to_generate) - output_text = encoder.decode(output_ids) + hparams = json.load(open(os.path.join(model_dir, "hparams.json"))) + params = load_gpt2_params_from_tf_ckpt(tf_ckpt_path, hparams) - return output_text - - -if __name__ == "__main__": - parser = argparse.ArgumentParser("Generate text with GPT-2.", formatter_class=argparse.ArgumentDefaultsHelpFormatter) - parser.add_argument( - "--prompt", - type=str, - help="Input text to condition the outputs. If not set, we'll generate unconditioned (i.e. start with <|endoftext|> token).", - default=None, - ) - parser.add_argument( - "--models_dir", - type=str, - default="models", - help="Base directory for the model directories.", - ) - parser.add_argument( - "--model_size", - type=str, - default="124M", - help="Model size. Must be one of ['124M', '355M', '774M', '1558M']", - ) - parser.add_argument( - "--n_tokens_to_generate", - type=int, - default=40, - help="Number of tokens to generate.", - ) - args = parser.parse_args() - - print(main(**args.__dict__)) + return encoder, hparams, params