/*
 * Decompiled with CFR 0.152.
 */
package weka.classifiers.meta;

import java.io.Serializable;
import java.util.Enumeration;
import java.util.Random;
import java.util.Vector;
import weka.classifiers.Evaluation;
import weka.classifiers.RandomizableSingleClassifierEnhancer;
import weka.core.Capabilities;
import weka.core.Drawable;
import weka.core.FastVector;
import weka.core.Instance;
import weka.core.Instances;
import weka.core.Option;
import weka.core.OptionHandler;
import weka.core.RevisionHandler;
import weka.core.RevisionUtils;
import weka.core.Summarizable;
import weka.core.TechnicalInformation;
import weka.core.TechnicalInformationHandler;
import weka.core.Utils;

public class CVParameterSelection
extends RandomizableSingleClassifierEnhancer
implements Drawable,
Summarizable,
TechnicalInformationHandler {
    static final long serialVersionUID = -6529603380876641265L;
    protected String[] m_ClassifierOptions;
    protected String[] m_BestClassifierOptions;
    protected String[] m_InitOptions;
    protected double m_BestPerformance;
    protected FastVector m_CVParams = new FastVector();
    protected int m_NumAttributes;
    protected int m_TrainFoldSize;
    protected int m_NumFolds = 10;

    protected String[] createOptions() {
        String[] options = new String[this.m_ClassifierOptions.length + 2 * this.m_CVParams.size()];
        int start = 0;
        int end = options.length;
        for (int i = 0; i < this.m_CVParams.size(); ++i) {
            boolean isInt;
            CVParameter cvParam = (CVParameter)this.m_CVParams.elementAt(i);
            double paramValue = cvParam.m_ParamValue;
            if (cvParam.m_RoundParam) {
                paramValue = Math.rint(paramValue);
            }
            boolean bl = isInt = paramValue - (double)((int)paramValue) == 0.0;
            if (cvParam.m_AddAtEnd) {
                options[--end] = "" + (cvParam.m_RoundParam || isInt ? Utils.doubleToString(paramValue, 4) : Double.valueOf(cvParam.m_ParamValue));
                options[--end] = "-" + cvParam.m_ParamChar;
                continue;
            }
            options[start++] = "-" + cvParam.m_ParamChar;
            options[start++] = "" + (cvParam.m_RoundParam || isInt ? Utils.doubleToString(paramValue, 4) : Double.valueOf(cvParam.m_ParamValue));
        }
        System.arraycopy(this.m_ClassifierOptions, 0, options, start, this.m_ClassifierOptions.length);
        return options;
    }

    protected void findParamsByCrossValidation(int depth, Instances trainData, Random random) throws Exception {
        if (depth < this.m_CVParams.size()) {
            double upper;
            CVParameter cvParam = (CVParameter)this.m_CVParams.elementAt(depth);
            switch ((int)(cvParam.m_Lower - cvParam.m_Upper + 0.5)) {
                case 1: {
                    upper = this.m_NumAttributes;
                    break;
                }
                case 2: {
                    upper = this.m_TrainFoldSize;
                    break;
                }
                default: {
                    upper = cvParam.m_Upper;
                }
            }
            double increment = (upper - cvParam.m_Lower) / (cvParam.m_Steps - 1.0);
            cvParam.m_ParamValue = cvParam.m_Lower;
            while (cvParam.m_ParamValue <= upper) {
                this.findParamsByCrossValidation(depth + 1, trainData, random);
                cvParam.m_ParamValue += increment;
            }
        } else {
            Evaluation evaluation = new Evaluation(trainData);
            String[] options = this.createOptions();
            if (this.m_Debug) {
                System.err.print("Setting options for " + this.m_Classifier.getClass().getName() + ":");
                for (int i = 0; i < options.length; ++i) {
                    System.err.print(" " + options[i]);
                }
                System.err.println("");
            }
            this.m_Classifier.setOptions(options);
            for (int j = 0; j < this.m_NumFolds; ++j) {
                Instances train = trainData.trainCV(this.m_NumFolds, j, new Random(1L));
                Instances test = trainData.testCV(this.m_NumFolds, j);
                this.m_Classifier.buildClassifier(train);
                evaluation.setPriors(train);
                evaluation.evaluateModel(this.m_Classifier, test, new Object[0]);
            }
            double error = evaluation.errorRate();
            if (this.m_Debug) {
                System.err.println("Cross-validated error rate: " + Utils.doubleToString(error, 6, 4));
            }
            if (this.m_BestPerformance == -99.0 || error < this.m_BestPerformance) {
                this.m_BestPerformance = error;
                this.m_BestClassifierOptions = this.createOptions();
            }
        }
    }

    public String globalInfo() {
        return "Class for performing parameter selection by cross-validation for any classifier.\n\nFor more information, see:\n\n" + this.getTechnicalInformation().toString();
    }

    public TechnicalInformation getTechnicalInformation() {
        TechnicalInformation result = new TechnicalInformation(TechnicalInformation.Type.PHDTHESIS);
        result.setValue(TechnicalInformation.Field.AUTHOR, "R. Kohavi");
        result.setValue(TechnicalInformation.Field.YEAR, "1995");
        result.setValue(TechnicalInformation.Field.TITLE, "Wrappers for Performance Enhancement and Oblivious Decision Graphs");
        result.setValue(TechnicalInformation.Field.SCHOOL, "Stanford University");
        result.setValue(TechnicalInformation.Field.ADDRESS, "Department of Computer Science, Stanford University");
        return result;
    }

    public Enumeration listOptions() {
        Vector<Option> newVector = new Vector<Option>(2);
        newVector.addElement(new Option("\tNumber of folds used for cross validation (default 10).", "X", 1, "-X <number of folds>"));
        newVector.addElement(new Option("\tClassifier parameter options.\n\teg: \"N 1 5 10\" Sets an optimisation parameter for the\n\tclassifier with name -N, with lower bound 1, upper bound\n\t5, and 10 optimisation steps. The upper bound may be the\n\tcharacter 'A' or 'I' to substitute the number of\n\tattributes or instances in the training data,\n\trespectively. This parameter may be supplied more than\n\tonce to optimise over several classifier options\n\tsimultaneously.", "P", 1, "-P <classifier parameter>"));
        Enumeration enu = super.listOptions();
        while (enu.hasMoreElements()) {
            newVector.addElement((Option)enu.nextElement());
        }
        return newVector.elements();
    }

    public void setOptions(String[] options) throws Exception {
        String cvParam;
        String foldsString = Utils.getOption('X', options);
        if (foldsString.length() != 0) {
            this.setNumFolds(Integer.parseInt(foldsString));
        } else {
            this.setNumFolds(10);
        }
        this.m_CVParams = new FastVector();
        do {
            if ((cvParam = Utils.getOption('P', options)).length() == 0) continue;
            this.addCVParameter(cvParam);
        } while (cvParam.length() != 0);
        super.setOptions(options);
    }

    public String[] getOptions() {
        String[] superOptions;
        if (this.m_InitOptions != null) {
            try {
                this.m_Classifier.setOptions((String[])this.m_InitOptions.clone());
                superOptions = super.getOptions();
                this.m_Classifier.setOptions((String[])this.m_BestClassifierOptions.clone());
            }
            catch (Exception e) {
                throw new RuntimeException("CVParameterSelection: could not set options in getOptions().");
            }
        } else {
            superOptions = super.getOptions();
        }
        String[] options = new String[superOptions.length + this.m_CVParams.size() * 2 + 2];
        int current = 0;
        for (int i = 0; i < this.m_CVParams.size(); ++i) {
            options[current++] = "-P";
            options[current++] = "" + this.getCVParameter(i);
        }
        options[current++] = "-X";
        options[current++] = "" + this.getNumFolds();
        System.arraycopy(superOptions, 0, options, current, superOptions.length);
        return options;
    }

    public String[] getBestClassifierOptions() {
        return (String[])this.m_BestClassifierOptions.clone();
    }

    public Capabilities getCapabilities() {
        Capabilities result = super.getCapabilities();
        result.setMinimumNumberInstances(this.m_NumFolds);
        return result;
    }

    public void buildClassifier(Instances instances) throws Exception {
        this.getCapabilities().testWithFail(instances);
        Instances trainData = new Instances(instances);
        trainData.deleteWithMissingClass();
        if (!(this.m_Classifier instanceof OptionHandler)) {
            throw new IllegalArgumentException("Base classifier should be OptionHandler.");
        }
        this.m_InitOptions = this.m_Classifier.getOptions();
        this.m_BestPerformance = -99.0;
        this.m_NumAttributes = trainData.numAttributes();
        Random random = new Random(this.m_Seed);
        trainData.randomize(random);
        this.m_TrainFoldSize = trainData.trainCV(this.m_NumFolds, 0).numInstances();
        if (this.m_CVParams.size() == 0) {
            this.m_Classifier.buildClassifier(trainData);
            this.m_BestClassifierOptions = this.m_InitOptions;
            return;
        }
        if (trainData.classAttribute().isNominal()) {
            trainData.stratify(this.m_NumFolds);
        }
        this.m_BestClassifierOptions = null;
        this.m_ClassifierOptions = this.m_Classifier.getOptions();
        for (int i = 0; i < this.m_CVParams.size(); ++i) {
            Utils.getOption(((CVParameter)this.m_CVParams.elementAt(i)).m_ParamChar, this.m_ClassifierOptions);
        }
        this.findParamsByCrossValidation(0, trainData, random);
        String[] options = (String[])this.m_BestClassifierOptions.clone();
        this.m_Classifier.setOptions(options);
        this.m_Classifier.buildClassifier(trainData);
    }

    public double[] distributionForInstance(Instance instance) throws Exception {
        return this.m_Classifier.distributionForInstance(instance);
    }

    public void addCVParameter(String cvParam) throws Exception {
        CVParameter newCV = new CVParameter(cvParam);
        this.m_CVParams.addElement(newCV);
    }

    public String getCVParameter(int index) {
        if (this.m_CVParams.size() <= index) {
            return "";
        }
        return ((CVParameter)this.m_CVParams.elementAt(index)).toString();
    }

    public String CVParametersTipText() {
        return "Sets the scheme parameters which are to be set by cross-validation.\nThe format for each string should be:\nparam_char lower_bound upper_bound number_of_steps\neg to search a parameter -P from 1 to 10 by increments of 1:\n    \"P 1 10 10\" ";
    }

    public Object[] getCVParameters() {
        Object[] CVParams = this.m_CVParams.toArray();
        Object[] params = new String[CVParams.length];
        for (int i = 0; i < CVParams.length; ++i) {
            params[i] = CVParams[i].toString();
        }
        return params;
    }

    public void setCVParameters(Object[] params) throws Exception {
        FastVector backup = this.m_CVParams;
        this.m_CVParams = new FastVector();
        for (int i = 0; i < params.length; ++i) {
            try {
                this.addCVParameter((String)params[i]);
                continue;
            }
            catch (Exception ex) {
                this.m_CVParams = backup;
                throw ex;
            }
        }
    }

    public String numFoldsTipText() {
        return "Get the number of folds used for cross-validation.";
    }

    public int getNumFolds() {
        return this.m_NumFolds;
    }

    public void setNumFolds(int numFolds) throws Exception {
        if (numFolds < 0) {
            throw new IllegalArgumentException("Stacking: Number of cross-validation folds must be positive.");
        }
        this.m_NumFolds = numFolds;
    }

    public int graphType() {
        if (this.m_Classifier instanceof Drawable) {
            return ((Drawable)((Object)this.m_Classifier)).graphType();
        }
        return 0;
    }

    public String graph() throws Exception {
        if (this.m_Classifier instanceof Drawable) {
            return ((Drawable)((Object)this.m_Classifier)).graph();
        }
        throw new Exception("Classifier: " + this.m_Classifier.getClass().getName() + " " + Utils.joinOptions(this.m_BestClassifierOptions) + " cannot be graphed");
    }

    public String toString() {
        if (this.m_InitOptions == null) {
            return "CVParameterSelection: No model built yet.";
        }
        String result = "Cross-validated Parameter selection.\nClassifier: " + this.m_Classifier.getClass().getName() + "\n";
        try {
            for (int i = 0; i < this.m_CVParams.size(); ++i) {
                CVParameter cvParam = (CVParameter)this.m_CVParams.elementAt(i);
                result = result + "Cross-validation Parameter: '-" + cvParam.m_ParamChar + "'" + " ranged from " + cvParam.m_Lower + " to ";
                switch ((int)(cvParam.m_Lower - cvParam.m_Upper + 0.5)) {
                    case 1: {
                        result = result + this.m_NumAttributes;
                        break;
                    }
                    case 2: {
                        result = result + this.m_TrainFoldSize;
                        break;
                    }
                    default: {
                        result = result + cvParam.m_Upper;
                    }
                }
                result = result + " with " + cvParam.m_Steps + " steps\n";
            }
        }
        catch (Exception ex) {
            result = result + ex.getMessage();
        }
        result = result + "Classifier Options: " + Utils.joinOptions(this.m_BestClassifierOptions) + "\n\n" + this.m_Classifier.toString();
        return result;
    }

    public String toSummaryString() {
        String result = "Selected values: " + Utils.joinOptions(this.m_BestClassifierOptions);
        return result + '\n';
    }

    public String getRevision() {
        return RevisionUtils.extract("$Revision: 8180 $");
    }

    public static void main(String[] argv) {
        CVParameterSelection.runClassifier(new CVParameterSelection(), argv);
    }

    protected class CVParameter
    implements Serializable,
    RevisionHandler {
        static final long serialVersionUID = -4668812017709421953L;
        private String m_ParamChar;
        private double m_Lower;
        private double m_Upper;
        private double m_Steps;
        private double m_ParamValue;
        private boolean m_AddAtEnd;
        private boolean m_RoundParam;

        public CVParameter(String param) throws Exception {
            String[] parts = param.split(" ");
            if (parts.length < 4 || parts.length > 5) {
                throw new Exception("CVParameter " + param + ": four or five components expected!");
            }
            try {
                Double.parseDouble(parts[0]);
                throw new Exception("CVParameter " + param + ": Character parameter identifier expected");
            }
            catch (NumberFormatException n) {
                this.m_ParamChar = parts[0];
                try {
                    this.m_Lower = Double.parseDouble(parts[1]);
                }
                catch (NumberFormatException n2) {
                    throw new Exception("CVParameter " + param + ": Numeric lower bound expected");
                }
                if (parts[2].equals("A")) {
                    this.m_Upper = this.m_Lower - 1.0;
                } else if (parts[2].equals("I")) {
                    this.m_Upper = this.m_Lower - 2.0;
                } else {
                    try {
                        this.m_Upper = Double.parseDouble(parts[2]);
                        if (this.m_Upper < this.m_Lower) {
                            throw new Exception("CVParameter " + param + ": Upper bound is less than lower bound");
                        }
                    }
                    catch (NumberFormatException n3) {
                        throw new Exception("CVParameter " + param + ": Upper bound must be numeric, or 'A' or 'N'");
                    }
                }
                try {
                    this.m_Steps = Double.parseDouble(parts[3]);
                }
                catch (NumberFormatException n4) {
                    throw new Exception("CVParameter " + param + ": Numeric number of steps expected");
                }
                if (parts.length == 5 && parts[4].equals("R")) {
                    this.m_RoundParam = true;
                }
                return;
            }
        }

        public String toString() {
            String result = this.m_ParamChar + " " + this.m_Lower + " ";
            switch ((int)(this.m_Lower - this.m_Upper + 0.5)) {
                case 1: {
                    result = result + "A";
                    break;
                }
                case 2: {
                    result = result + "I";
                    break;
                }
                default: {
                    result = result + this.m_Upper;
                }
            }
            result = result + " " + this.m_Steps;
            if (this.m_RoundParam) {
                result = result + " R";
            }
            return result;
        }

        public String getRevision() {
            return RevisionUtils.extract("$Revision: 8180 $");
        }
    }
}

