From 1b4b769b7a5af81eb538fdd8d79bdb26f983dc1d Mon Sep 17 00:00:00 2001 From: array-in-a-matrix Date: Tue, 16 Aug 2022 03:46:05 -0400 Subject: [PATCH] check arg passed --- textgen.py | 29 +++++++++++++++++++++-------- 1 file changed, 21 insertions(+), 8 deletions(-) diff --git a/textgen.py b/textgen.py index 549b638..437bd5e 100644 --- a/textgen.py +++ b/textgen.py @@ -2,7 +2,8 @@ from aitextgen.TokenDataset import TokenDataset from aitextgen.tokenizers import train_tokenizer from aitextgen.utils import GPT2ConfigCPU from aitextgen import aitextgen -import json, sys +import json +import sys with open('config.json', 'r') as file: json_object = json.load(file) @@ -10,21 +11,33 @@ with open('config.json', 'r') as file: file_name = json_object['file'] # ? generate message using trained model + + def generate_message(): ai = aitextgen(model_folder="trained_model", - tokenizer_file="aitextgen.tokenizer.json") + tokenizer_file="aitextgen.tokenizer.json") ai.generate() # ? train model using text file + + def train_ai(): train_tokenizer(file_name) tokenizer_file = "aitextgen.tokenizer.json" config = GPT2ConfigCPU() ai = aitextgen(tokenizer_file=tokenizer_file, config=config) - data = TokenDataset(file_name, tokenizer_file=tokenizer_file, block_size=64) - ai.train(data, batch_size=8, num_steps=50000, generate_every=5000, save_every=5000) - print("AI has been trained!") + data = TokenDataset( + file_name, tokenizer_file=tokenizer_file, block_size=64) + ai.train(data, batch_size=8, num_steps=50000, + generate_every=5000, save_every=5000) -# ? send message to parent JS process -print(generate_message()) -sys.stdout.flush() \ No newline at end of file + +match sys.argv[1]: + case "generate": + # ? send message to parent JS process + print(generate_message()) + sys.stdout.flush() + case "train": + train_ai() + case _: + raise NameError("Argument not provided.")