pattern = r'\[(.*?)\] (.*?): (.*)'matches = re.findall(pattern, text)text = [(x1, x2.lower()) for x0, x1, x2 in matches]
[ (2018-03-12 16:03:59, "Alice", "Hi, how are you guys?"), (2018-03-12 16:05:36, "Tom", "I am good thanks!"), ...]
import json
import torch
from config import eval_interval, learn_rate, max_iters
from src.model import GPTLanguageModel
from src.utils import current_time, estimate_loss, get_batch
def model_training(update: bool) -> None:
"""
Trains or updates a GPTLanguageModel using pre-loaded data.
This function either initializes a new model or loads an existing model based
on the `update` parameter. It then trains the model using the AdamW optimizer
on the training and validation data sets. Finally the trained model is saved.
:param update: Boolean flag to indicate whether to update an existing model.
"""
# LOAD DATA -----------------------------------------------------------------
train_data = torch.load("assets/output/train.pt")
valid_data = torch.load("assets/output/valid.pt")
with open("assets/output/vocab.txt", "r", encoding="utf-8") as f:
vocab = json.loads(f.read())
# INITIALIZE / LOAD MODEL ---------------------------------------------------
if update:
try:
model = torch.load("assets/models/model.pt")
print("Loaded existing model to continue training.")
except FileNotFoundError:
print("No existing model found. Initializing a new model.")
model = GPTLanguageModel(vocab_size=len(vocab))
else:
print("Initializing a new model.")
model = GPTLanguageModel(vocab_size=len(vocab))
# initialize optimizer
optimizer = torch.optim.AdamW(model.parameters(), lr=learn_rate)
# number of model parameters
n_params = sum(p.numel() for p in model.parameters())
print(f"Parameters to be optimized: {n_params}\n", )
# MODEL TRAINING ------------------------------------------------------------
for i in range(max_iters):
# evaluate the loss on train and valid sets every 'eval_interval' steps
if i % eval_interval == 0 or i == max_iters - 1:
train_loss = estimate_loss(model, train_data)
valid_loss = estimate_loss(model, valid_data)
time = current_time()
print(f"{time} | step {i}: train loss {train_loss:.4f}, valid loss {valid_loss:.4f}")
# sample batch of data
x_batch, y_batch = get_batch(train_data)
# evaluate the loss
logits, loss = model(x_batch, y_batch)
optimizer.zero_grad(set_to_none=True)
loss.backward()
optimizer.step()
torch.save(model, "assets/models/model.pt")
print("Model saved")
7. Chat-Mode
To interact with the trained model, I created a function that allows selecting a contact name via a dropdown menu and inputting a message for the model to respond to. The parameter “n_chats” determines the number of responses the model generates at once. The model concludes a generated message when it predicts the <END> token as the next token.
import json
import random
import torch
from prompt_toolkit import prompt
from prompt_toolkit.completion import WordCompleter
from config import end_token, n_chats
from src.utils import custom_tokenizer, decode, encode, print_delayed
def conversation() -> None:
"""
Emulates chat conversations by sampling from a pre-trained GPTLanguageModel.
This function loads a trained GPTLanguageModel along with vocabulary and
the list of special tokens. It then enters into a loop where the user specifies
a contact. Given this input, the model generates a sample response. The conversation
continues until the user inputs the end token.
"""
with open("assets/output/vocab.txt", "r", encoding="utf-8") as f:
vocab = json.loads(f.read())
with open("assets/output/contacts.txt", "r", encoding="utf-8") as f:
contacts = json.loads(f.read())
spec_tokens = contacts + [end_token]
model = torch.load("assets/models/model.pt")
completer = WordCompleter(spec_tokens, ignore_case=True)
input = prompt("message >> ", completer=completer, default="")
output = torch.tensor([], dtype=torch.long)
print()
while input != end_token:
for _ in range(n_chats):
add_tokens = custom_tokenizer(input, spec_tokens)
add_context = encode(add_tokens, vocab)
context = torch.cat((output, add_context)).unsqueeze(1).T
n0 = len(output)
output = model.generate(context, vocab)
n1 = len(output)
print_delayed(decode(output[n0-n1:], vocab))
input = random.choice(contacts)
input = prompt("\nresponse >> ", completer=completer, default="")
print()