31 lines
811 B
Python
31 lines
811 B
Python
|
from aitextgen.TokenDataset import TokenDataset
|
||
|
from aitextgen.tokenizers import train_tokenizer
|
||
|
from aitextgen.utils import GPT2ConfigCPU
|
||
|
from aitextgen import aitextgen
|
||
|
import json
|
||
|
import sys
|
||
|
|
||
|
with open('config.json', 'r') as file:
|
||
|
json_object = json.load(file)
|
||
|
|
||
|
file_name = json_object['file']
|
||
|
|
||
|
def generate_message(prompt=None):
|
||
|
ai = aitextgen(model_folder="trained_model",
|
||
|
tokenizer_file="aitextgen.tokenizer.json", prompt=prompt)
|
||
|
return ai.generate()
|
||
|
|
||
|
match sys.argv[1]:
|
||
|
case "prompt":
|
||
|
try:
|
||
|
msg = generate_message(' '.join(map(str, sys.argv[1:])))
|
||
|
print(msg)
|
||
|
except IndexError:
|
||
|
print(generate_message())
|
||
|
finally:
|
||
|
sys.stdout.flush()
|
||
|
|
||
|
case _:
|
||
|
msg = generate_message()
|
||
|
print(msg)
|