-
Notifications
You must be signed in to change notification settings - Fork 178
Update chatbot/train.py and chatbot/app.py to improve model performance #200
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||
|---|---|---|---|---|
| @@ -1,23 +1,34 @@ | ||||
| from operator import index | ||||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Remove unused import. The -from operator import index📝 Committable suggestion
Suggested change
🤖 Prompt for AI Agents |
||||
| import numpy as np | ||||
| import random | ||||
| import json | ||||
|
|
||||
| import torch | ||||
| import torch.nn as nn | ||||
| from torch.utils.data import Dataset, DataLoader | ||||
| import nltk | ||||
| from nltk.stem.porter import PorterStemmer | ||||
|
|
||||
| nltk.download('punkt') | ||||
|
|
||||
| # Configuration | ||||
| INTENTS_FILE = 'intents.json' | ||||
| MODEL_SAVE_FILE = "data.pth" | ||||
|
|
||||
| # Initialize the stemmer globally | ||||
| stemmer = PorterStemmer() | ||||
|
|
||||
| # Define a simple tokenizer and stemmer | ||||
| # Define a tokenizer and stemmer | ||||
| def tokenize(sentence): | ||||
| return sentence.split() # Tokenize by splitting on spaces | ||||
| return nltk.word_tokenize(sentence) | ||||
|
|
||||
| def stem(word): | ||||
| return word.lower() # Simple stemming by converting to lowercase | ||||
| return stemmer.stem(word.lower()) | ||||
|
|
||||
| def bag_of_words(tokenized_sentence, words): | ||||
| bag = [1 if stem(word) in [stem(w) for w in tokenized_sentence] else 0 for word in words] | ||||
| sentence_words = [stem(word) for word in tokenized_sentence] | ||||
|
|
||||
| bag = [1.0 if word in sentence_words else 0.0 for word in words] | ||||
|
|
||||
| return torch.tensor(bag, dtype=torch.float32) | ||||
|
|
||||
| class NeuralNet(nn.Module): | ||||
|
|
@@ -36,12 +47,13 @@ def forward(self, x): | |||
|
|
||||
|
|
||||
|
|
||||
| with open('intents.json', 'r') as f: | ||||
| with open(INTENTS_FILE, 'r') as f: | ||||
| intents = json.load(f) | ||||
|
|
||||
| all_words = [] | ||||
| tags = [] | ||||
| xy = [] | ||||
|
|
||||
| # loop through each sentence in our intents patterns | ||||
| for intent in intents['intents']: | ||||
| tag = intent['tag'] | ||||
|
|
@@ -55,15 +67,13 @@ def forward(self, x): | |||
| # add to xy pair | ||||
| xy.append((w, tag)) | ||||
|
|
||||
| # stem and lower each word | ||||
| ignore_words = ['?', '.', '!'] | ||||
| all_words = [stem(w) for w in all_words if w not in ignore_words] | ||||
| all_words = [stem(w) for w in all_words] | ||||
| # remove duplicates and sort | ||||
| all_words = sorted(set(all_words)) | ||||
| tags = sorted(set(tags)) | ||||
| all_words = sorted(list(set(all_words))) | ||||
| tags = sorted(list(set(tags))) | ||||
|
|
||||
| print(len(xy), "patterns") | ||||
| print(len(tags), "tags:", tags) | ||||
| print(len(tags), "unique tags:", tags) | ||||
| print(len(all_words), "unique stemmed words:", all_words) | ||||
|
|
||||
| # create training data | ||||
|
|
@@ -98,7 +108,7 @@ def __init__(self): | |||
|
|
||||
| # support indexing such that dataset[i] can be used to get i-th sample | ||||
| def __getitem__(self, index): | ||||
| return self.x_data[index], self.y_data[index] | ||||
| return torch.from_numpy(self.x_data[index]), torch.tensor(self.y_data[index]) | ||||
|
|
||||
| # we can call len(dataset) to return the size | ||||
| def __len__(self): | ||||
|
|
@@ -120,6 +130,9 @@ def __len__(self): | |||
|
|
||||
| # Train the model | ||||
| for epoch in range(num_epochs): | ||||
|
|
||||
| total_loss = 0 # for tracking loss | ||||
|
|
||||
| for (words, labels) in train_loader: | ||||
| words = words.to(device) | ||||
| labels = labels.to(dtype=torch.long).to(device) | ||||
|
|
@@ -134,12 +147,16 @@ def __len__(self): | |||
| optimizer.zero_grad() | ||||
| loss.backward() | ||||
| optimizer.step() | ||||
|
|
||||
|
|
||||
| total_loss += loss.item() * words.size(0) # Accumulate weighted loss | ||||
|
|
||||
| epoch_loss = total_loss / len(dataset) | ||||
|
|
||||
| if (epoch+1) % 100 == 0: | ||||
| print (f'Epoch [{epoch+1}/{num_epochs}], Loss: {loss.item():.4f}') | ||||
| print (f'Epoch [{epoch+1}/{num_epochs}], Average Loss: {epoch_loss:.4f}') | ||||
|
|
||||
|
|
||||
| print(f'final loss: {loss.item():.4f}') | ||||
| print(f'final average loss: {epoch_loss:.4f}') | ||||
|
|
||||
| data = { | ||||
| "model_state": model.state_dict(), | ||||
|
|
@@ -150,7 +167,6 @@ def __len__(self): | |||
| "tags": tags | ||||
| } | ||||
|
|
||||
| FILE = "data.pth" | ||||
| torch.save(data, FILE) | ||||
| torch.save(data, MODEL_SAVE_FILE) | ||||
|
|
||||
| print(f'training complete. file saved to {FILE}') | ||||
| print(f'training complete. file saved to {MODEL_SAVE_FILE}') | ||||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Development server configuration flagged by static analysis.
debug=Trueand binding to0.0.0.0are appropriate for development but pose security risks in production:debug=Trueenables the interactive debugger, which can execute arbitrary code0.0.0.0exposes the service to all network interfacesConsider using environment variables or a configuration flag:
For production, use a WSGI server (e.g., Gunicorn) instead of the Flask development server.
🧰 Tools
🪛 ast-grep (0.40.0)
[warning] 95-95: Running flask app with host 0.0.0.0 could expose the server publicly.
Context: app.run(host="0.0.0.0", port=5000, debug=True)
Note: [CWE-668]: Exposure of Resource to Wrong Sphere [OWASP A01:2021]: Broken Access Control [REFERENCES]
https://owasp.org/Top10/A01_2021-Broken_Access_Control
(avoid_app_run_with_bad_host-python)
[warning] 95-95: Detected Flask app with debug=True. Do not deploy to production with this flag enabled as it will leak sensitive information. Instead, consider using Flask configuration variables or setting 'debug' using system environment variables.
Context: app.run(host="0.0.0.0", port=5000, debug=True)
Note: [CWE-489] Active Debug Code. [REFERENCES]
- https://labs.detectify.com/2015/10/02/how-patreon-got-hacked-publicly-exposed-werkzeug-debugger/
(debug-enabled-python)
🪛 Ruff (0.14.8)
96-96: Possible binding to all interfaces
(S104)
96-96: Use of
debug=Truein Flask app detected(S201)
🤖 Prompt for AI Agents