diff --git a/generate.py b/generate.py new file mode 100644 index 0000000..c57576c --- /dev/null +++ b/generate.py @@ -0,0 +1,30 @@ +from aitextgen.TokenDataset import TokenDataset +from aitextgen.tokenizers import train_tokenizer +from aitextgen.utils import GPT2ConfigCPU +from aitextgen import aitextgen +import json +import sys + +with open('config.json', 'r') as file: + json_object = json.load(file) + +file_name = json_object['file'] + +def generate_message(prompt=None): + ai = aitextgen(model_folder="trained_model", + tokenizer_file="aitextgen.tokenizer.json", prompt=prompt) + return ai.generate() + +match sys.argv[1]: + case "prompt": + try: + msg = generate_message(' '.join(map(str, sys.argv[1:]))) + print(msg) + except IndexError: + print(generate_message()) + finally: + sys.stdout.flush() + + case _: + msg = generate_message() + print(msg) diff --git a/index.js b/index.js index fb2ece5..c8aa6d3 100644 --- a/index.js +++ b/index.js @@ -2,75 +2,59 @@ import config from './config.json' assert {type: "json"}; import { MatrixClient, SimpleFsStorageProvider, AutojoinRoomsMixin } from "matrix-bot-sdk"; import fs from "fs"; import { PythonShell } from 'python-shell'; -import { type } from 'os'; +// import { type } from 'os'; const storage = new SimpleFsStorageProvider("storage.json"); const client = new MatrixClient(config.homeserver, config.token, storage); -const pyFile = "textgen.py"; +const pyFile = "generate.py"; AutojoinRoomsMixin.setupOnClient(client); client.start().then(() => console.log(`Client has started!\n`)); const messageCounters = new Map(); // room ID, message count -let trainingCounter = 0; client.on("room.message", (roomId, event) => { - if (!event["content"] || event["sender"] === config.user) return; + if (!event["content"] || event["sender"] === config.user) return; // ? ignore if message sent by bot itself or is empty - ++trainingCounter; - messageCounters.set(roomId, (messageCounters.get(roomId) ?? 0) + 1); - let userMessage = event["content"]["body"].split(" "); + let messageArray = event["content"]["body"].split(" "); - console.log(`COUNTER:\t${messageCounters.get(roomId)}\t${roomId}\t${userMessage.join(" ")}`); - - - if (userMessage[0].startsWith(config.prefix)) { - userMessage[0] = userMessage[0].replace(config.prefix, '').toLowerCase(); - } else { - fs.appendFile(config.file, userMessage.join(" ") + "\n", function (err) { + if (!(messageArray[0] === config.prefix)) { + messageCounters.set(roomId, (messageCounters.get(roomId) ?? 0) + 1); + console.log(`COUNTER:\t${messageCounters.get(roomId)}\t${roomId}\t${event["content"]["body"]}`); + fs.appendFile(config.file, event["content"]["body"] + "\n", function (err) { if (err) throw err; }); - }; + return; + } // ? if message does not start with prefix log it for training + + + messageArray.shift() // ? remove bot's prefix from array // ? send message if: // ? - enough messages have been sent // ? - commanded - if (!(messageCounters.get(roomId) % config.frequency) || userMessage[0] === "speak") { + if (!(messageCounters.get(roomId) % config.frequency) || messageArray[0].toLowerCase() === "speak") { console.log("Generating message..."); - - userMessage.shift() - userMessage = userMessage.join(" ") - fs.appendFile(config.file, userMessage + "\n", function (err) { - if (err) throw err; - }); - const options = { args: ['generate', userMessage] }; + + const options = { args: ["", ""] }; PythonShell.run(pyFile, options, (err, message) => { if (err) throw err; client.sendText(roomId, message.toString()); console.log("Message sent!"); }); // ? send generated message to room - }; - // ? retrain if: - // ? - enough message have been sent - // ? - commanded - if (trainingCounter >= config.retrain || userMessage[0] === "train") { - console.log("Retraining the AI..."); - client.sendText(roomId, "Retraining the AI..."); - - trainingCounter = 0; - const options = { args: ['train'] }; + } else if (messageArray[0] === "prompt") { + console.log("prompted to generate...") + const options = { args: ['prompt', messageArray.join(" ")] }; PythonShell.run(pyFile, options, (err, message) => { if (err) throw err; - console.log(message.toString()); - }); - console.log("Training finished!"); - client.sendText(roomId, "Training finished!"); - }; + client.sendText(roomId, message.toString()); + console.log("Message sent!"); + }); // ? send prompted message to room + + } else { + console.log("Invalid command") + } }); -function lineCount(text) { - return fs.readFileSync(text).toString().split("\n").length - 1; -}; - diff --git a/textgen.py b/textgen.py deleted file mode 100644 index 8356164..0000000 --- a/textgen.py +++ /dev/null @@ -1,49 +0,0 @@ -from aitextgen.TokenDataset import TokenDataset -from aitextgen.tokenizers import train_tokenizer -from aitextgen.utils import GPT2ConfigCPU -from aitextgen import aitextgen -import json -import sys - -with open('config.json', 'r') as file: - json_object = json.load(file) - -file_name = json_object['file'] - -# ? generate message using trained model - - -def generate_message(prompt=None): - # ai = aitextgen(prompt=prompt) - ai = aitextgen(model_folder="trained_model", - tokenizer_file="aitextgen.tokenizer.json", prompt=prompt) - return ai.generate() - -# ? train model using text file - - -def train_ai(): - train_tokenizer(file_name) - tokenizer_file = "aitextgen.tokenizer.json" - config = GPT2ConfigCPU() - ai = aitextgen(tokenizer_file=tokenizer_file, config=config) - data = TokenDataset( - file_name, tokenizer_file=tokenizer_file, block_size=64) - ai.train(data, batch_size=8, num_steps=50000, - generate_every=5000, save_every=5000) - - -match sys.argv[1]: - case "generate": - # ? send message to parent JS process - try: - prompt = ' '.join(map(str, sys.argv[2:])) - generate_message(prompt) - except IndexError: - generate_message() - finally: - sys.stdout.flush() - case "train": - train_ai() - case _: - raise NameError("Argument not provided.") diff --git a/train.py b/train.py new file mode 100644 index 0000000..d4cbadb --- /dev/null +++ b/train.py @@ -0,0 +1,20 @@ +from aitextgen.TokenDataset import TokenDataset +from aitextgen.tokenizers import train_tokenizer +from aitextgen.utils import GPT2ConfigCPU +from aitextgen import aitextgen +import json + +with open('config.json', 'r') as file: + json_object = json.load(file) + +file_name = json_object['file'] + +train_tokenizer(file_name) +tokenizer_file = "aitextgen.tokenizer.json" +config = GPT2ConfigCPU() + +ai = aitextgen(tokenizer_file=tokenizer_file, config=config) +data = TokenDataset( + file_name, tokenizer_file=tokenizer_file, block_size=64) +ai.train(data, batch_size=8, num_steps=50000, + generate_every=5000, save_every=5000) \ No newline at end of file