2023-01-28 23:52:25 +00:00
import numpy as np
def gelu ( x ) :
return 0.5 * x * ( 1 + np . tanh ( np . sqrt ( 2 / np . pi ) * ( x + 0.044715 * x * * 3 ) ) )
def softmax ( x ) :
exp_x = np . exp ( x - np . max ( x , axis = - 1 , keepdims = True ) )
return exp_x / np . sum ( exp_x , axis = - 1 , keepdims = True )
def layer_norm ( x , g , b , eps : float = 1e-5 ) :
mean = np . mean ( x , axis = - 1 , keepdims = True )
variance = np . var ( x , axis = - 1 , keepdims = True )
x = ( x - mean ) / np . sqrt ( variance + eps ) # normalize x to have mean=0 and var=1 over last axis
return g * x + b # scale and offset with gamma/beta params
def linear ( x , w , b ) : # [m, in], [in, out], [out] -> [m, out]
return x @ w + b
def ffn ( x , c_fc , c_proj ) : # [n_seq, n_embd] -> [n_seq, n_embd]
# project up
a = gelu ( linear ( x , * * c_fc ) ) # [n_seq, n_embd] -> [n_seq, 4*n_embd]
# project back down
x = linear ( a , * * c_proj ) # [n_seq, 4*n_embd] -> [n_seq, n_embd]
return x
def attention ( q , k , v , mask ) : # [n_q, d_k], [n_k, d_k], [n_k, d_v], [n_q, n_k] -> [n_q, d_v]
return softmax ( q @ k . T / np . sqrt ( q . shape [ - 1 ] ) + mask ) @ v
2023-02-12 05:12:21 +00:00
def mha ( x , c_attn , c_proj , n_head , kvcache ) : # [n_seq, n_embd] -> [n_seq, n_embd]
2023-01-28 23:52:25 +00:00
# qkv projection
x = linear ( x , * * c_attn ) # [n_seq, n_embd] -> [n_seq, 3*n_embd]
# split into qkv
qkv = np . split ( x , 3 , axis = - 1 ) # [n_seq, 3*n_embd] -> [3, n_seq, n_embd]
2023-02-12 05:12:21 +00:00
if kvcache :
# qkv
q , k , v = qkv
old_k , old_v = kvcache
k = np . vstack ( [ old_k , k ] )
v = np . vstack ( [ old_v , v ] )
qkv = [ q , k , v ]
current_cache = [ qkv [ 1 ] , qkv [ 2 ] ]
2023-01-28 23:52:25 +00:00
# split into heads
qkv_heads = list ( map ( lambda x : np . split ( x , n_head , axis = - 1 ) , qkv ) ) # [3, n_seq, n_embd] -> [n_head, 3, n_seq, n_embd/n_head]
2023-02-09 23:53:26 +00:00
# causal mask to hide future inputs from being attended to
2023-02-12 05:12:21 +00:00
if kvcache :
causal_mask = np . zeros ( ( 1 , k . shape [ 0 ] ) )
else :
causal_mask = ( 1 - np . tri ( x . shape [ 0 ] ) ) * - 1e10 # [n_seq, n_seq]
2023-01-28 23:52:25 +00:00
# perform attention over each head
2023-02-09 23:53:26 +00:00
out_heads = [ attention ( q , k , v , causal_mask ) for q , k , v in zip ( * qkv_heads ) ] # [n_head, 3, n_seq, n_embd/n_head] -> [n_head, n_seq, n_embd/n_head]
2023-01-28 23:52:25 +00:00
2023-02-12 05:12:21 +00:00
2023-01-28 23:52:25 +00:00
# merge heads
x = np . hstack ( out_heads ) # [n_head, n_seq, n_embd/n_head] -> [n_seq, n_embd]
# out projection
x = linear ( x , * * c_proj ) # [n_seq, n_embd] -> [n_seq, n_embd]
2023-02-12 05:12:21 +00:00
return x , current_cache
2023-01-28 23:52:25 +00:00
2023-02-12 05:12:21 +00:00
def transformer_block ( x , mlp , attn , ln_1 , ln_2 , n_head , kvcache ) : # [n_seq, n_embd] -> [n_seq, n_embd]
2023-02-09 23:53:26 +00:00
# multi-head causal self attention
2023-02-12 05:12:21 +00:00
attn_out , kvcache_updated = mha ( layer_norm ( x , * * ln_1 ) , * * attn , n_head = n_head , kvcache = kvcache )
x = x + attn_out # [n_seq, n_embd] -> [n_seq, n_embd]
2023-01-28 23:52:25 +00:00
# position-wise feed forward network
x = x + ffn ( layer_norm ( x , * * ln_2 ) , * * mlp ) # [n_seq, n_embd] -> [n_seq, n_embd]
2023-02-12 05:12:21 +00:00
return x , kvcache_updated
2023-01-28 23:52:25 +00:00
2023-02-12 05:12:21 +00:00
def gpt2 ( inputs , wte , wpe , blocks , ln_f , n_head , kvcache ) : # [n_seq] -> [n_seq, n_vocab]
if not kvcache :
kvcache = [ None ] * len ( blocks )
wpe_out = wpe [ range ( len ( inputs ) ) ]
else :
wpe_out = wpe [ [ len ( inputs ) - 1 ] ]
inputs = [ inputs [ - 1 ] ]
2023-01-28 23:52:25 +00:00
# token + positional embeddings
2023-02-12 05:12:21 +00:00
x = wte [ inputs ] + wpe_out # [n_seq] -> [n_seq, n_embd]
2023-01-28 23:52:25 +00:00
2023-02-12 05:12:21 +00:00
2023-01-28 23:52:25 +00:00
# forward pass through n_layer transformer blocks
2023-02-12 05:12:21 +00:00
new_kvcache = [ ]
for block , kvcache_block in zip ( blocks , kvcache ) :
x , updated_cache = transformer_block ( x , * * block , n_head = n_head , kvcache = kvcache_block ) # [n_seq, n_embd] -> [n_seq, n_embd]
new_kvcache . append ( updated_cache ) # TODO: inplace extend new cache instead of re-saving whole
2023-01-28 23:52:25 +00:00
# projection to vocab
x = layer_norm ( x , * * ln_f ) # [n_seq, n_embd] -> [n_seq, n_embd]
2023-02-12 05:12:21 +00:00
return x @ wte . T , new_kvcache # [n_seq, n_embd] -> [n_seq, n_vocab]
2023-01-28 23:52:25 +00:00
def generate ( inputs , params , n_head , n_tokens_to_generate ) :
from tqdm import tqdm
2023-02-12 05:12:21 +00:00
kvcache = None
2023-01-28 23:52:25 +00:00
for _ in tqdm ( range ( n_tokens_to_generate ) , " generating " ) : # auto-regressive decode loop
2023-02-12 05:12:21 +00:00
logits , kvcache = gpt2 ( inputs , * * params , n_head = n_head , kvcache = kvcache ) # model forward pass
2023-01-28 23:52:25 +00:00
next_id = np . argmax ( logits [ - 1 ] ) # greedy sampling
inputs = np . append ( inputs , [ next_id ] ) # append prediction to input
return list ( inputs [ len ( inputs ) - n_tokens_to_generate : ] ) # only return generated ids
2023-02-12 05:12:21 +00:00
def main ( prompt : str = " Alan Turing theorized that computers would one day become " , n_tokens_to_generate : int = 40 , model_size : str = " 124M " , models_dir : str = " models " ) :
2023-01-28 23:52:25 +00:00
from utils import load_encoder_hparams_and_params
# load encoder, hparams, and params from the released open-ai gpt-2 files
encoder , hparams , params = load_encoder_hparams_and_params ( model_size , models_dir )
# encode the input string using the BPE tokenizer
input_ids = encoder . encode ( prompt )
# make sure we are not surpassing the max sequence length of our model
assert len ( input_ids ) + n_tokens_to_generate < hparams [ " n_ctx " ]
# generate output ids
output_ids = generate ( input_ids , params , hparams [ " n_head " ] , n_tokens_to_generate )
# decode the ids back into a string
output_text = encoder . decode ( output_ids )
return output_text
if __name__ == " __main__ " :
import fire
fire . Fire ( main )