Introduction

  • here is the method of classifier builder:

protected void buildInternal(MultiLabelInstances aTrain)
throws Exception {
 super.buildInternal(aTrain);
 
 if (cvkSelection == true) {
 crossValidate();
 }
 }

the aTrain is of the training dataset

buildInternal method is from the Class

crossValidate() method will show in the follow

protected void crossValidate() throws Exception {
        try {
            // the performance for each different k
                       // 不同的k可以造成不同的表现
            double[] hammingLoss = new double[cvMaxK];
                        // 记录每个模型的Hamming Loss
            for (int i = 0; i < cvMaxK; i++) {
                hammingLoss[i] = 0;
            }
 
            Instances dataSet = train;//nothing
            Instance instance; // the hold out instance
            Instances neighbours; // the neighboring instances
            double[] origDistances, convertedDistances;//表示orig距离
            for (int i = 0; i < dataSet.numInstances(); i++) {
                if (getDebug() && (i % 50 == 0)) {
                    debug("Cross validating " + i + "/"
                            + dataSet.numInstances() + "\r");
                }
                instance = dataSet.instance(i);
                neighbours = lnn.kNearestNeighbours(instance, cvMaxK);
                                       //instance's k  neighbours
                origDistances = lnn.getDistances();
 
                // gathering the true labels for the instance
                boolean[] trueLabels = new boolean[numLabels];
                for (int counter = 0; counter < numLabels; counter++) {
                    int classIdx = labelIndices[counter];
                    String classValue = instance.attribute(classIdx).value(
                            (int) instance.value(classIdx));
                    trueLabels[counter] = classValue.equals("1");
                }
                // calculate the performance metric for each different k
                for (int j = cvMaxK; j > 0; j--) {
                    convertedDistances = new double[origDistances.length];
                    System.arraycopy(origDistances, 0, convertedDistances, 0,
                            origDistances.length);
                    double[] confidences = this.getConfidences(neighbours,
                            convertedDistances);
                    boolean[] bipartition = null;
 
                    switch (extension) {
                    case NONE: // BRknn
                        /**那么选择默认的方法得到最终分类结果
                            bipartition = results.getBipartition();
                         */
                        MultiLabelOutput results;
                        results = new MultiLabelOutput(confidences, 0.5);
                        bipartition = results.getBipartition();
                        break;
                    case EXTA: // BRknn-a
                       /**那么选择如果最终预测结果不输出任何结果时输出最大置信度的类标的
                            方法得到最终分类结果
                            bipartition =labelsFromConfidences2(confidences);
                         */
                        bipartition = labelsFromConfidences2(confidences);
                        break;
                    case EXTB: // BRknn-b
                               /*选择输出固定类标数目的方法得到预测结果
                               */
                        bipartition = labelsFromConfidences3(confidences);
                        break;
                    }
 
                    double symmetricDifference = 0; // |Y xor Z|
                    for (int labelIndex = 0; labelIndex < numLabels;
                                                 labelIndex++) {
                         //统计每个样例的预测结果和实际结果的差别
                        boolean actual = trueLabels[labelIndex];
                        boolean predicted = bipartition[labelIndex];
 
                        if (predicted != actual) {
                            symmetricDifference++;
                        }
                    }
                    hammingLoss[j - 1] += (symmetricDifference / numLabels);
 
                    neighbours = new IBk().pruneToK(neighbours,
                            convertedDistances, j - 1);
                }
            }
 
            // Display the results of the cross-validation
            if (getDebug()) {
                for (int i = cvMaxK; i > 0; i--) {
                    debug("Hold-one-out performance of " + (i) + " neighbors ");
                    debug("(Hamming Loss) = " + hammingLoss[i - 1]
                            / dataSet.numInstances());
                }
            }
 
            // Check through the performance stats and select the best
            // k value (or the lowest k if more than one best)
            double[] searchStats = hammingLoss;
 
            double bestPerformance = Double.NaN;
            int bestK = 1;
            for (int i = 0; i < cvMaxK; i++) {
                if (Double.isNaN(bestPerformance)
                        || (bestPerformance > searchStats[i])) {
                    bestPerformance = searchStats[i];
                    bestK = i + 1;
                }
            }
            numOfNeighbors = bestK;
            if (getDebug()) {
                System.err.println("Selected k = " + bestK);
            }
 
        } catch (Exception ex) {
            throw new Error("Couldn't optimize by cross-validation: "
                    + ex.getMessage());
        }
    }

介绍一下,中间的一些代码switch片段是说,选择三种不同的优化方法