你好,游客 登录
背景:
阅读新闻

决策分类树算法之ID3,C4.5算法系列

[日期:2017-10-29] 来源:csdn  作者:Android路上的人 [字体: ]

一、引言

在最开始的时候,我本来准备学习的是C4.5算法,后来发现C4.5算法的核心还是ID3算法,所以又辗转回到学习ID3算法了,因为C4.5是他的一个改进。至于是什么改进,在后面的描述中我会提到。

二、ID3算法

ID3算法是一种分类决策树算法。他通过一系列的规则,将数据最后分类成决策树的形式。分类的根据是用到了熵这个概念。熵在物理这门学科中就已经出现过,表示是一个物质的稳定度,在这里就是分类的纯度的一个概念。公式为:


在ID3算法中,是采用Gain信息增益来作为一个分类的判定标准的。他的定义为:


每次选择属性中信息增益最大作为划分属性,在这里本人实现了一个java版本的ID3算法,为了模拟数据的可操作性,就把数据写到一个input.txt文件中,作为数据源,格式如下:

 

  1. Day OutLook Temperature Humidity Wind PlayTennis 
  2. 1 Sunny Hot High Weak No 
  3. 2 Sunny Hot High Strong No 
  4. 3 Overcast Hot High Weak Yes 
  5. 4 Rainy Mild High Weak Yes 
  6. 5 Rainy Cool Normal Weak Yes 
  7. 6 Rainy Cool Normal Strong No 
  8. 7 Overcast Cool Normal Strong Yes 
  9. 8 Sunny Mild High Weak No 
  10. 9 Sunny Cool Normal Weak Yes 
  11. 10 Rainy Mild Normal Weak Yes 
  12. 11 Sunny Mild Normal Strong Yes 
  13. 12 Overcast Mild High Strong Yes 
  14. 13 Overcast Hot Normal Weak Yes 
  15. 14 Rainy Mild High Strong No 

PalyTennis 属性为结构属性,是作为类标识用的,中间的OutLool,Temperature,Humidity,Wind才是划分属性,通过将源数据与执行程序分 类,这样可以模拟巨大的数据量了。下面是ID3的主程序类,本人将ID3的算法进行了包装,对外只开放了一个构建决策树的方法,在构造函数时候,只需传入 一个数据路径文件即可:

 

  1. package DataMing_ID3; 
  2.  
  3. import java.io.BufferedReader; 
  4. import java.io.File; 
  5. import java.io.FileReader; 
  6. import java.io.IOException; 
  7. import java.util.ArrayList; 
  8. import java.util.HashMap; 
  9. import java.util.Iterator; 
  10. import java.util.Map; 
  11. import java.util.Map.Entry; 
  12. import java.util.Set; 
  13.  
  14. /** 
  15.  * ID3算法实现类 
  16.  *  
  17.  * @author lyq 
  18.  *  
  19.  */ 
  20. public class ID3Tool { 
  21.     // 类标号的值类型 
  22.     private final String YES = "Yes"
  23.     private final String NO = "No"
  24.  
  25.     // 所有属性的类型总数,在这里就是data源数据的列数 
  26.     private int attrNum; 
  27.     private String filePath; 
  28.     // 初始源数据,用一个二维字符数组存放模仿表格数据 
  29.     private String[][] data; 
  30.     // 数据的属性行的名字 
  31.     private String[] attrNames; 
  32.     // 每个属性的值所有类型 
  33.     private HashMap<String, ArrayList<String>> attrValue; 
  34.  
  35.     public ID3Tool(String filePath) { 
  36.         this.filePath = filePath; 
  37.         attrValue = new HashMap<>(); 
  38.     } 
  39.  
  40.     /** 
  41.      * 从文件中读取数据 
  42.      */ 
  43.     private void readDataFile() { 
  44.         File file = new File(filePath); 
  45.         ArrayList<String[]> dataArray = new ArrayList<String[]>(); 
  46.  
  47.         try { 
  48.             BufferedReader in = new BufferedReader(new FileReader(file)); 
  49.             String str; 
  50.             String[] tempArray; 
  51.             while ((str = in.readLine()) != null) { 
  52.                 tempArray = str.split(" "); 
  53.                 dataArray.add(tempArray); 
  54.             } 
  55.             in.close(); 
  56.         } catch (IOException e) { 
  57.             e.getStackTrace(); 
  58.         } 
  59.  
  60.         data = new String[dataArray.size()][]; 
  61.         dataArray.toArray(data); 
  62.         attrNum = data[0].length; 
  63.         attrNames = data[0]; 
  64.  
  65.         /* 
  66.          * for(int i=0; i<data.length;i++){ for(int j=0; j<data[0].length; j++){ 
  67.          * System.out.print(" " + data[i][j]); } 
  68.          *  
  69.          * System.out.print("\n"); } 
  70.          */ 
  71.     } 
  72.  
  73.     /** 
  74.      * 首先初始化每种属性的值的所有类型,用于后面的子类熵的计算时用 
  75.      */ 
  76.     private void initAttrValue() { 
  77.         ArrayList<String> tempValues; 
  78.  
  79.         // 按照列的方式,从左往右找 
  80.         for (int j = 1; j < attrNum; j++) { 
  81.             // 从一列中的上往下开始寻找值 
  82.             tempValues = new ArrayList<>(); 
  83.             for (int i = 1; i < data.length; i++) { 
  84.                 if (!tempValues.contains(data[i][j])) { 
  85.                     // 如果这个属性的值没有添加过,则添加 
  86.                     tempValues.add(data[i][j]); 
  87.                 } 
  88.             } 
  89.  
  90.             // 一列属性的值已经遍历完毕,复制到map属性表中 
  91.             attrValue.put(data[0][j], tempValues); 
  92.         } 
  93.  
  94.         /* 
  95.          * for(Map.Entry entry : attrValue.entrySet()){ 
  96.          * System.out.println("key:value " + entry.getKey() + ":" + 
  97.          * entry.getValue()); } 
  98.          */ 
  99.     } 
  100.  
  101.     /** 
  102.      * 计算数据按照不同方式划分的熵 
  103.      *  
  104.      * @param remainData 
  105.      *            剩余的数据 
  106.      * @param attrName 
  107.      *            待划分的属性,在算信息增益的时候会使用到 
  108.      * @param attrValue 
  109.      *            划分的子属性值 
  110.      * @param isParent 
  111.      *            是否分子属性划分还是原来不变的划分 
  112.      */ 
  113.     private double computeEntropy(String[][] remainData, String attrName, 
  114.             String value, boolean isParent) { 
  115.         // 实例总数 
  116.         int total = 0
  117.         // 正实例数 
  118.         int posNum = 0
  119.         // 负实例数 
  120.         int negNum = 0
  121.  
  122.         // 还是按列从左往右遍历属性 
  123.         for (int j = 1; j < attrNames.length; j++) { 
  124.             // 找到了指定的属性 
  125.             if (attrName.equals(attrNames[j])) { 
  126.                 for (int i = 1; i < remainData.length; i++) { 
  127.                     // 如果是父结点直接计算熵或者是通过子属性划分计算熵,这时要进行属性值的过滤 
  128.                     if (isParent 
  129.                             || (!isParent && remainData[i][j].equals(value))) { 
  130.                         if (remainData[i][attrNames.length - 1].equals(YES)) { 
  131.                             // 判断此行数据是否为正实例 
  132.                             posNum++; 
  133.                         } else { 
  134.                             negNum++; 
  135.                         } 
  136.                     } 
  137.                 } 
  138.             } 
  139.         } 
  140.  
  141.         total = posNum + negNum; 
  142.         double posProbobly = (double) posNum / total; 
  143.         double negProbobly = (double) negNum / total; 
  144.  
  145.         if (posProbobly == 1 || posProbobly == 0) { 
  146.             // 如果数据全为同种类型,则熵为0,否则带入下面的公式会报错 
  147.             return 0
  148.         } 
  149.  
  150.         double entropyValue = -posProbobly * Math.log(posProbobly) 
  151.                 / Math.log(2.0) - negProbobly * Math.log(negProbobly) 
  152.                 / Math.log(2.0); 
  153.  
  154.         // 返回计算所得熵 
  155.         return entropyValue; 
  156.     } 
  157.  
  158.     /** 
  159.      * 为某个属性计算信息增益 
  160.      *  
  161.      * @param remainData 
  162.      *            剩余的数据 
  163.      * @param value 
  164.      *            待划分的属性名称 
  165.      * @return 
  166.      */ 
  167.     private double computeGain(String[][] remainData, String value) { 
  168.         double gainValue = 0
  169.         // 源熵的大小将会与属性划分后进行比较 
  170.         double entropyOri = 0
  171.         // 子划分熵和 
  172.         double childEntropySum = 0
  173.         // 属性子类型的个数 
  174.         int childValueNum = 0
  175.         // 属性值的种数 
  176.         ArrayList<String> attrTypes = attrValue.get(value); 
  177.         // 子属性对应的权重比 
  178.         HashMap<String, Integer> ratioValues = new HashMap<>(); 
  179.  
  180.         for (int i = 0; i < attrTypes.size(); i++) { 
  181.             // 首先都统一计数为0 
  182.             ratioValues.put(attrTypes.get(i), 0); 
  183.         } 
  184.  
  185.         // 还是按照一列,从左往右遍历 
  186.         for (int j = 1; j < attrNames.length; j++) { 
  187.             // 判断是否到了划分的属性列 
  188.             if (value.equals(attrNames[j])) { 
  189.                 for (int i = 1; i <= remainData.length - 1; i++) { 
  190.                     childValueNum = ratioValues.get(remainData[i][j]); 
  191.                     // 增加个数并且重新存入 
  192.                     childValueNum++; 
  193.                     ratioValues.put(remainData[i][j], childValueNum); 
  194.                 } 
  195.             } 
  196.         } 
  197.  
  198.         // 计算原熵的大小 
  199.         entropyOri = computeEntropy(remainData, value, nulltrue); 
  200.         for (int i = 0; i < attrTypes.size(); i++) { 
  201.             double ratio = (double) ratioValues.get(attrTypes.get(i)) 
  202.                     / (remainData.length - 1); 
  203.             childEntropySum += ratio 
  204.                     * computeEntropy(remainData, value, attrTypes.get(i), false); 
  205.  
  206.             // System.out.println("ratio:value: " + ratio + " " + 
  207.             // computeEntropy(remainData, value, 
  208.             // attrTypes.get(i), false)); 
  209.         } 
  210.  
  211.         // 二者熵相减就是信息增益 
  212.         gainValue = entropyOri - childEntropySum; 
  213.         return gainValue; 
  214.     } 
  215.  
  216.     /** 
  217.      * 计算信息增益比 
  218.      *  
  219.      * @param remainData 
  220.      *            剩余数据 
  221.      * @param value 
  222.      *            待划分属性 
  223.      * @return 
  224.      */ 
  225.     private double computeGainRatio(String[][] remainData, String value) { 
  226.         double gain = 0
  227.         double spiltInfo = 0
  228.         int childValueNum = 0
  229.         // 属性值的种数 
  230.         ArrayList<String> attrTypes = attrValue.get(value); 
  231.         // 子属性对应的权重比 
  232.         HashMap<String, Integer> ratioValues = new HashMap<>(); 
  233.  
  234.         for (int i = 0; i < attrTypes.size(); i++) { 
  235.             // 首先都统一计数为0 
  236.             ratioValues.put(attrTypes.get(i), 0); 
  237.         } 
  238.  
  239.         // 还是按照一列,从左往右遍历 
  240.         for (int j = 1; j < attrNames.length; j++) { 
  241.             // 判断是否到了划分的属性列 
  242.             if (value.equals(attrNames[j])) { 
  243.                 for (int i = 1; i <= remainData.length - 1; i++) { 
  244.                     childValueNum = ratioValues.get(remainData[i][j]); 
  245.                     // 增加个数并且重新存入 
  246.                     childValueNum++; 
  247.                     ratioValues.put(remainData[i][j], childValueNum); 
  248.                 } 
  249.             } 
  250.         } 
  251.  
  252.         // 计算信息增益 
  253.         gain = computeGain(remainData, value); 
  254.         // 计算分裂信息,分裂信息度量被定义为(分裂信息用来衡量属性分裂数据的广度和均匀): 
  255.         for (int i = 0; i < attrTypes.size(); i++) { 
  256.             double ratio = (double) ratioValues.get(attrTypes.get(i)) 
  257.                     / (remainData.length - 1); 
  258.             spiltInfo += -ratio * Math.log(ratio) / Math.log(2.0); 
  259.         } 
  260.  
  261.         // 计算机信息增益率 
  262.         return gain / spiltInfo; 
  263.     } 
  264.  
  265.     /** 
  266.      * 利用源数据构造决策树 
  267.      */ 
  268.     private void buildDecisionTree(AttrNode node, String parentAttrValue, 
  269.             String[][] remainData, ArrayList<String> remainAttr, boolean isID3) { 
  270.         node.setParentAttrValue(parentAttrValue); 
  271.  
  272.         String attrName = ""
  273.         double gainValue = 0
  274.         double tempValue = 0
  275.  
  276.         // 如果只有1个属性则直接返回 
  277.         if (remainAttr.size() == 1) { 
  278.             System.out.println("attr null"); 
  279.             return
  280.         } 
  281.  
  282.         // 选择剩余属性中信息增益最大的作为下一个分类的属性 
  283.         for (int i = 0; i < remainAttr.size(); i++) { 
  284.             // 判断是否用ID3算法还是C4.5算法 
  285.             if (isID3) { 
  286.                 // ID3算法采用的是按照信息增益的值来比 
  287.                 tempValue = computeGain(remainData, remainAttr.get(i)); 
  288.             } else { 
  289.                 // C4.5算法进行了改进,用的是信息增益率来比,克服了用信息增益选择属性时偏向选择取值多的属性的不足 
  290.                 tempValue = computeGainRatio(remainData, remainAttr.get(i)); 
  291.             } 
  292.  
  293.             if (tempValue > gainValue) { 
  294.                 gainValue = tempValue; 
  295.                 attrName = remainAttr.get(i); 
  296.             } 
  297.         } 
  298.  
  299.         node.setAttrName(attrName); 
  300.         ArrayList<String> valueTypes = attrValue.get(attrName); 
  301.         remainAttr.remove(attrName); 
  302.  
  303.         AttrNode[] childNode = new AttrNode[valueTypes.size()]; 
  304.         String[][] rData; 
  305.         for (int i = 0; i < valueTypes.size(); i++) { 
  306.             // 移除非此值类型的数据 
  307.             rData = removeData(remainData, attrName, valueTypes.get(i)); 
  308.  
  309.             childNode[i] = new AttrNode(); 
  310.             boolean sameClass = true
  311.             ArrayList<String> indexArray = new ArrayList<>(); 
  312.             for (int k = 1; k < rData.length; k++) { 
  313.                 indexArray.add(rData[k][0]); 
  314.                 // 判断是否为同一类的 
  315.                 if (!rData[k][attrNames.length - 1
  316.                         .equals(rData[1][attrNames.length - 1])) { 
  317.                     // 只要有1个不相等,就不是同类型的 
  318.                     sameClass = false
  319.                     break
  320.                 } 
  321.             } 
  322.  
  323.             if (!sameClass) { 
  324.                 // 创建新的对象属性,对象的同个引用会出错 
  325.                 ArrayList<String> rAttr = new ArrayList<>(); 
  326.                 for (String str : remainAttr) { 
  327.                     rAttr.add(str); 
  328.                 } 
  329.  
  330.                 buildDecisionTree(childNode[i], valueTypes.get(i), rData, 
  331.                         rAttr, isID3); 
  332.             } else { 
  333.                 // 如果是同种类型,则直接为数据节点 
  334.                 childNode[i].setParentAttrValue(valueTypes.get(i)); 
  335.                 childNode[i].setChildDataIndex(indexArray); 
  336.             } 
  337.  
  338.         } 
  339.         node.setChildAttrNode(childNode); 
  340.     } 
  341.  
  342.     /** 
  343.      * 属性划分完毕,进行数据的移除 
  344.      *  
  345.      * @param srcData 
  346.      *            源数据 
  347.      * @param attrName 
  348.      *            划分的属性名称 
  349.      * @param valueType 
  350.      *            属性的值类型 
  351.      */ 
  352.     private String[][] removeData(String[][] srcData, String attrName, 
  353.             String valueType) { 
  354.         String[][] desDataArray; 
  355.         ArrayList<String[]> desData = new ArrayList<>(); 
  356.         // 待删除数据 
  357.         ArrayList<String[]> selectData = new ArrayList<>(); 
  358.         selectData.add(attrNames); 
  359.  
  360.         // 数组数据转化到列表中,方便移除 
  361.         for (int i = 0; i < srcData.length; i++) { 
  362.             desData.add(srcData[i]); 
  363.         } 
  364.  
  365.         // 还是从左往右一列列的查找 
  366.         for (int j = 1; j < attrNames.length; j++) { 
  367.             if (attrNames[j].equals(attrName)) { 
  368.                 for (int i = 1; i < desData.size(); i++) { 
  369.                     if (desData.get(i)[j].equals(valueType)) { 
  370.                         // 如果匹配这个数据,则移除其他的数据 
  371.                         selectData.add(desData.get(i)); 
  372.                     } 
  373.                 } 
  374.             } 
  375.         } 
  376.  
  377.         desDataArray = new String[selectData.size()][]; 
  378.         selectData.toArray(desDataArray); 
  379.  
  380.         return desDataArray; 
  381.     } 
  382.  
  383.     /** 
  384.      * 开始构建决策树 
  385.      *  
  386.      * @param isID3 
  387.      *            是否采用ID3算法构架决策树 
  388.      */ 
  389.     public void startBuildingTree(boolean isID3) { 
  390.         readDataFile(); 
  391.         initAttrValue(); 
  392.  
  393.         ArrayList<String> remainAttr = new ArrayList<>(); 
  394.         // 添加属性,除了最后一个类标号属性 
  395.         for (int i = 1; i < attrNames.length - 1; i++) { 
  396.             remainAttr.add(attrNames[i]); 
  397.         } 
  398.  
  399.         AttrNode rootNode = new AttrNode(); 
  400.         buildDecisionTree(rootNode, "", data, remainAttr, isID3); 
  401.         showDecisionTree(rootNode, 1); 
  402.     } 
  403.  
  404.     /** 
  405.      * 显示决策树 
  406.      *  
  407.      * @param node 
  408.      *            待显示的节点 
  409.      * @param blankNum 
  410.      *            行空格符,用于显示树型结构 
  411.      */ 
  412.     private void showDecisionTree(AttrNode node, int blankNum) { 
  413.         System.out.println(); 
  414.         for (int i = 0; i < blankNum; i++) { 
  415.             System.out.print("\t"); 
  416.         } 
  417.         System.out.print("--"); 
  418.         // 显示分类的属性值 
  419.         if (node.getParentAttrValue() != null 
  420.                 && node.getParentAttrValue().length() > 0) { 
  421.             System.out.print(node.getParentAttrValue()); 
  422.         } else { 
  423.             System.out.print("--"); 
  424.         } 
  425.         System.out.print("--"); 
  426.  
  427.         if (node.getChildDataIndex() != null 
  428.                 && node.getChildDataIndex().size() > 0) { 
  429.             String i = node.getChildDataIndex().get(0); 
  430.             System.out.print("类别:" 
  431.                     + data[Integer.parseInt(i)][attrNames.length - 1]); 
  432.             System.out.print("["); 
  433.             for (String index : node.getChildDataIndex()) { 
  434.                 System.out.print(index + ", "); 
  435.             } 
  436.             System.out.print("]"); 
  437.         } else { 
  438.             // 递归显示子节点 
  439.             System.out.print("【" + node.getAttrName() + "】"); 
  440.             for (AttrNode childNode : node.getChildAttrNode()) { 
  441.                 showDecisionTree(childNode, 2 * blankNum); 
  442.             } 
  443.         } 
  444.  
  445.     } 
  446.  

他的场景调用实现的方式为:

 

  1. /** 
  2.  * ID3决策树分类算法测试场景类 
  3.  * @author lyq 
  4.  * 
  5.  */ 
  6. public class Client { 
  7.     public static void main(String[] args){ 
  8.         String filePath = "C:\\Users\\lyq\\Desktop\\icon\\input.txt"
  9.          
  10.         ID3Tool tool = new ID3Tool(filePath); 
  11.         tool.startBuildingTree(true); 
  12.     } 

最终的结果为:

 

  1. ------【OutLook】 
  2.     --Sunny--【Humidity】 
  3.             --High--类别:No[128, ] 
  4.             --Normal--类别:Yes[911, ] 
  5.     --Overcast--类别:Yes[371213, ] 
  6.     --Rainy--【Wind】 
  7.             --Weak--类别:Yes[4510, ] 
  8.             --Strong--类别:No[614, ] 


请从左往右观察这棵决策树,【】里面的是分类属性,---XXX----,XXX为属性的值,在叶子节点处为类标记。

对应的分类结果图:


这里的构造决策树和显示决策树采用的DFS的方法,所以可能会比较难懂,希望读者能细细体会,可以调试一下代码,一步步的跟踪会更加容易理解的。

三、C4.5算法

如果你已经理解了上面ID3算法的实现,那么理解C4.5也很容易了,C4.5与 ID3在核心的算法是一样的,但是有一点所采用的办法是不同的,C4.5采用了信息增益率作为划分的根据,克服了ID3算法中采用信息增益划分导致属性选 择偏向取值多的属性。信息增益率的公式为:


分母的位置是分裂因子,他的计算公式为:


和熵的计算公式比较像,具体的信息增益率的算法也在上面的代码中了,请关注着2个方法:

 

  1. // 选择剩余属性中信息增益最大的作为下一个分类的属性 
  2. for (int i = 0; i < remainAttr.size(); i++) { 
  3.     // 判断是否用ID3算法还是C4.5算法 
  4.     if (isID3) { 
  5.         // ID3算法采用的是按照信息增益的值来比 
  6.         tempValue = computeGain(remainData, remainAttr.get(i)); 
  7.     } else { 
  8.         // C4.5算法进行了改进,用的是信息增益率来比,克服了用信息增益选择属性时偏向选择取值多的属性的不足 
  9.         tempValue = computeGainRatio(remainData, remainAttr.get(i)); 
  10.     } 
  11.  
  12.     if (tempValue > gainValue) { 
  13.         gainValue = tempValue; 
  14.         attrName = remainAttr.get(i); 
  15.     } 

在补充一下C4.5在其他方面对ID3的补充和改进:

1、在构造决策树的过程中能对树进行剪枝。

2、能对连续性的值进行离散化的操作。

四、编码时遇到的一些问题

为了实现ID3算法,从理解阅读他的原理就已经用掉了比较多的时间,然后再尝试阅读别 人写的C++版本的代码,又是看了几天,好不容易实现了2个算法,最后在构造树的过程中遇到了最大了麻烦,因为用到了递归构造树,对于其中节点的设计就显 得至关重要了,也许我自己目前的设计也不是最优秀的。下面盘点一下我的程序的遇到的一些问题和存在的潜在的问题:

1、在构建决策树的时候,出现了remainAttr值缺少的情况,就是递归的时候 remainAttr的属性划分移除掉之后,对于上次的递归操作的属性时受到影响了,后来发现是因为我remainAttr采用的是ArrayList, 他是一个引用对象,通过引用传入的方式,对象用的还是同一个,所以果断重新建了一个ArrayList对象,问题就OK了。

 

  1. // 创建新的对象属性,对象的同个引用会出错 
  2. ArrayList<String> rAttr = new ArrayList<>(); 
  3. for (String str : remainAttr) { 
  4.     rAttr.add(str); 
  5.  
  6. buildDecisionTree(childNode[i], valueTypes.get(i), rData, 
  7.         rAttr, isID3); 

2、第二个问题是当程序划分到最后一个属性时,如果出现了数据的类标识并不是同一个类的时候,我的处理操作时直接不处理,直接返回,会造成节点没有数据属性,也没有数据索引。

 

  1. private void buildDecisionTree(AttrNode node, String parentAttrValue, 
  2.         String[][] remainData, ArrayList<String> remainAttr, boolean isID3) { 
  3.     node.setParentAttrValue(parentAttrValue); 
  4.  
  5.     String attrName = ""
  6.     double gainValue = 0
  7.     double tempValue = 0
  8.  
  9.     // 如果只有1个属性则直接返回 
  10.     if (remainAttr.size() == 1) { 
  11.         System.out.println("attr null"); 
  12.         return
  13.     } 
  14.     ..... 

在这种情况下的处理不是很恰当个人觉得是这样。

收藏 推荐 打印 | 阅读:
相关新闻       C4.5  ID3  大数据算法实现