Eat System - Machine Learning Model¶
File: cogs/eat/train/train.py
The Train class encapsulates the machine learning component of the "Eat What" system. It is responsible for training a personalized recommendation model for each server and using that model to predict food preferences.
Train Class¶
__init__(self, db: DB, ...)¶
Initializes the training manager.
- Parameters:
db(DB): An instance of the database class, used to fetch training data.- Other parameters control the model's architecture and training hyperparameters (embedding dimensions, hidden layers, learning rate, etc.).
genModel(self, discord_id: str)¶
This is the core training method. It builds and trains a recommendation model based on a server's historical data.
- Parameters:
discord_id(str): The ID of the server for which to train the model.
- Process:
- Data Loading: It uses a
DataLoaderto fetch all search records for the givendiscord_idfrom the database. - Data Processing: It processes the raw data, extracting features like restaurant titles, tags, and user-provided keywords.
- Vocabulary Creation: It builds a vocabulary of all unique words from the processed data.
- Tensor Conversion: The data is converted into numerical tensors that the model can understand.
- Model Training: It initializes a
Netmodel (a neural network defined incogs/eat/train/model.py) and trains it on the tensor data for a set number of epochs. The user's ratings (self_ratefrom the database) implicitly influence the training data distribution. - Saving: After training, the model's state (
.modelfile) and the vocabulary mapping (.picklefile) are saved to themodels/directory, named after thediscord_id.
- Data Loading: It uses a
predict(self, discord_id: str)¶
Uses a pre-trained model to generate a food recommendation.
- Parameters:
discord_id(str): The ID of the server to generate a prediction for.
- Process:
- It loads the saved model and vocabulary files for the specified
discord_id. - It feeds a random seed from the server's vocabulary into the model.
- The model outputs a probability distribution over the entire vocabulary.
- It uses this distribution to randomly select and return a predicted keyword.
- It loads the saved model and vocabulary files for the specified
- Returns: A string containing the predicted food keyword (e.g., "ramen"). Returns
Noneif no model exists for the server.