diff --git a/textgen.py b/textgen.py index 549b638..437bd5e 100644 --- a/textgen.py +++ b/textgen.py @@ -2,7 +2,8 @@ from aitextgen.TokenDataset import TokenDataset from aitextgen.tokenizers import train_tokenizer from aitextgen.utils import GPT2ConfigCPU from aitextgen import aitextgen -import json, sys +import json +import sys with open('config.json', 'r') as file: json_object = json.load(file) @@ -10,21 +11,33 @@ with open('config.json', 'r') as file: file_name = json_object['file'] # ? generate message using trained model + + def generate_message(): ai = aitextgen(model_folder="trained_model", - tokenizer_file="aitextgen.tokenizer.json") + tokenizer_file="aitextgen.tokenizer.json") ai.generate() # ? train model using text file + + def train_ai(): train_tokenizer(file_name) tokenizer_file = "aitextgen.tokenizer.json" config = GPT2ConfigCPU() ai = aitextgen(tokenizer_file=tokenizer_file, config=config) - data = TokenDataset(file_name, tokenizer_file=tokenizer_file, block_size=64) - ai.train(data, batch_size=8, num_steps=50000, generate_every=5000, save_every=5000) - print("AI has been trained!") + data = TokenDataset( + file_name, tokenizer_file=tokenizer_file, block_size=64) + ai.train(data, batch_size=8, num_steps=50000, + generate_every=5000, save_every=5000) -# ? send message to parent JS process -print(generate_message()) -sys.stdout.flush() \ No newline at end of file + +match sys.argv[1]: + case "generate": + # ? send message to parent JS process + print(generate_message()) + sys.stdout.flush() + case "train": + train_ai() + case _: + raise NameError("Argument not provided.")