text-gen-bot/textgen.py

50 lines
1.3 KiB
Python

from aitextgen.TokenDataset import TokenDataset
from aitextgen.tokenizers import train_tokenizer
from aitextgen.utils import GPT2ConfigCPU
from aitextgen import aitextgen
import json
import sys
with open('config.json', 'r') as file:
json_object = json.load(file)
file_name = json_object['file']
# ? generate message using trained model
def generate_message(prompt=None):
# ai = aitextgen(prompt=prompt)
ai = aitextgen(model_folder="trained_model",
tokenizer_file="aitextgen.tokenizer.json", prompt=prompt)
return ai.generate()
# ? train model using text file
def train_ai():
train_tokenizer(file_name)
tokenizer_file = "aitextgen.tokenizer.json"
config = GPT2ConfigCPU()
ai = aitextgen(tokenizer_file=tokenizer_file, config=config)
data = TokenDataset(
file_name, tokenizer_file=tokenizer_file, block_size=64)
ai.train(data, batch_size=8, num_steps=50000,
generate_every=5000, save_every=5000)
match sys.argv[1]:
case "generate":
# ? send message to parent JS process
try:
prompt = ' '.join(map(str, sys.argv[2:]))
generate_message(prompt)
except IndexError:
generate_message()
finally:
sys.stdout.flush()
case "train":
train_ai()
case _:
raise NameError("Argument not provided.")