溫馨提示×

溫馨提示×

您好,登錄后才能下訂單哦!

密碼登錄×
登錄注冊×
其他方式登錄
點擊 登錄注冊 即表示同意《億速云用戶服務條款》

Python基于決策樹算法的分類預測怎么實現

發布時間:2022-01-17 16:20:11 來源:億速云 閱讀:166 作者:iii 欄目:開發技術
# Python基于決策樹算法的分類預測實現

決策樹是機器學習中經典的分類與回歸方法,因其直觀易懂、可解釋性強而廣受歡迎。本文將詳細介紹如何使用Python的scikit-learn庫實現基于決策樹的分類預測,涵蓋數據準備、模型構建、評估優化等全流程。

## 一、決策樹算法基礎

### 1.1 算法原理
決策樹通過遞歸地將數據集劃分為更純凈的子集來構建樹形結構,核心概念包括:
- **節點**:包含屬性測試條件的分支點
- **葉節點**:最終的分類結果
- **信息增益/基尼系數**:劃分標準的衡量指標

常用算法:
- ID3(使用信息增益)
- C4.5(使用信息增益率)
- CART(使用基尼系數)

### 1.2 數學基礎
**信息熵**:
$$ H(D) = -\sum_{k=1}^{K}p_k\log_2p_k $$

**基尼系數**:
$$ Gini(D) = 1-\sum_{k=1}^{K}p_k^2 $$

## 二、環境準備

```python
# 基礎庫安裝
pip install numpy pandas scikit-learn matplotlib

# 導入必要庫
import numpy as np
import pandas as pd
from sklearn.tree import DecisionTreeClassifier, export_text, plot_tree
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score, classification_report
import matplotlib.pyplot as plt

三、數據準備與預處理

3.1 數據加載

以經典的鳶尾花數據集為例:

from sklearn.datasets import load_iris
iris = load_iris()
X = iris.data
y = iris.target
feature_names = iris.feature_names
class_names = iris.target_names

# 轉換為DataFrame方便查看
df = pd.DataFrame(X, columns=feature_names)
df['target'] = y

3.2 數據探索

print(df.describe())
print("\n類別分布:\n", df['target'].value_counts())

3.3 數據劃分

X_train, X_test, y_train, y_test = train_test_split(
    X, y, test_size=0.3, random_state=42, stratify=y)

四、模型構建與訓練

4.1 基礎模型

# 創建決策樹分類器
clf = DecisionTreeClassifier(
    criterion='gini',       # 分裂標準
    max_depth=3,           # 最大深度
    min_samples_split=2,    # 分裂所需最小樣本數
    random_state=42
)

# 模型訓練
clf.fit(X_train, y_train)

4.2 關鍵參數說明

參數 說明 典型值
criterion 分裂標準 ‘gini’或’entropy’
max_depth 樹的最大深度 整數或None
min_samples_split 節點分裂最小樣本數 2-10
min_samples_leaf 葉節點最小樣本數 1-5
max_features 考慮的最大特征數 ‘auto’, ‘sqrt’等

五、模型評估與可視化

5.1 預測與評估

# 測試集預測
y_pred = clf.predict(X_test)

# 評估指標
print("準確率:", accuracy_score(y_test, y_pred))
print("\n分類報告:\n", classification_report(y_test, y_pred, target_names=class_names))

5.2 決策樹可視化

文本表示:

tree_rules = export_text(clf, feature_names=feature_names)
print("決策樹規則:\n", tree_rules)

圖形化展示:

plt.figure(figsize=(12,8))
plot_tree(clf, 
          feature_names=feature_names, 
          class_names=class_names,
          filled=True, 
          rounded=True)
plt.show()

六、模型優化策略

6.1 超參數調優

使用GridSearchCV進行網格搜索:

from sklearn.model_selection import GridSearchCV

param_grid = {
    'max_depth': [3, 5, 7, None],
    'min_samples_split': [2, 5, 10],
    'criterion': ['gini', 'entropy']
}

grid_search = GridSearchCV(DecisionTreeClassifier(random_state=42),
                          param_grid, 
                          cv=5,
                          scoring='accuracy')
grid_search.fit(X_train, y_train)

print("最優參數:", grid_search.best_params_)
print("最優分數:", grid_search.best_score_)

6.2 特征重要性分析

importances = clf.feature_importances_
indices = np.argsort(importances)[::-1]

plt.title('Feature Importance')
plt.bar(range(X.shape[1]), importances[indices], align='center')
plt.xticks(range(X.shape[1]), [feature_names[i] for i in indices])
plt.show()

七、實際應用案例

7.1 泰坦尼克生存預測

# 數據加載與預處理
titanic = pd.read_csv('titanic.csv')
titanic = titanic[['Survived', 'Pclass', 'Sex', 'Age', 'Fare']]
titanic['Sex'] = titanic['Sex'].map({'male':0, 'female':1})
titanic = titanic.dropna()

# 特征工程與建模
X = titanic.drop('Survived', axis=1)
y = titanic['Survived']
clf = DecisionTreeClassifier(max_depth=4)
clf.fit(X, y)

# 可視化決策路徑
plt.figure(figsize=(15,10))
plot_tree(clf, feature_names=X.columns, class_names=['Died','Survived'], filled=True)
plt.show()

八、決策樹的優缺點

8.1 優勢

  • 直觀易懂,可視化效果好
  • 無需特征縮放
  • 能處理數值和類別特征
  • 可解釋性強

8.2 局限性

  • 容易過擬合
  • 對數據變化敏感
  • 可能產生偏向于多值屬性的樹

九、擴展與進階

9.1 集成方法

  • 隨機森林
  • 梯度提升樹(GBDT)
  • XGBoost/LightGBM

9.2 類別不平衡處理

# 使用class_weight參數
clf = DecisionTreeClassifier(class_weight='balanced')

十、總結

本文完整演示了Python中使用決策樹進行分類預測的流程: 1. 數據準備與探索 2. 模型構建與訓練 3. 可視化與解釋 4. 評估與優化

決策樹作為基礎算法,雖然簡單但功能強大,是理解更復雜集成方法的重要基礎。實際應用中需要根據數據特點調整參數,并結合業務場景進行解釋。

注:本文代碼基于Python 3.8和scikit-learn 1.0.2版本實現,不同版本可能需要適當調整。 “`

本文共約1750字,涵蓋決策樹分類的完整實現流程,采用Markdown格式編寫,包含代碼塊、數學公式、表格等元素,可直接用于技術文檔或博客發布。

向AI問一下細節

免責聲明:本站發布的內容(圖片、視頻和文字)以原創、轉載和分享為主,文章觀點不代表本網站立場,如果涉及侵權請聯系站長郵箱:is@yisu.com進行舉報,并提供相關證據,一經查實,將立刻刪除涉嫌侵權內容。

AI

亚洲午夜精品一区二区_中文无码日韩欧免_久久香蕉精品视频_欧美主播一区二区三区美女