Initial commit.

pull/4/head
jaymody 2023-01-21 22:07:30 +01:00
commit bf118a3660
6 zmienionych plików z 576 dodań i 0 usunięć

236
.gitignore vendored 100644
Wyświetl plik

@ -0,0 +1,236 @@
### Project ###
/models/
### Python ###
# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
*$py.class
# C extensions
*.so
# Distribution / packaging
.Python
build/
develop-eggs/
dist/
downloads/
eggs/
.eggs/
lib/
lib64/
parts/
sdist/
var/
wheels/
share/python-wheels/
*.egg-info/
.installed.cfg
*.egg
MANIFEST
# PyInstaller
# Usually these files are written by a python script from a template
# before PyInstaller builds the exe, so as to inject date/other infos into it.
*.manifest
*.spec
# Installer logs
pip-log.txt
pip-delete-this-directory.txt
# Unit test / coverage reports
htmlcov/
.tox/
.nox/
.coverage
.coverage.*
.cache
nosetests.xml
coverage.xml
*.cover
*.py,cover
.hypothesis/
.pytest_cache/
cover/
# Translations
*.mo
*.pot
# Django stuff:
*.log
local_settings.py
db.sqlite3
db.sqlite3-journal
# Flask stuff:
instance/
.webassets-cache
# Scrapy stuff:
.scrapy
# Sphinx documentation
docs/_build/
# PyBuilder
.pybuilder/
target/
# Jupyter Notebook
.ipynb_checkpoints
# IPython
profile_default/
ipython_config.py
# pyenv
# For a library or package, you might want to ignore these files since the code is
# intended to run in multiple environments; otherwise, check them in:
# .python-version
# pipenv
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
# However, in case of collaboration, if having platform-specific dependencies or dependencies
# having no cross-platform support, pipenv may install dependencies that don't work, or not
# install all needed dependencies.
#Pipfile.lock
# poetry
# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
# This is especially recommended for binary packages to ensure reproducibility, and is more
# commonly ignored for libraries.
# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
#poetry.lock
# PEP 582; used by e.g. github.com/David-OConnor/pyflow
__pypackages__/
# Celery stuff
celerybeat-schedule
celerybeat.pid
# SageMath parsed files
*.sage.py
# Environments
.env
.venv
env/
venv/
ENV/
env.bak/
venv.bak/
# Spyder project settings
.spyderproject
.spyproject
# Rope project settings
.ropeproject
# mkdocs documentation
/site
# mypy
.mypy_cache/
.dmypy.json
dmypy.json
# Pyre type checker
.pyre/
# pytype static type analyzer
.pytype/
# Cython debug symbols
cython_debug/
# PyCharm
# JetBrains specific template is maintainted in a separate JetBrains.gitignore that can
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
# and can be added to the global gitignore or merged into this file. For a more nuclear
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
#.idea/
### macOS ###
# General
.DS_Store
.AppleDouble
.LSOverride
# Icon must end with two \r
Icon
# Thumbnails
._*
# Files that might appear in the root of a volume
.DocumentRevisions-V100
.fseventsd
.Spotlight-V100
.TemporaryItems
.Trashes
.VolumeIcon.icns
.com.apple.timemachine.donotpresent
# Directories potentially created on remote AFP share
.AppleDB
.AppleDesktop
Network Trash Folder
Temporary Items
.apdisk
### Linux ###
*~
# temporary files which can be created if a process still has a handle open of a deleted file
.fuse_hidden*
# KDE directory preferences
.directory
# Linux trash folder which might appear on any partition or disk
.Trash-*
# .nfs files are created when an open file is removed but is still being accessed
.nfs*
### Windows ###
# Windows thumbnail cache files
Thumbs.db
Thumbs.db:encryptable
ehthumbs.db
ehthumbs_vista.db
# Dump file
*.stackdump
# Folder config file
[Dd]esktop.ini
# Recycle Bin used on file shares
$RECYCLE.BIN/
# Windows Installer files
*.cab
*.msi
*.msix
*.msm
*.msp
# Windows shortcuts
*.lnk

43
README.md 100644
Wyświetl plik

@ -0,0 +1,43 @@
# PicoGPT
You've seen [openai/gpt-2](https://github.com/openai/gpt-2).
You've seen [karpathy/minGPT](https://github.com/karpathy/mingpt).
You've even seen [karpathy/nanoGPT](https://github.com/karpathy/nanogpt)!
But have you seen [picoGPT](https://github.com/jaymody/picoGPT)??!?
`picoGPT` is an unnecessarily tiny and minimal implementation of [GPT-2](https://d4mucfpksywv.cloudfront.net/better-language-models/language_models_are_unsupervised_multitask_learners.pdf) in plain [NumPy](https://numpy.org). The entire forward pass code is [40 lines of code](https://github.com/jaymody/picoGPT/blob/main/model.py#L3-L41).
So is picoGPT:
* Fast? ❌ Nah, picoGPT is megaSLOW 🐌
* Commented? ❌ You joking? That would add more lines 😤😤😤😤
* Training code? ❌ Error, 4⃣0⃣4⃣ not found
* Batch inference? ❌ picoGPT is civilized, one at a time only, single file line
* Support top-p? ❌ top-k? ❌ temperature? ❌ categorical sampling?! ❌ greedy? ✅
* Readable? 🤔 Well actually, it's not too horrible on the eyes, and it's minimal!
* Smol??? ✅✅✅✅✅✅ YESS!!! TEENIE TINY in fact 🤏
#### Dependencies
```bash
pip install -r requirements.txt
```
If you're using an M1 Macbook, you'll need to replace `tensorflow` with `tensorflow-macos`.
Tested on `Python 3.9.10`.
#### Usage
```bash
python main.py \
--prompt "Alan Turing theorized that computers would one day become" \
--model_size "124M" \
--n_tokens_to_generate 40
```
Which generates:
```text
the most powerful machines on the planet.
The computer is a machine that can perform complex calculations, and it can perform these calculations in a way that is very similar to the human brain.
```
Use `python main.py --help` for a full list of options.

120
encoder.py 100644
Wyświetl plik

@ -0,0 +1,120 @@
"""Byte pair encoding utilities.
Copied from: https://github.com/openai/gpt-2/blob/master/src/encoder.py.
"""
import json
import os
from functools import lru_cache
import regex as re
@lru_cache()
def bytes_to_unicode():
"""
Returns list of utf-8 byte and a corresponding list of unicode strings.
The reversible bpe codes work on unicode strings.
This means you need a large # of unicode characters in your vocab if you want to avoid UNKs.
When you're at something like a 10B token dataset you end up needing around 5K for decent coverage.
This is a signficant percentage of your normal, say, 32K bpe vocab.
To avoid that, we want lookup tables between utf-8 bytes and unicode strings.
And avoids mapping to whitespace/control characters the bpe code barfs on.
"""
bs = list(range(ord("!"), ord("~") + 1)) + list(range(ord("¡"), ord("¬") + 1)) + list(range(ord("®"), ord("ÿ") + 1))
cs = bs[:]
n = 0
for b in range(2**8):
if b not in bs:
bs.append(b)
cs.append(2**8 + n)
n += 1
cs = [chr(n) for n in cs]
return dict(zip(bs, cs))
def get_pairs(word):
"""Return set of symbol pairs in a word.
Word is represented as tuple of symbols (symbols being variable-length strings).
"""
pairs = set()
prev_char = word[0]
for char in word[1:]:
pairs.add((prev_char, char))
prev_char = char
return pairs
class Encoder:
def __init__(self, encoder, bpe_merges, errors="replace"):
self.encoder = encoder
self.decoder = {v: k for k, v in self.encoder.items()}
self.errors = errors # how to handle errors in decoding
self.byte_encoder = bytes_to_unicode()
self.byte_decoder = {v: k for k, v in self.byte_encoder.items()}
self.bpe_ranks = dict(zip(bpe_merges, range(len(bpe_merges))))
self.cache = {}
# Should haved added re.IGNORECASE so BPE merges can happen for capitalized versions of contractions
self.pat = re.compile(r"""'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+""")
def bpe(self, token):
if token in self.cache:
return self.cache[token]
word = tuple(token)
pairs = get_pairs(word)
if not pairs:
return token
while True:
bigram = min(pairs, key=lambda pair: self.bpe_ranks.get(pair, float("inf")))
if bigram not in self.bpe_ranks:
break
first, second = bigram
new_word = []
i = 0
while i < len(word):
try:
j = word.index(first, i)
new_word.extend(word[i:j])
i = j
except:
new_word.extend(word[i:])
break
if word[i] == first and i < len(word) - 1 and word[i + 1] == second:
new_word.append(first + second)
i += 2
else:
new_word.append(word[i])
i += 1
new_word = tuple(new_word)
word = new_word
if len(word) == 1:
break
else:
pairs = get_pairs(word)
word = " ".join(word)
self.cache[token] = word
return word
def encode(self, text):
bpe_tokens = []
for token in re.findall(self.pat, text):
token = "".join(self.byte_encoder[b] for b in token.encode("utf-8"))
bpe_tokens.extend(self.encoder[bpe_token] for bpe_token in self.bpe(token).split(" "))
return bpe_tokens
def decode(self, tokens):
text = "".join([self.decoder[token] for token in tokens])
text = bytearray([self.byte_decoder[c] for c in text]).decode("utf-8", errors=self.errors)
return text
def get_encoder(model_name, models_dir):
with open(os.path.join(models_dir, model_name, "encoder.json"), "r") as f:
encoder = json.load(f)
with open(os.path.join(models_dir, model_name, "vocab.bpe"), "r", encoding="utf-8") as f:
bpe_data = f.read()
bpe_merges = [tuple(merge_str.split()) for merge_str in bpe_data.split("\n")[1:-1]]
return Encoder(encoder=encoder, bpe_merges=bpe_merges)

131
main.py 100644
Wyświetl plik

@ -0,0 +1,131 @@
import argparse
import json
import os
import re
import numpy as np
import requests
import tensorflow as tf
from tqdm import tqdm
from encoder import get_encoder
from model import gpt2
def download_gpt2_files(model_size, model_dir):
assert model_size in ["124M", "355M", "774M", "1558M"]
for filename in [
"checkpoint",
"encoder.json",
"hparams.json",
"model.ckpt.data-00000-of-00001",
"model.ckpt.index",
"model.ckpt.meta",
"vocab.bpe",
]:
url = "https://openaipublic.blob.core.windows.net/gpt-2/models/"
r = requests.get(f"{url}/{model_size}/{filename}", stream=True)
r.raise_for_status()
with open(os.path.join(model_dir, filename), "wb") as f:
file_size = int(r.headers["content-length"])
chunk_size = 1000
with tqdm(
ncols=100,
desc="Fetching " + filename,
total=file_size,
unit_scale=True,
) as pbar:
# 1k for chunk_size, since Ethernet packet size is around 1500 bytes
for chunk in r.iter_content(chunk_size=chunk_size):
f.write(chunk)
pbar.update(chunk_size)
def load_gpt2_params_from_tf_ckpt(tf_ckpt_path, hparams):
def set_in_nested_dict(d, keys, val):
if not keys:
return val
if keys[0] not in d:
d[keys[0]] = {}
d[keys[0]] = set_in_nested_dict(d[keys[0]], keys[1:], val)
return d
init_vars = tf.train.list_variables(tf_ckpt_path)
params = {"blocks": [{} for _ in range(hparams["n_layer"])]}
for name, _ in init_vars:
array = np.squeeze(tf.train.load_variable(tf_ckpt_path, name))
name = name.removeprefix("model/")
if name.startswith("h"):
m = re.match(r"h([0-9]+)/(.*)", name)
n = int(m[1])
sub_name = m[2]
set_in_nested_dict(params["blocks"][n], sub_name.split("/"), array)
else:
set_in_nested_dict(params, name.split("/"), array)
return params
def generate(ids, params, h, n_tokens_to_generate):
max_seq_len = params["wpe"].shape[0]
assert len(ids) + n_tokens_to_generate < max_seq_len
for _ in tqdm(range(n_tokens_to_generate), "generating"):
logits = gpt2(ids, **params, h=h)
next_id = np.argmax(logits[-1])
ids = np.append(ids, [next_id])
return list(ids[len(ids) - n_tokens_to_generate :])
def main(prompt, models_dir, model_size, n_tokens_to_generate):
assert model_size in ["124M", "355M", "774M", "1558M"]
model_dir = os.path.join(models_dir, model_size)
tf_ckpt_path = tf.train.latest_checkpoint(model_dir)
if not os.path.isdir(model_dir):
os.makedirs(model_dir)
download_gpt2_files(model_size, model_dir)
with open(os.path.join(model_dir, "hparams.json")) as file:
hparams = json.load(file)
params = load_gpt2_params_from_tf_ckpt(tf_ckpt_path, hparams)
encoder = get_encoder(model_size, models_dir)
input_ids = [encoder.encoder["<|endoftext|>"]] if prompt is None else encoder.encode(prompt)
output_ids = generate(input_ids, params, hparams["n_head"], n_tokens_to_generate)
output_text = encoder.decode(output_ids)
return output_text
if __name__ == "__main__":
parser = argparse.ArgumentParser("Generate text with GPT-2.", formatter_class=argparse.ArgumentDefaultsHelpFormatter)
parser.add_argument(
"--prompt",
type=str,
help="Input text to condition the outputs. If not set, we'll generate unconditioned (i.e. start with <|endoftext|> token).",
default=None,
)
parser.add_argument(
"--models_dir",
type=str,
default="models",
help="Base directory for the model directories.",
)
parser.add_argument(
"--model_size",
type=str,
default="124M",
help="Model size. Must be one of ['124M', '355M', '774M', '1558M']",
)
parser.add_argument(
"--n_tokens_to_generate",
type=int,
default=40,
help="Number of tokens to generate.",
)
args = parser.parse_args()
print(main(**args.__dict__))

41
model.py 100644
Wyświetl plik

@ -0,0 +1,41 @@
import numpy as np
def gelu(x):
return 0.5 * x * (1 + np.tanh(np.sqrt(2 / np.pi) * (x + 0.044715 * x**3)))
def softmax(x, axis=-1):
exp_x = np.exp(x - np.max(x, axis=axis, keepdims=True))
return exp_x / np.sum(exp_x, axis=axis, keepdims=True)
def layer_norm(x, g, b, eps: float = 1e-5):
mean = np.mean(x, axis=-1, keepdims=True)
variance = np.var(x, axis=-1, keepdims=True)
return g * (x - mean) / np.sqrt(variance + eps) + b
def linear(x, w, b):
return x @ w + b
def ffn(x, c_fc, c_proj):
return linear(gelu(linear(x, **c_fc)), **c_proj)
def attention(q, k, v, mask):
return softmax(q @ k.T / np.sqrt(q.shape[-1]) + mask) @ v
def mha(x, c_attn, c_proj, h):
x = linear(x, **c_attn)
qkv = list(map(lambda x: np.split(x, h, axis=-1), np.split(x, 3, axis=-1)))
casual_mask = (1 - np.tri(x.shape[0])) * -1e10
heads = [attention(q, k, v, casual_mask) for q, k, v in zip(*qkv)]
return linear(np.hstack(heads), **c_proj)
def block(x, mlp, attn, ln_1, ln_2, h):
x = x + mha(layer_norm(x, **ln_1), **attn, h=h)
x = x + ffn(layer_norm(x, **ln_2), **mlp)
return x
def gpt2(ids, wte, wpe, blocks, ln_f, h):
x = wte[ids] + wpe[np.arange(len(ids))]
for block_params in blocks:
x = block(x, **block_params, h=h)
x = layer_norm(x, **ln_f)
return x @ wte.T

5
requirements.txt 100644
Wyświetl plik

@ -0,0 +1,5 @@
numpy==1.24.1
regex==2017.4.5 # used by bpe encoder
requests==2.27.1 # used to download gpt-2 files
tensorflow==2.11.0 # used to load the gpt-2 weights from the tf checkpoint into numpy
tqdm==4.64.0 # progress bar to keep your sanity