check arg passed

This commit is contained in:
array-in-a-matrix 2022-08-16 03:46:05 -04:00
parent 8c6ad6d5a3
commit 1b4b769b7a
1 changed files with 21 additions and 8 deletions

View File

@ -2,7 +2,8 @@ from aitextgen.TokenDataset import TokenDataset
from aitextgen.tokenizers import train_tokenizer
from aitextgen.utils import GPT2ConfigCPU
from aitextgen import aitextgen
import json, sys
import json
import sys
with open('config.json', 'r') as file:
json_object = json.load(file)
@ -10,21 +11,33 @@ with open('config.json', 'r') as file:
file_name = json_object['file']
# ? generate message using trained model
def generate_message():
ai = aitextgen(model_folder="trained_model",
tokenizer_file="aitextgen.tokenizer.json")
tokenizer_file="aitextgen.tokenizer.json")
ai.generate()
# ? train model using text file
def train_ai():
train_tokenizer(file_name)
tokenizer_file = "aitextgen.tokenizer.json"
config = GPT2ConfigCPU()
ai = aitextgen(tokenizer_file=tokenizer_file, config=config)
data = TokenDataset(file_name, tokenizer_file=tokenizer_file, block_size=64)
ai.train(data, batch_size=8, num_steps=50000, generate_every=5000, save_every=5000)
print("AI has been trained!")
data = TokenDataset(
file_name, tokenizer_file=tokenizer_file, block_size=64)
ai.train(data, batch_size=8, num_steps=50000,
generate_every=5000, save_every=5000)
# ? send message to parent JS process
print(generate_message())
sys.stdout.flush()
match sys.argv[1]:
case "generate":
# ? send message to parent JS process
print(generate_message())
sys.stdout.flush()
case "train":
train_ai()
case _:
raise NameError("Argument not provided.")