diff --git a/textgen.py b/textgen.py index 437bd5e..f26bfbb 100644 --- a/textgen.py +++ b/textgen.py @@ -13,9 +13,9 @@ file_name = json_object['file'] # ? generate message using trained model -def generate_message(): +def generate_message(prompt): ai = aitextgen(model_folder="trained_model", - tokenizer_file="aitextgen.tokenizer.json") + tokenizer_file="aitextgen.tokenizer.json", prompt=prompt) ai.generate() # ? train model using text file @@ -35,7 +35,7 @@ def train_ai(): match sys.argv[1]: case "generate": # ? send message to parent JS process - print(generate_message()) + print(generate_message(sys.argv[2])) sys.stdout.flush() case "train": train_ai()