1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133
| #!/usr/bin/python # -*- coding: UTF-8 -*-
import operator from math import log import decisionTreePlot as dtPlot from collections import Counter
# 计算香农熵 def calcShannonEnt(dataset): numEnteries = len(dataset) labelCounts = {} # 统计分类中标签出现的次数 for featVec in dataset: # 默认每一行中最后一个特征为标签 currentLable = featVec[-1] if currentLable not in labelCounts.key(): labelCounts[currentLable] = 0 labelCounts[currentLable] +=1
shannonEnt = 0.0 for key in labelCounts: # 计算每种标签出现的频率 作为 概论 prob = float(labelCounts[key])/numEnteries # 根据香农熵的公式 - 负 概论*log以2为底概率值的对数 shannonEnt -= prob * log(prob,2)
## 方式二 # # 统计标签出现的次数 # label_count = Counter(data[-1] for data in dataSet) # # 计算概率 # probs = [p[1] / len(dataSet) for p in label_count.items()] # # 计算香农熵 # shannonEnt = sum([-p * log(p, 2) for p in probs]) # #
return shannonEnt
def majorityCnt(classList): # 选择出现次数最多的一个结果 classCount = {} for vote in classList: if vote not in classCount.keys(): classCount[vote] = 0 classCount[vote] +=1 # 倒叙排列classCount得到一个字典集合,然后取出第一个就是结果(yes/no),即出现次数最多的结果 sortedClassCount = sorted(classCount.items(), key=operator.itemgetter(1), reverse=True) # print('sortedClassCount:', sortedClassCount) return sortedClassCount[0][0]
# 选择切分数据集的最佳特征 def chooseBestFeatureToSplit(dataSet): # 求第一行有多少列的 Feature, 最后一列是label列嘛 numFeatures = len(dataSet[0]) - 1 # label的信息熵 baseEntropy = calcShannonEnt(dataSet) # 最优的信息增益值, 和最优的Featurn编号 bestInfoGain, bestFeature = 0.0, -1 # iterate over all the features for i in range(numFeatures): # create a list of all the examples of this feature # 获取每一个实例的第i+1个feature,组成list集合 featList = [example[i] for example in dataSet] # get a set of unique values # 获取剔重后的集合,使用set对list数据进行去重 uniqueVals = set(featList) # 创建一个临时的信息熵 newEntropy = 0.0 # 遍历某一列的value集合,计算该列的信息熵 # 遍历当前特征中的所有唯一属性值,对每个唯一属性值划分一次数据集,计算数据集的新熵值,并对所有唯一特征值得到的熵求和。 for value in uniqueVals: subDataSet = splitDataSet(dataSet, i, value) prob = len(subDataSet)/float(len(dataSet)) newEntropy += prob * calcShannonEnt(subDataSet) # gain[信息增益]: 划分数据集前后的信息变化, 获取信息熵最大的值 # 信息增益是熵的减少或者是数据无序度的减少。最后,比较所有特征中的信息增益,返回最好特征划分的索引值。 infoGain = baseEntropy - newEntropy print('infoGain=', infoGain, 'bestFeature=', i, baseEntropy, newEntropy) if (infoGain > bestInfoGain): bestInfoGain = infoGain bestFeature = i return bestFeature
# 创建决策树 def createTree(dataset, lables): classList = [example[-1] for example in dataset] # 如果数据集只有一种则直接返回 # if classList.count(classList[0]) == len(classList): return classList[0] # ?使用完了所有特征 依旧还有未包含的数据集 if len(dataset[0]) == 1: return majorityCnt(classList)
# 选择最优的列,得到最优列对应的label含义 bestFeat = chooseBestFeatureToSplit(dataSet) # 获取label的名称 bestFeatLabel = labels[bestFeat] # 初始化myTree myTree = {bestFeatLabel: {\}\} # hexo的bug 双大括号 会异常 需转义 # 注:labels列表是可变对象,在PYTHON函数中作为参数时传址引用,能够被全局修改 # 所以这行代码导致函数外的同名变量被删除了元素,造成例句无法执行,提示'no surfacing' is not in list del(labels[bestFeat]) # 取出最优列,然后它的branch做分类 featValues = [example[bestFeat] for example in dataSet] uniqueVals = set(featValues) for value in uniqueVals: # 求出剩余的标签label subLabels = labels[:] # 遍历当前选择特征包含的所有属性值,在每个数据集划分上递归调用函数createTree() myTree[bestFeatLabel][value] = createTree(splitDataSet(dataSet, bestFeat, value), subLabels) # print('myTree', value, myTree) return myTree
def ContactLensesTest(): # 预测隐形眼镜的测试代码,并将结果画出来 # fr = open('https://raw.githubusercontent.com/pbharrin/machinelearninginaction/master/Ch03/lenses.txt') fr = open('../../../../data/3.DecisionTree/lenses.txt') # 解析数据,获得 features 数据 lenses = [inst.strip().split('\t') for inst in fr.readlines()] # 得到数据的对应的 Labels # age(年龄)、prescript(症状)、astigmatic(是否散光)、tearRate(眼泪数量) lensesLabels = ['age', 'prescript', 'astigmatic', 'tearRate'] # 使用上面的创建决策树的代码,构造预测隐形眼镜的决策树 lensesTree = createTree(lenses, lensesLabels) print(lensesTree) # 画图可视化展现 dtPlot.createPlot(lensesTree)
if __name__ == "__main__": ContactLensesTest()
|