Skip to content

Commit

Permalink
מיון מחדש ושיפור מודל
Browse files Browse the repository at this point in the history
  • Loading branch information
NHLOCAL committed Aug 31, 2024
1 parent f96729d commit 5615906
Show file tree
Hide file tree
Showing 11 changed files with 53 additions and 4 deletions.
Binary file removed machine-learn/find_typename/trained_model.pkl
Binary file not shown.
File renamed without changes.
File renamed without changes.
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,8 @@

# הגדרת רשת פרמטרים לבדיקה
param_grid = {
'logisticregression__max_iter': [200, 500, 1000], # בדיקת ערכים גדולים יותר
'tfidfvectorizer__ngram_range': [(1, 1), (1, 2)],
'logisticregression__max_iter': [200], # בדיקת ערכים גדולים יותר
'tfidfvectorizer__ngram_range': [(1, 2)],
}

# יצירת אובייקט GridSearchCV
Expand All @@ -55,11 +55,11 @@
print(f'דיוק המודל הטוב ביותר: {accuracy * 100:.2f}%')

# שמירת המודל
with open('trained_model.pkl', 'wb') as f:
with open('music_classifier.pkl', 'wb') as f:
pickle.dump(best_model, f)

# טעינת המודל (לא חובה, רק להדגמה)
with open('trained_model.pkl', 'rb') as f:
with open('music_classifier.pkl', 'rb') as f:
loaded_model = pickle.load(f)

# דוגמה לשימוש במודל לחיזוי על טקסט חדש
Expand Down
File renamed without changes.
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
import pickle
import csv
from sklearn import metrics
import matplotlib.pyplot as plt
import seaborn as sns

# מיפוי תוויות מספריות לשמות קטגוריות
label_mapping = {0: "ARTIST", 1: "ALBUM", 2: "SONG", 3: "RANDOM"}

# טעינת המודל
with open('music_classifier.pkl', 'rb') as f:
loaded_model = pickle.load(f)

# רשימות ריקות לטעינת הדאטה
texts = []
labels = []

# קריאת הדאטה מקובץ CSV
with open('dataset.csv', newline='', encoding='utf-8') as csvfile:
reader = csv.DictReader(csvfile)

for row in reader:
try:
texts.append(row['text'])
labels.append(int(row['label'])) # המרת תוויות למספרים שלמים
except Exception as e:
print(f"שגיאה בקריאת שורה: {e}, דילוג על שורה")

# חיזוי על כל הדאטה
predicted = loaded_model.predict(texts)

# יצירת Confusion Matrix
cm = metrics.confusion_matrix(labels, predicted)

# ויזואליזציה של Confusion Matrix
plt.figure(figsize=(10, 7))
sns.heatmap(cm, annot=True, fmt='d', cmap='Blues',
xticklabels=label_mapping.values(), yticklabels=label_mapping.values())
plt.xlabel('Predicted Label')
plt.ylabel('True Label')
plt.title('Confusion Matrix')
plt.show()

# הדפסת דוח סיווג
print(metrics.classification_report(labels, predicted, target_names=label_mapping.values()))

# חישוב דיוק כללי
accuracy = metrics.accuracy_score(labels, predicted)
print(f'דיוק כללי: {accuracy * 100:.2f}%')
Binary file not shown.

0 comments on commit 5615906

Please sign in to comment.