Skip to content

Commit

Permalink
fix doc bug & slides & paper
Browse files Browse the repository at this point in the history
  • Loading branch information
Alexhaoge committed Dec 28, 2020
1 parent 32e5fb1 commit 4b4e332
Show file tree
Hide file tree
Showing 14 changed files with 15 additions and 0 deletions.
4 changes: 4 additions & 0 deletions MLSR/plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
def plot_confusion_matrix(cm, classes, filename, title='Confusion matrix', cmap=plt.cm.Blues):
"""
绘制混淆矩阵,如果本地有图形界面或
Args:
cm: 混淆矩阵,numpy.ndarray
classes: 类名
Expand Down Expand Up @@ -43,6 +44,7 @@ def plot_confusion_matrix(cm, classes, filename, title='Confusion matrix', cmap=
def plot_roc(model, X, y, filename):
"""
画roc图,不过sklearn只支持二分类roc,三分类画不了
Args:
model: 模型
X: 特征
Expand All @@ -63,6 +65,7 @@ def plot_roc(model, X, y, filename):
def plot_tsne(data: DataSet, filename: str, n_iter: int = 1000):
"""
tSNE降维绘图
Args:
data: 数据集
filename: 图片保存路径
Expand Down Expand Up @@ -92,6 +95,7 @@ def plot_tsne(data: DataSet, filename: str, n_iter: int = 1000):
def plot_tsne_ssl(data: DataSet, filename: str, n_iter: int = 1000):
"""
半监督数据集的tSNE降维绘图
Args:
data: 数据集
filename: 图片保存路径
Expand Down
9 changes: 9 additions & 0 deletions MLSR/primary.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ def lower_bound(cv_results: dict):
Calculate the lower bound within 1 standard deviation
of the best `mean_test_scores`.
Author: Wenhao Zhang <[email protected]>
Args:
cv_results: dict of numpy(masked) ndarrays
See attribute cv_results_ of `GridSearchCV`
Expand All @@ -39,6 +40,7 @@ def best_low_complexity(cv_results: dict):
"""
Balance model complexity with cross-validated score.
Author: Wenhao Zhang <[email protected]>
Args:
cv_results: dict of numpy(masked) ndarrays
See attribute cv_results_ of `GridSearchCV`.
Expand Down Expand Up @@ -72,6 +74,7 @@ def grid_search_and_result(
fit_params: dict = None):
"""
交叉验证网格搜索,测试集和训练集得分,混淆矩阵和ROC曲线绘制
Args:
Xtrain: 训练集特征
ytrain: 训练集标签
Expand Down Expand Up @@ -137,6 +140,7 @@ def grid_search_and_result(
def do_decision_tree(dataset: DataSet, log_dir: str = '../log', grid: dict = None):
"""
训练决策树
Args:
grid:超参数搜索空间的网格,不填则使用默认搜索空间
dataset:输入数据集,将会按照0.7, 0.3比例分为训练集和测试集
Expand Down Expand Up @@ -167,6 +171,7 @@ def do_decision_tree(dataset: DataSet, log_dir: str = '../log', grid: dict = Non
def do_random_forest(dataset: DataSet, log_dir: str = '../log', grid: dict = None):
"""
训练随机森林
Args:
grid:超参数搜索空间的网格,不填则使用默认搜索空间
dataset:输入数据集,将会按照0.7, 0.3比例分为训练集和测试集
Expand Down Expand Up @@ -212,6 +217,7 @@ def do_random_forest(dataset: DataSet, log_dir: str = '../log', grid: dict = Non
def do_svm(dataset: DataSet, log_dir: str = '../log', grid: dict = None):
"""
训练支持向量机
Args:
grid:超参数搜索空间的网格,不填则使用默认搜索空间
dataset:输入数据集,将会按照0.7, 0.3比例分为训练集和测试集
Expand Down Expand Up @@ -255,6 +261,7 @@ def do_svm(dataset: DataSet, log_dir: str = '../log', grid: dict = None):
def do_logistic(dataset: DataSet, log_dir: str = '../log', grid: dict = None):
"""
训练逻辑回归
Args:
grid:超参数搜索空间的网格,不填则使用默认搜索空间
dataset:输入数据集,将会按照0.7, 0.3比例分为训练集和测试集
Expand Down Expand Up @@ -286,6 +293,7 @@ def do_logistic(dataset: DataSet, log_dir: str = '../log', grid: dict = None):
def do_naive_bayes(dataset: DataSet, log_dir: str = '../log', grid: dict = None):
"""
训练朴素贝叶斯
Args:
grid:超参数搜索空间的网格,不填则使用默认搜索空间
dataset:输入数据集,将会按照0.7, 0.3比例分为训练集和测试集
Expand All @@ -310,6 +318,7 @@ def do_naive_bayes(dataset: DataSet, log_dir: str = '../log', grid: dict = None)
def do_xgb(dataset: DataSet, log_dir: str = '../log', grid: dict = None):
"""
训练Xgboost
Args:
grid:超参数搜索空间的网格,不填则使用默认搜索空间
dataset:输入数据集,将会按照0.7, 0.3比例分为训练集和测试集
Expand Down
2 changes: 2 additions & 0 deletions MLSR/ssl.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ def grid_search_and_result_ssl(
k: int = 5):
"""
交叉验证网格搜索,测试集和训练集得分,混淆矩阵和ROC曲线绘制
Args:
Xtrain: 训练集特征
ytrain: 训练集标签
Expand Down Expand Up @@ -84,6 +85,7 @@ def grid_search_and_result_ssl(
def do_tsvm(data: DataSet, log_dir: str = '../log', grid: dict = None):
"""
Transductive Support Vector Machine
Args:
data: 输入数据DataSet对象
grid:超参数搜索空间的网格,不填则使用默认搜索空间
Expand Down
Binary file modified log/1_tsne.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Empty file.
Empty file.
Empty file.
Empty file.
Empty file.
Empty file.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.

0 comments on commit 4b4e332

Please sign in to comment.