Compare commits
4 commits
55b9a2693e
...
1b4b769b7a
Author | SHA1 | Date | |
---|---|---|---|
1b4b769b7a | |||
8c6ad6d5a3 | |||
5075827551 | |||
651a2b0729 |
10
README.md
10
README.md
|
@ -46,11 +46,13 @@ Before a bot can be used the fields in the `config.json` file must be populated
|
|||
|
||||
► user* ⇢ Account's User ID.
|
||||
|
||||
► file ⇢ Path of file used for training the AI (.txt file only).
|
||||
► file ⇢ Path of file used for training the AI (.txt file only).
|
||||
|
||||
► prefix ⇢ Bot listens to commands that start with this prefix.
|
||||
► prefix ⇢ Bot listens to commands that start with this prefix.
|
||||
|
||||
► frequency ⇢ How often the bot sends a message (keep low to prevent spam).
|
||||
► frequency ⇢ How often the bot sends a message (keep low to prevent spam).
|
||||
|
||||
► size ⇢ Bot starts generating messages when the number of lines in the training file is equal to this. The greater the size, the longer bot waits before messaging but might increase message quality.
|
||||
► size ⇢ Bot starts generating messages when the number of lines in the training file is equal to this. The greater the size, the longer bot waits before messaging but might increase message quality.
|
||||
|
||||
► retrain ⇢ The bot retrains itself after this many extra lines of messages are recorded in the text file.
|
||||
```
|
||||
|
|
|
@ -5,5 +5,6 @@
|
|||
"file": "training-matrix.txt",
|
||||
"prefix": "!",
|
||||
"frequency": "25",
|
||||
"size": "5000"
|
||||
"size": "5000",
|
||||
"retrain": "10000"
|
||||
}
|
||||
|
|
8
index.js
8
index.js
|
@ -21,21 +21,21 @@ client.on("room.message", (roomId, event) => {
|
|||
});
|
||||
|
||||
if (lineCount(config.file) < config.size) return; // ? don't start generating messages until a big enough dataset is present
|
||||
|
||||
// TODO: train AI every Nth message?
|
||||
// ? send message every N messages using the training data
|
||||
if (!(messageCounter % config.frequency)) {
|
||||
console.log("Generating message...");
|
||||
|
||||
const python = spawn('python', ["textgen.py"]);
|
||||
const python = spawn('python', ["textgen.py", "generate"]);
|
||||
|
||||
python.stdout.on('data', function (message) {
|
||||
message = message.toString();
|
||||
console.log("bot:\t" + message);
|
||||
client.sendText(roomId, message);
|
||||
});
|
||||
python.on('close'); // ? close python process when finished
|
||||
};
|
||||
|
||||
// TODO: train AI every Nth message?
|
||||
|
||||
});
|
||||
|
||||
function lineCount(text) {
|
||||
|
|
29
textgen.py
29
textgen.py
|
@ -2,7 +2,8 @@ from aitextgen.TokenDataset import TokenDataset
|
|||
from aitextgen.tokenizers import train_tokenizer
|
||||
from aitextgen.utils import GPT2ConfigCPU
|
||||
from aitextgen import aitextgen
|
||||
import json, sys
|
||||
import json
|
||||
import sys
|
||||
|
||||
with open('config.json', 'r') as file:
|
||||
json_object = json.load(file)
|
||||
|
@ -10,21 +11,33 @@ with open('config.json', 'r') as file:
|
|||
file_name = json_object['file']
|
||||
|
||||
# ? generate message using trained model
|
||||
|
||||
|
||||
def generate_message():
|
||||
ai = aitextgen(model_folder="trained_model",
|
||||
tokenizer_file="aitextgen.tokenizer.json")
|
||||
tokenizer_file="aitextgen.tokenizer.json")
|
||||
ai.generate()
|
||||
|
||||
# ? train model using text file
|
||||
|
||||
|
||||
def train_ai():
|
||||
train_tokenizer(file_name)
|
||||
tokenizer_file = "aitextgen.tokenizer.json"
|
||||
config = GPT2ConfigCPU()
|
||||
ai = aitextgen(tokenizer_file=tokenizer_file, config=config)
|
||||
data = TokenDataset(file_name, tokenizer_file=tokenizer_file, block_size=64)
|
||||
ai.train(data, batch_size=8, num_steps=50000, generate_every=5000, save_every=5000)
|
||||
print("AI has been trained!")
|
||||
data = TokenDataset(
|
||||
file_name, tokenizer_file=tokenizer_file, block_size=64)
|
||||
ai.train(data, batch_size=8, num_steps=50000,
|
||||
generate_every=5000, save_every=5000)
|
||||
|
||||
# ? send message to parent JS process
|
||||
print(generate_message())
|
||||
sys.stdout.flush()
|
||||
|
||||
match sys.argv[1]:
|
||||
case "generate":
|
||||
# ? send message to parent JS process
|
||||
print(generate_message())
|
||||
sys.stdout.flush()
|
||||
case "train":
|
||||
train_ai()
|
||||
case _:
|
||||
raise NameError("Argument not provided.")
|
||||
|
|
Loading…
Reference in a new issue