当前位置: 首页 > AI > 文章内容页

数据挖掘:决策树

时间:2025-07-20    作者:游乐小编    

本文介绍决策树算法,包括其通过规则分类数据、分分类树和回归树的原理,以及计算复杂度低等优缺点。还以UCI的Adult数据集为例,展示手动实现决策树分类算法和使用sklearn库实现的过程,包括数据加载、处理、模型构建、可视化及测试,两者在该数据集上分类准确度相同。

数据挖掘:决策树 - 游乐网

1. 算法原理

决策树是通过一系列规则对数据进行分类的过程。它提供一种在什么条件下会得到什么值的类似规则的方法。决策树分为分类树和回归树两种,分类树对离散变量做决策树,回归树对连续变量做决策树。近来的调查表明决策树也是最经常使用的数据挖掘算法,它的概念非常简单。决策树算法之所以如此流行,一个很重要的原因就是使用者基本上不用了解机器学习算法,也不用深究它是如何工作的。直观看上去,决策树分类器就像判断模块和终止块组成的流程图,终止块表示分类结果(也就是树的叶子)。判断模块表示对一个特征取值的判断(该特征有几个值,判断模块就有几个分支)。

如果不考虑效率等,那么样本所有特征的判断级联起来终会将某一个样本分到一个类终止块上。实际上,样本所有特征中有一些特征在分类时起到决定性作用,决策树的构造过程就是找到这些具有决定性作用的特征,根据其决定性程度来构造一个倒立的树–-决定性作用最大的那个特征作为根节点,然后递归找到各分支下子数据集中次大的决定性特征,直至子数据集中所有数据都属于同一类。所以,构造决策树的过程本质上就是根据数据特征将数据集分类的递归过程,我们需要解决的第一个问题就是,当前数据集上哪个特征在划分数据分类时起决定性作用。

2. 优缺点分析

决策树适用于数值型和标称型(离散型数据,变量的结果只在有限目标集中取值),能够读取数据集合,提取一些列数据中蕴含的规则。在分类问题中使用决策树模型有很多的优点,决策树计算复杂度不高、便于使用、而且高效,决策树可处理具有不相关特征的数据、可很容易地构造出易于理解的规则,而规则通常易于解释和理解。

决策树模型也有一些缺点,比如处理缺失数据时的困难、过度拟合以及忽略数据集中属性之间的相关性等。

3. Adult数据集

数据集使用 UCI 数据集中的 Adult 数据集。数据集下载链接:https://archive.ics.uci.edu/ml/datasets/Adult ,我们提供了已经下载好的数据集:【Data Mining】Adult。源数据集共有 14 个属性,分类有两种 (50K)。这 14 个属性中,有 6 个是连续类型的,有 8 个是离散类型的。

4. 代码展示

4.1 导入依赖

由于AI Studio平台没有提供sklearn_pandas库,因此我们手动安装一下,只要运行下方代码块即可:

In [1]
# 升级pip!pip install --upgrade pip# 安装sklearn_pandas!pip install sklearn_pandas
登录后复制
Looking in indexes: https://pypi.tuna.tsinghua.edu.cn/simpleRequirement already satisfied: pip in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (23.0)Looking in indexes: https://pypi.tuna.tsinghua.edu.cn/simpleRequirement already satisfied: sklearn_pandas in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (2.2.0)Requirement already satisfied: scipy>=1.5.1 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from sklearn_pandas) (1.7.3)Requirement already satisfied: numpy>=1.18.1 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from sklearn_pandas) (1.21.6)Requirement already satisfied: pandas>=1.1.4 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from sklearn_pandas) (1.1.5)Requirement already satisfied: scikit-learn>=0.23.0 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from sklearn_pandas) (1.0.2)Requirement already satisfied: python-dateutil>=2.7.3 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from pandas>=1.1.4->sklearn_pandas) (2.8.2)Requirement already satisfied: pytz>=2017.2 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from pandas>=1.1.4->sklearn_pandas) (2019.3)Requirement already satisfied: joblib>=0.11 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from scikit-learn>=0.23.0->sklearn_pandas) (0.14.1)Requirement already satisfied: threadpoolctl>=2.0.0 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from scikit-learn>=0.23.0->sklearn_pandas) (3.1.0)Requirement already satisfied: six>=1.5 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from python-dateutil>=2.7.3->pandas>=1.1.4->sklearn_pandas) (1.16.0)
登录后复制In [2]
import pandas as pd import numpy as np from sklearn_pandas import DataFrameMapper #sklearn-pandas模块提供了Scikit-Learn的机器学习方法和pandas风格的数据框架之间的桥梁。from sklearn.preprocessing import LabelEncoderfrom sklearn.tree import DecisionTreeClassifier #sklearn提供的决策树分类器from sklearn.tree import export_graphviz # 决策树可视化import graphviz # 用于绘制DOT语言脚本描述的图形from matplotlib import pyplot as pltfrom pylab import *from collections import defaultdict,Counter from tqdm import tqdm # 进度条
登录后复制

4.2 加载数据集

In [3]
# 14个属性+类别columns=['age','workclass','fnlwgt','education','education_num','marital_status','occupation','relationship',                 'race','sex','capital_gain','capital_loss','hours_per_week','native_country','annual_salary']# 加载训练集adult_train_path = 'data/data87314/adult.data'adult_train = pd.read_csv(adult_train_path,header=None,names=columns)# 加载测试集adult_test_path = 'data/data87314/adult.test'adult_test = pd.read_csv(adult_test_path,header=None,names=columns)
登录后复制

数据展示:

In [4]
adult_train.head()
登录后复制登录后复制
   age          workclass  fnlwgt   education  education_num  \0   39          State-gov   77516   Bachelors             13   1   50   Self-emp-not-inc   83311   Bachelors             13   2   38            Private  215646     HS-grad              9   3   53            Private  234721        11th              7   4   28            Private  338409   Bachelors             13           marital_status          occupation    relationship    race      sex  \0        Never-married        Adm-clerical   Not-in-family   White     Male   1   Married-civ-spouse     Exec-managerial         Husband   White     Male   2             Divorced   Handlers-cleaners   Not-in-family   White     Male   3   Married-civ-spouse   Handlers-cleaners         Husband   Black     Male   4   Married-civ-spouse      Prof-specialty            Wife   Black   Female      capital_gain  capital_loss  hours_per_week  native_country annual_salary  0          2174             0              40   United-States         <=50K  1             0             0              13   United-States         <=50K  2             0             0              40   United-States         <=50K  3             0             0              40   United-States         <=50K  4             0             0              40            Cuba         <=50K
登录后复制

4.3 数据向量化

使用LabelEncoder、DataFrameMapper将非数值列的数据转化为数值,也即向量化。

处理训练集:

In [5]
# 获取非数值列的列名train_dtype=adult_train.dtypes#print(train_dtype)train_list=[train_dtype.index[i] for i in range(len(train_dtype)) if train_dtype[i]=='object']# 使用LabelEncoder、DataFrameMapper将非数值列的数据转化为数值,也即向量化。 # 列的顺序会发生变化mapper=DataFrameMapper([(i, LabelEncoder()) for i in train_list], df_out=True, default=None)adult_train = mapper.fit_transform(adult_train.copy()).astype(dtype='int64')
登录后复制

处理测试集:

In [6]
test_dtype=adult_test.dtypestest_list=[test_dtype.index[i] for i in range(len(test_dtype)) if test_dtype[i]=='object']mapper=DataFrameMapper([(i, LabelEncoder()) for i in test_list], df_out=True, default=None)adult_test = mapper.fit_transform(adult_test.copy()).astype(dtype='int64')
登录后复制

向量化后的数据展示:

In [7]
adult_train.head()
登录后复制登录后复制
   workclass  education  marital_status  occupation  relationship  race  sex  \0          7          9               4           1             1     4    1   1          6          9               2           4             0     4    1   2          4         11               0           6             1     4    1   3          4          1               2           6             0     2    1   4          4          9               2          10             5     2    0      native_country  annual_salary  age  fnlwgt  education_num  capital_gain  \0              39              0   39   77516             13          2174   1              39              0   50   83311             13             0   2              39              0   38  215646              9             0   3              39              0   53  234721              7             0   4               5              0   28  338409             13             0      capital_loss  hours_per_week  0             0              40  1             0              13  2             0              40  3             0              40  4             0              40
登录后复制

4.4 属性与标签划分

fnlwgt属性,由于数据过于分散,在生成决策树的过程中会很耗时,在数据处理过程中删除该属性。最后得到的训练集大小为 (32561, 13) (32561,);测试集大小为(16281, 13) (16281,)。In [8]
col=list(adult_train.columns)label='annual_salary'col.remove(label)col.remove('fnlwgt')x_train,y_train=adult_train[col].values,adult_train[label].valuesx_test,y_test=adult_test[col].values,adult_test[label].valuesprint("训练集shape: ",x_train.shape,y_train.shape,"\n测试集shape: ",x_test.shape,y_test.shape)
登录后复制
训练集shape:  (32561, 13) (32561,) 测试集shape:  (16281, 13) (16281,)
登录后复制

4.5 手动实现决策树分类算法并可视化

4.5.1 定义决策树类:

In [9]
class Branch:    no=0 # 决策树节点的编号    column=0 # 该节点的属性    entropy=0 # 交叉熵    samples=0 # 该节点下的数据数目    value=[] # 记录由该节点划分的不同类别的数据    split=0 # 分类临界值    clss=-1 # 该节点的分类        branch_positive=None # 左分支    branch_negative=None # 右分支
登录后复制

4.5.2 构造决策树分类器

主要进行的操作有:

定义计算熵的函数定义根据指定属性进行分类的函数定义在指定数据范围内查找最佳分类属性的函数定义递归的构造决策树分类器的函数

定义计算熵的函数:

In [10]
def entroy(y):    counter=Counter(y)    res=0.0    for num in counter.values():        p=num/len(y) #每个类别的占比        res+=-p*np.log2(p)    return res
登录后复制

定义数据划分的函数:

In [11]
def split(x,y,d,value):    # x 数据集属性    # y 数据集标签    # d 划分的维度    # value 划分的参考值    left=(x[:,d]<=value)        right=(x[:,d]>value)    return x[left],x[right],y[left],y[right]
登录后复制

定义选取最好分类特征的函数,在当前的数据下(x,y)选取最合适的分类特征,并返回分类后的左右分支数据:

In [12]
def find_best_fearture(x,y):    best_entroy=entroy(y) # 熵初始化    best_v=None # 分类临界值    best_d='' # 分类属性    x_r=None    x_l=None    y_r=None    y_l=None    # 逐个属性进行比较    for d in range(x.shape[1]):        # 每个属性中,寻找最好的切分点。        # 因为有的属性本身是数值类型的,需要进行更细致的查找,确定最好的切分点。        sorted_index=np.argsort(x[:,d])# 根据d维度进行排序        for i in range(1,len(x)):#遍历每个样本            if x[sorted_index[i-1],d]!=x[sorted_index[i],d]:                v=(x[sorted_index[i-1],d]+x[sorted_index[i],d])/2.0                # 调用split函数进行划分                xl,xr,yl,yr=split(x,y,d,v)                n1=len(yl)                n2=len(yr)                n=n1+n2                                # 计算基尼系数                e=n1/n*entroy(yl)+n2/n*entroy(yr)                                if e登录后复制

定义决策树的构造函数,通过该函数递归生成一颗决策树:

In [13]
number=0def decison_tree_in(x,y,depth,max_depth=3):    global number    branch=Branch()    branch.no=number    number+=1    ddepth=depth # 记录分支的深度        branch.samples=len(y) # 记录该结点所包含数据的数量    n_positive=y[y==1].shape[0]    branch.value=[branch.samples-n_positive,n_positive] # 该结点下,0与1类别数目列表    if branch.value[0]>branch.value[1]:        branch.clss=0    else :        branch.clss=1    branch.entropy=entroy(y) # 计算该节点下的信息熵    best_feature=find_best_fearture(x,y)    branch.column=best_feature[0]    branch.split=best_feature[1]        if ddepth==max_depth or branch.column=='':        return branch    else:        x_l,y_l=best_feature[3],best_feature[5]        branch.branch_positive=decison_tree_in(x_l,y_l,ddepth+1,max_depth)        x_r,y_r=best_feature[4],best_feature[6]        branch.branch_negative=decison_tree_in(x_r,y_r,ddepth+1,max_depth)            return branch
登录后复制In [14]
tree=decison_tree_in(x_train,y_train,0,max_depth=4)
登录后复制

4.5.3 可视化构造好的决策树分类器模型

使用graphviz(使用DOT语言脚本绘制图形)可视化决策树。

In [15]
def get_dot_data_innner(branch:Branch, dot_data):       if branch.branch_positive:        dot_data=dot_data+'{} [label=<{}≤{}
entropy = {:.3f}
samples = {}
value = {}
class = {}> , fillcolor="#FFFFFFFF"] ;\r\n'.format( branch.no, col[branch.column],branch.split, branch.entropy, branch.samples, branch.value,branch.clss) else: dot_data=dot_data+'{} [label=<{}
entropy = {:.3f}
samples = {}
value = {}
class = {}> , fillcolor="#FFFFFFFF"] ;\r\n'.format( branch.no,branch.column, branch.entropy, branch.samples, branch.value,branch.clss) if branch.branch_positive: dot_data=dot_data+'{} -> {} [labeldistance=2.5, labelangle=45, headlabel="True"]; \r\n'.format(branch.no, branch.branch_positive.no) dot_data=get_dot_data_innner(branch.branch_positive, dot_data) if branch.branch_negative: dot_data=dot_data+'{} -> {} [labeldistance=2.5, labelangle=45, headlabel="False"]; \r\n'.format(branch.no, branch.branch_negative.no) dot_data=get_dot_data_innner(branch.branch_negative, dot_data) return dot_data
登录后复制In [16]
def get_dot_data(branch:Branch):    dot_data="""digraph Tree {node [shape=box, style="filled, rounded", color="black", fontname=helvetica] ;edge [fontname=helvetica] ;"""    dot_data=get_dot_data_innner(branch,  dot_data)    dot_data=dot_data+'\r\n}'    return dot_data
登录后复制In [17]
dot_data=get_dot_data(tree)
登录后复制In [18]
graph = graphviz.Source(dot_data) graph.render('./data/my_dt', format='png')graph
登录后复制relationship≤0.5entropy = 0.796samples = 32561value = [24720, 7841]class = 0education_num≤12.5entropy = 0.992samples = 13193value = [7275, 5918]class = 0Truecapital_gain≤7073.5entropy = 0.467samples = 19368value = [17445, 1923]class = 0Falsecapital_gain≤5095.5entropy = 0.915samples = 9224value = [6178, 3046]class = 0Truecapital_gain≤5095.5entropy = 0.850samples = 3969value = [1097, 2872]class = 1Falseeducation_num≤8.5entropy = 0.877samples = 8766value = [6170, 2596]class = 0Trueage≤61.5entropy = 0.127samples = 458value = [8, 450]class = 1False11 entropy = 0.480samples = 1459value = [1308, 151]class = 0Trueentropy = 0.920samples = 7307value = [4862, 2445]class = 0Falseentropy = 0.000samples = 410value = [0, 410]class = 1True10 entropy = 0.650samples = 48value = [8, 40]class = 1Falsecapital_loss≤1782.5entropy = 0.911samples = 3356value = [1094, 2262]class = 1Trueage≤62.5entropy = 0.045samples = 613value = [3, 610]class = 1False12 entropy = 0.944samples = 2999value = [1083, 1916]class = 1True11 entropy = 0.198samples = 357value = [11, 346]class = 1Falseentropy = 0.000samples = 542value = [0, 542]class = 1Trueentropy = 0.253samples = 71value = [3, 68]class = 1Falserelationship≤4.5entropy = 0.400samples = 18932value = [17431, 1501]class = 0Trueeducation_num≤10.5entropy = 0.205samples = 436value = [14, 422]class = 1Falseeducation_num≤12.5entropy = 0.286samples = 17482value = [16610, 872]class = 0Trueeducation_num≤10.5entropy = 0.987samples = 1450value = [821, 629]class = 0Falseentropy = 0.172samples = 14036value = [13677, 359]class = 0Trueentropy = 0.607samples = 3446value = [2933, 513]class = 0Falseentropy = 0.895samples = 902value = [621, 281]class = 0True11 entropy = 0.947samples = 548value = [200, 348]class = 1Falseage≤20.5entropy = 0.454samples = 147value = [14, 133]class = 1Trueentropy = 0.000samples = 289value = [0, 289]class = 1Falseentropy = 0.722samples = 5value = [4, 1]class = 0True10 entropy = 0.367samples = 142value = [10, 132]class = 1False
登录后复制

4.5.4 用测试集进行验证,计算模型分类准确性得分

In [20]
def cl(branch:Branch, x):        # 纯的数据集,不需要继续划分    if branch.split==None:        return branch.clss        # 继续划分,直至最大深度    if x[branch.column]<=branch.split:        if branch.branch_positive is not None:            return cl(branch.branch_positive,x)        else:            return branch.clss            if x[branch.column]>branch.split:        if branch.branch_negative is not None:            return cl(branch.branch_negative,x)        else:            return branch.clss
登录后复制In [21]
def compute_score(branch:Branch,x,y):    re=[]    for i in range(len(x)):        re.append(cl(branch,x[i]))            if len(re)!=len(y):        print("预测结果与实际结果数量不同,请检查程序。")        exit(0)            a= re==y    score=a[a==1].shape[0]/a.shape[0]        return score
登录后复制In [22]
score=compute_score(tree,x_test,y_test)print('自己搭建的决策树分类准确度得分:', score)
登录后复制
自己搭建的决策树分类准确度得分: 0.8443584546403784
登录后复制

4.6 使用sklearn提供的决策树分类器进行实验

4.6.1 实例化决策树分类器

In [23]
treeClassifier = DecisionTreeClassifier(max_depth=4,criterion='entropy')
登录后复制

4.6.2 决策树分类器训练

In [24]
treeClassifier.fit(x_train, y_train)
登录后复制
DecisionTreeClassifier(criterion='entropy', max_depth=4)
登录后复制

4.6.3 可视化训练好的决策树分类器模型

In [25]
export_graphviz(treeClassifier, out_file="dt_clf.pdf",feature_names=col)with open('dt_clf.pdf','r') as f:    dot_graph = f.read()graphviz.Source(dot_graph)
登录后复制relationship entropy = 0.796samples = 32561value = [24720, 7841]education_num entropy = 0.992samples = 13193value = [7275, 5918]Truecapital_gain entropy = 0.467samples = 19368value = [17445, 1923]Falsecapital_gain entropy = 0.915samples = 9224value = [6178, 3046]capital_gain entropy = 0.85samples = 3969value = [1097, 2872]education_num entropy = 0.877samples = 8766value = [6170, 2596]age entropy = 0.127samples = 458value = [8, 450]entropy = 0.48samples = 1459value = [1308, 151]entropy = 0.92samples = 7307value = [4862, 2445]entropy = 0.0samples = 410value = [0, 410]entropy = 0.65samples = 48value = [8, 40]capital_loss entropy = 0.911samples = 3356value = [1094, 2262]age entropy = 0.045samples = 613value = [3, 610]entropy = 0.944samples = 2999value = [1083, 1916]entropy = 0.198samples = 357value = [11, 346]entropy = 0.0samples = 542value = [0, 542]entropy = 0.253samples = 71value = [3, 68]relationship entropy = 0.4samples = 18932value = [17431, 1501]education_num entropy = 0.205samples = 436value = [14, 422]education_num entropy = 0.286samples = 17482value = [16610, 872]education_num entropy = 0.987samples = 1450value = [821, 629]entropy = 0.172samples = 14036value = [13677, 359]entropy = 0.607samples = 3446value = [2933, 513]entropy = 0.895samples = 902value = [621, 281]entropy = 0.947samples = 548value = [200, 348]age entropy = 0.454samples = 147value = [14, 133]entropy = 0.0samples = 289value = [0, 289]entropy = 0.722samples = 5value = [4, 1]entropy = 0.367samples = 142value = [10, 132]
登录后复制

4.6.4 用测试集进行验证,计算模型分类准确性得分

In [26]
score = treeClassifier.score(x_test, y_test)print('使用sklearn提供的决策树分类准确度得分:', score)
登录后复制
使用sklearn提供的决策树分类准确度得分: 0.8443584546403784
登录后复制

热门推荐

更多

热门文章

更多

首页  返回顶部

本站所有软件都由网友上传,如有侵犯您的版权,请发邮件youleyoucom@outlook.com