text-gen-bot/textgen.py

26 lines
860 B
Python

from aitextgen.TokenDataset import TokenDataset
from aitextgen.tokenizers import train_tokenizer
from aitextgen.utils import GPT2ConfigCPU
from aitextgen import aitextgen
import json
with open('config.json', 'r') as file:
json_object = json.load(file)
file_name = json_object['file']
def generate_message():
ai = aitextgen(model_folder="trained_model",
tokenizer_file="aitextgen.tokenizer.json")
ai.generate()
def train_ai():
train_tokenizer(file_name)
tokenizer_file = "aitextgen.tokenizer.json"
config = GPT2ConfigCPU()
ai = aitextgen(tokenizer_file=tokenizer_file, config=config)
data = TokenDataset(file_name, tokenizer_file=tokenizer_file, block_size=64)
ai.train(data, batch_size=8, num_steps=50000, generate_every=5000, save_every=5000)
print("AI has been trained!")
print(generate_message())