# Python基于決策樹算法的分類預測實現
決策樹是機器學習中經典的分類與回歸方法,因其直觀易懂、可解釋性強而廣受歡迎。本文將詳細介紹如何使用Python的scikit-learn庫實現基于決策樹的分類預測,涵蓋數據準備、模型構建、評估優化等全流程。
## 一、決策樹算法基礎
### 1.1 算法原理
決策樹通過遞歸地將數據集劃分為更純凈的子集來構建樹形結構,核心概念包括:
- **節點**:包含屬性測試條件的分支點
- **葉節點**:最終的分類結果
- **信息增益/基尼系數**:劃分標準的衡量指標
常用算法:
- ID3(使用信息增益)
- C4.5(使用信息增益率)
- CART(使用基尼系數)
### 1.2 數學基礎
**信息熵**:
$$ H(D) = -\sum_{k=1}^{K}p_k\log_2p_k $$
**基尼系數**:
$$ Gini(D) = 1-\sum_{k=1}^{K}p_k^2 $$
## 二、環境準備
```python
# 基礎庫安裝
pip install numpy pandas scikit-learn matplotlib
# 導入必要庫
import numpy as np
import pandas as pd
from sklearn.tree import DecisionTreeClassifier, export_text, plot_tree
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score, classification_report
import matplotlib.pyplot as plt
以經典的鳶尾花數據集為例:
from sklearn.datasets import load_iris
iris = load_iris()
X = iris.data
y = iris.target
feature_names = iris.feature_names
class_names = iris.target_names
# 轉換為DataFrame方便查看
df = pd.DataFrame(X, columns=feature_names)
df['target'] = y
print(df.describe())
print("\n類別分布:\n", df['target'].value_counts())
X_train, X_test, y_train, y_test = train_test_split(
X, y, test_size=0.3, random_state=42, stratify=y)
# 創建決策樹分類器
clf = DecisionTreeClassifier(
criterion='gini', # 分裂標準
max_depth=3, # 最大深度
min_samples_split=2, # 分裂所需最小樣本數
random_state=42
)
# 模型訓練
clf.fit(X_train, y_train)
參數 | 說明 | 典型值 |
---|---|---|
criterion | 分裂標準 | ‘gini’或’entropy’ |
max_depth | 樹的最大深度 | 整數或None |
min_samples_split | 節點分裂最小樣本數 | 2-10 |
min_samples_leaf | 葉節點最小樣本數 | 1-5 |
max_features | 考慮的最大特征數 | ‘auto’, ‘sqrt’等 |
# 測試集預測
y_pred = clf.predict(X_test)
# 評估指標
print("準確率:", accuracy_score(y_test, y_pred))
print("\n分類報告:\n", classification_report(y_test, y_pred, target_names=class_names))
文本表示:
tree_rules = export_text(clf, feature_names=feature_names)
print("決策樹規則:\n", tree_rules)
圖形化展示:
plt.figure(figsize=(12,8))
plot_tree(clf,
feature_names=feature_names,
class_names=class_names,
filled=True,
rounded=True)
plt.show()
使用GridSearchCV進行網格搜索:
from sklearn.model_selection import GridSearchCV
param_grid = {
'max_depth': [3, 5, 7, None],
'min_samples_split': [2, 5, 10],
'criterion': ['gini', 'entropy']
}
grid_search = GridSearchCV(DecisionTreeClassifier(random_state=42),
param_grid,
cv=5,
scoring='accuracy')
grid_search.fit(X_train, y_train)
print("最優參數:", grid_search.best_params_)
print("最優分數:", grid_search.best_score_)
importances = clf.feature_importances_
indices = np.argsort(importances)[::-1]
plt.title('Feature Importance')
plt.bar(range(X.shape[1]), importances[indices], align='center')
plt.xticks(range(X.shape[1]), [feature_names[i] for i in indices])
plt.show()
# 數據加載與預處理
titanic = pd.read_csv('titanic.csv')
titanic = titanic[['Survived', 'Pclass', 'Sex', 'Age', 'Fare']]
titanic['Sex'] = titanic['Sex'].map({'male':0, 'female':1})
titanic = titanic.dropna()
# 特征工程與建模
X = titanic.drop('Survived', axis=1)
y = titanic['Survived']
clf = DecisionTreeClassifier(max_depth=4)
clf.fit(X, y)
# 可視化決策路徑
plt.figure(figsize=(15,10))
plot_tree(clf, feature_names=X.columns, class_names=['Died','Survived'], filled=True)
plt.show()
# 使用class_weight參數
clf = DecisionTreeClassifier(class_weight='balanced')
本文完整演示了Python中使用決策樹進行分類預測的流程: 1. 數據準備與探索 2. 模型構建與訓練 3. 可視化與解釋 4. 評估與優化
決策樹作為基礎算法,雖然簡單但功能強大,是理解更復雜集成方法的重要基礎。實際應用中需要根據數據特點調整參數,并結合業務場景進行解釋。
注:本文代碼基于Python 3.8和scikit-learn 1.0.2版本實現,不同版本可能需要適當調整。 “`
本文共約1750字,涵蓋決策樹分類的完整實現流程,采用Markdown格式編寫,包含代碼塊、數學公式、表格等元素,可直接用于技術文檔或博客發布。
免責聲明:本站發布的內容(圖片、視頻和文字)以原創、轉載和分享為主,文章觀點不代表本網站立場,如果涉及侵權請聯系站長郵箱:is@yisu.com進行舉報,并提供相關證據,一經查實,將立刻刪除涉嫌侵權內容。