text-gen-bot/generate.py

31 lines
811 B
Python

from aitextgen.TokenDataset import TokenDataset
from aitextgen.tokenizers import train_tokenizer
from aitextgen.utils import GPT2ConfigCPU
from aitextgen import aitextgen
import json
import sys
with open('config.json', 'r') as file:
json_object = json.load(file)
file_name = json_object['file']
def generate_message(prompt=None):
ai = aitextgen(model_folder="trained_model",
tokenizer_file="aitextgen.tokenizer.json", prompt=prompt)
return ai.generate()
match sys.argv[1]:
case "prompt":
try:
msg = generate_message(' '.join(map(str, sys.argv[1:])))
print(msg)
except IndexError:
print(generate_message())
finally:
sys.stdout.flush()
case _:
msg = generate_message()
print(msg)