Emergent generative agents
Rev. | 356751b59d378fbe6ea336211e0c3588cd08eccb |
---|---|
크기 | 2,451 bytes |
Time | 2023-06-12 03:14:32 |
Author | Corbin |
Log Message | Add WP query access, and fix a few bugs.
Most importantly, make choices on RWKV work. The same technique should
|
import importlib.util
import os
import sys
import tokenizers
from gens.base import Gen, Yarn
# Monkey-patch to get rwkv available.
RWKV = "@RWKV@"
RWKV_PATH = os.path.join(RWKV, "bin")
def bare_import(path, module_name):
file_path = os.path.join(path, module_name + ".py")
spec = importlib.util.spec_from_file_location(module_name, file_path)
module = importlib.util.module_from_spec(spec)
sys.modules[module_name] = module
spec.loader.exec_module(module)
return module
rwkv_cpp_shared_library = bare_import(RWKV_PATH, "rwkv_cpp_shared_library")
rwkv_cpp_model = bare_import(RWKV_PATH, "rwkv_cpp_model")
sampling = bare_import(RWKV_PATH, "sampling")
TOKENIZER_PATH = os.path.join(RWKV, "share", "20B_tokenizer.json")
# Upstream recommends temp 0.7, top_p 0.5
# TEMPERATURE = 0.8
# TOP_P = 0.8
TEMPERATURE = 0.9
TOP_P = 0.9
class MawrkovGen(Gen):
model_name = "The Pile"
model_arch = "RWKV"
def __init__(self, model_path, max_new_tokens):
self.max_new_tokens = max_new_tokens
self.model_size = os.stat(model_path).st_size * 3 // 2
self.tokenizer = tokenizers.Tokenizer.from_file(TOKENIZER_PATH)
self.lib = rwkv_cpp_shared_library.load_rwkv_shared_library()
self.model = rwkv_cpp_model.RWKVModel(self.lib, model_path)
# XXX wrong
def footprint(self): return self.model_size
def contextLength(self): return 8192
def tokenize(self, s): return self.tokenizer.encode(s).ids
def decode(self, ts): return self.tokenizer.decode(ts)
def fork(self):
return MawrkovYarn(self.max_new_tokens, self.model, self.tokenizer)
class MawrkovYarn(Yarn):
logits = state = None
def __init__(self, max_new_tokens, model, tokenizer):
self.max_new_tokens = max_new_tokens
self.model = model
self.tokenizer = tokenizer
def feedForward(self, tokens):
for t in tokens:
self.logits, self.state = self.model.eval(t, self.state,
self.state, self.logits)
def complete(self):
tokens = []
for i in range(self.max_new_tokens):
token = sampling.sample_logits(self.logits, TEMPERATURE, TOP_P)
if "\n" in self.tokenizer.decode([token]): break
tokens.append(token)
self.feedForward([token])
return tokens
def force(self, options): return max(options, key=self.logits.__getitem__)