Example 1: Binary Classification

This example trains a 3-layer network using 48 training patterns from four nominal input attributes. The first two nominal attributes have two classifications. The third and fourth nominal attributes have three and four classifications respectively. All four attributes are encoded using binary encoding. This results in eleven binary network input columns. The output class is 1 if the first two nominal attributes sum to 1, and 0 otherwise.

The structure of the network consists of eleven input nodes and three layers, with three perceptrons in the first hidden layer, two perceptrons in the second hidden layer, and one perceptron in the output layer.

There are a total of 47 weights in this network, including the six bias weights. The linearactivations function is used for both hidden layers. Since the target output is binary classification the logistic activation function is used in the output layer. Training is conducted using the quasi-newton trainer with the binary-entropy error function provided by the BinaryClassification class.

import com.imsl.datamining.neural.*;
import java.io.*;
import java.util.logging.*;
import com.imsl.math.PrintMatrix;
import com.imsl.math.PrintMatrixFormat;
import java.util.Random;

//*****************************************************************************
// Two Layer Feed-Forward Network with 11 inputs: 4 nominal with 2,2,3,4 categories,
// encoded using binary encoding,  and 1 output target (class).  
//
//  new classification training_ex1.c
//*****************************************************************************

public class BinaryClassificationEx1 implements Serializable 
{

// Network Settings
    private static int nObs          = 48; // number of training patterns
    private static int nInputs       =  11; // four nominal with 2,2,3,4 categories
    private static int nCategorical  =  11; // three categorical attributes
    private static int nOutputs      =  1; // one continuous output (nClasses=2)
    private static int nPerceptrons1 =  3; // perceptrons in 1st hidden layer
    private static int nPerceptrons2 =  2; // perceptrons in 2nd hidden layer
    private static boolean trace     = true; // Turns on/off training log
    
    private static Activation hiddenLayerActivation = Activation.LINEAR;
    private static Activation outputLayerActivation = Activation.LOGISTIC;

        /* 2 classifications */
    private static int[] x1 = {   
        1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 
        1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 2, 2, 
        2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2};

                     /* 2 classifications */
    private static int[] x2 = {  
        1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2,
        2, 2, 2, 2, 2, 2, 2, 2, 1, 1, 1, 1, 1, 1, 1, 1,
        1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2};

                       /* 3 classifications */
    private static int[] x3 = { 
        1, 1, 1, 1, 2, 2, 2, 2, 3, 3, 3, 3, 1, 1, 1, 1,
        2, 2, 2, 2, 3, 3, 3, 3, 1, 1, 1, 1, 2, 2, 2, 2,
        3, 3, 3, 3, 1, 1, 1, 1, 2, 2, 2, 2, 3, 3, 3, 3 };

                       /* 4 classifications */
    private static int[] x4 = {   
       1, 2, 3, 4, 1, 2, 3, 4, 1, 2, 3, 4, 1, 2, 3, 4,
       1, 2, 3, 4, 1, 2, 3, 4, 1, 2, 3, 4, 1, 2, 3, 4,
       1, 2, 3, 4, 1, 2, 3, 4, 1, 2, 3, 4, 1, 2, 3, 4 };
   
 // **********************************************************************
 // MAIN
 // **********************************************************************
    public static void main(String[] args) throws Exception 
    {
       double x[];       // temporary x space for generating forecasts
       double xData[][]; // Input  Attributes for Trainer
       int    yData[]; // Output Attributes for Trainer
       int i, j;         // array indicies
       int nWeights = 0; // Number of weights obtained from network
       String trainLogName    = "BinaryClassificationExample.log";
       
       
       // ******************************************************************
       // Binary encode 4 categorical variables.   
       //         Var x1 contains 2 classes
       //         Var x2 contains 2 classes
       //         Var x3 contains 3 classes
       //         Var x4 contains 4 classes
       // *******************************************************************
        int[][] z1;
        int[][] z2;
        int[][] z3;
        int[][] z4;
        UnsupervisedNominalFilter filter = new UnsupervisedNominalFilter(2);
        z1 = filter.encode(x1);
        z2 = filter.encode(x2);
        filter = new UnsupervisedNominalFilter(3);
        z3 = filter.encode(x3);
        filter = new UnsupervisedNominalFilter(4);
        z4 = filter.encode(x4);
        
        /* Concatenate binary encoded z's */
        xData = new double[nObs][nInputs];
        yData = new int[nObs];
        for (i=0; i<(nObs); i++) 
        {
            for (j=0; j <nCategorical; j++) {
                xData[i][j] = 0;
                if (j < 2) xData[i][j] = (double) z1[i][j];
                if (j > 1 && j < 4)  xData[i][j] = (double) z2[i][j-2];
                if (j > 3 && j < 7)  xData[i][j] = (double) z3[i][j-4];
                if (j > 6)  xData[i][j] = (double)z4[i][j-7];
            }
            yData[i] = ((x1[i] +x2[i] == 2) ? 1 : 0);
        }
        
    // **********************************************************************
    // CREATE FEEDFORWARD NETWORK
    // **********************************************************************
       long t0 = System.currentTimeMillis();

       FeedForwardNetwork network = new FeedForwardNetwork();
       network.getInputLayer().createInputs(nInputs);
       network.createHiddenLayer().createPerceptrons(nPerceptrons1);
       network.createHiddenLayer().createPerceptrons(nPerceptrons2);
       network.getOutputLayer().createPerceptrons(nOutputs);
       
       BinaryClassification classification = new BinaryClassification(network);
       
       network.linkAll();
       Random r = new Random(123457L);
       network.setRandomWeights(xData, r);
       Perceptron perceptrons[] = network.getPerceptrons();
       for (i=0; i < perceptrons.length-1; i++) {
          perceptrons[i].setActivation(hiddenLayerActivation);
       }
       perceptrons[perceptrons.length-1].setActivation(outputLayerActivation);
       

    // **********************************************************************
    // TRAIN NETWORK USING QUASI-NEWTON TRAINER
    // **********************************************************************
       QuasiNewtonTrainer trainer = new QuasiNewtonTrainer();
       trainer.setError(classification.getError());
       trainer.setMaximumTrainingIterations(1000);
       trainer.setMaximumStepsize(3.0);
       trainer.setGradientTolerance(1.0e-20);
       trainer.setFalseConvergenceTolerance(1.0e-20);
       trainer.setStepTolerance(1.0e-20);
       trainer.setRelativeTolerance(1.0e-20);
       if (trace) {
            try {
                Handler handler = new FileHandler(trainLogName);
                Logger logger = Logger.getLogger("com.imsl.datamining.neural");
                logger.setLevel(Level.FINEST);
                logger.addHandler(handler);
                handler.setFormatter(QuasiNewtonTrainer.getFormatter());
                System.out.println("--> Training Log Created in "+
                                   trainLogName);
            } catch (Exception e) {
                System.out.println("--> Cannot Create Training Log.");
            }
        }
       classification.train(trainer, xData, yData);

    // **********************************************************************
    // DISPLAY TRAINING STATISTICS
    // **********************************************************************
       double stats[] = classification.computeStatistics(xData, yData);
       System.out.println("***********************************************");
       System.out.println("--> Cross-entropy error:        "+(float)stats[0]);
       System.out.println("--> Classification error rate:  "+(float)stats[1]);
       System.out.println("***********************************************");
       System.out.println("");
       // **********************************************************************
       // OBTAIN AND DISPLAY NETWORK WEIGHTS AND GRADIENTS
       // **********************************************************************
        double weight[]   = network.getWeights();
        double gradient[] = trainer.getErrorGradient();
        double wg[][] = new double[weight.length][2];
        for(i = 0;  i < weight.length;  i++) 
        {
            wg[i][0] = weight[i];
            wg[i][1] = gradient[i];
        }
        PrintMatrixFormat pmf = new PrintMatrixFormat();
        pmf.setNumberFormat(new java.text.DecimalFormat("0.000000"));
        pmf.setColumnLabels(new String[]{"Weights", "Gradients"});
        new PrintMatrix().print(pmf,wg);

        //  ****************************
        //     forecast the network
        //  ****************************
        double report[][] = new double[nObs][6];  
        for ( i = 0;  i < nObs;  i++) 
        {
            report[i][0] = x1[i];
            report[i][1] = x2[i];
            report[i][2] = x3[i];
            report[i][3] = x4[i];
            report[i][4] = yData[i];
            report[i][5] = classification.predictedClass(xData[i]);
        }
        pmf = new PrintMatrixFormat();
        pmf.setColumnLabels(new String[]{
            "X1", "X2", "X3", "X4", 
            "Expected", "Predicted"});
        new PrintMatrix("Forecast").print(pmf, report);


    // **********************************************************************
    // DISPLAY CLASSIFICATION STATISTICS
    // **********************************************************************
       double statsClass[] = classification.computeStatistics(xData, yData);
       // Display Network Errors
       System.out.println("***********************************************");
       System.out.println("--> Cross-Entropy Error:      "+(float)statsClass[0]);
       System.out.println("--> Classification Error:     "+(float)statsClass[1]);
       System.out.println("***********************************************");
       System.out.println("");
        
        
        long t1 = System.currentTimeMillis();
        double small = 1.e-7;
        double time =  t1-t0; 
        time = time/1000;
        System.out.println("****************Time:  "+time);
        System.out.println("trainer.getErrorValue = "+trainer.getErrorValue());
    }
}

Output

--> Training Log Created in BinaryClassificationExample.log
***********************************************
--> Cross-entropy error:        1.8296475E-13
--> Classification error rate:  0.0
***********************************************

     Weights   Gradients  
 0   2.575599  -0.000000  
 1   1.770546  -0.000000  
 2   1.675687  -0.000000  
 3  -5.859796   0.000000  
 4  -1.794721   0.000000  
 5  -4.925026   0.000000  
 6   3.654187   0.000000  
 7   2.089872   0.000000  
 8   2.485173   0.000000  
 9  -5.238608   0.000000  
10  -1.396975   0.000000  
11  -4.730949   0.000000  
12   0.143083   0.000000  
13   0.777367   0.000000  
14   0.316769   0.000000  
15  -3.270781  -0.000000  
16   0.283153  -0.000000  
17  -0.162338  -0.000000  
18   1.153316   0.000000  
19   0.782549   0.000000  
20  -0.387279   0.000000  
21  -2.010958  -0.000000  
22   0.273662  -0.000000  
23  -0.670019  -0.000000  
24   2.096144   0.000000  
25  -0.264374   0.000000  
26   0.351305   0.000000  
27   1.190361   0.000000  
28  -0.053966   0.000000  
29   0.555192   0.000000  
30  -2.001125  -0.000000  
31   0.735950  -0.000000  
32  -0.829534  -0.000000  
33  -4.824521   0.000000  
34  -4.824521   0.000000  
35  -0.652606   0.000000  
36  -0.652606   0.000000  
37  -2.921224   0.000000  
38  -2.921224   0.000000  
39  -1.621591   0.000000  
40  -1.621591   0.000000  
41  -1.967947   0.000000  
42   1.534864   0.000000  
43   0.907830   0.000000  
44   1.594078  -0.000000  
45   1.594078  -0.000000  
46  -0.169361   0.000000  

                Forecast
    X1  X2  X3  X4  Expected  Predicted  
 0  1   1   1   1      1          1      
 1  1   1   1   2      1          1      
 2  1   1   1   3      1          1      
 3  1   1   1   4      1          1      
 4  1   1   2   1      1          1      
 5  1   1   2   2      1          1      
 6  1   1   2   3      1          1      
 7  1   1   2   4      1          1      
 8  1   1   3   1      1          1      
 9  1   1   3   2      1          1      
10  1   1   3   3      1          1      
11  1   1   3   4      1          1      
12  1   2   1   1      0          0      
13  1   2   1   2      0          0      
14  1   2   1   3      0          0      
15  1   2   1   4      0          0      
16  1   2   2   1      0          0      
17  1   2   2   2      0          0      
18  1   2   2   3      0          0      
19  1   2   2   4      0          0      
20  1   2   3   1      0          0      
21  1   2   3   2      0          0      
22  1   2   3   3      0          0      
23  1   2   3   4      0          0      
24  2   1   1   1      0          0      
25  2   1   1   2      0          0      
26  2   1   1   3      0          0      
27  2   1   1   4      0          0      
28  2   1   2   1      0          0      
29  2   1   2   2      0          0      
30  2   1   2   3      0          0      
31  2   1   2   4      0          0      
32  2   1   3   1      0          0      
33  2   1   3   2      0          0      
34  2   1   3   3      0          0      
35  2   1   3   4      0          0      
36  2   2   1   1      0          0      
37  2   2   1   2      0          0      
38  2   2   1   3      0          0      
39  2   2   1   4      0          0      
40  2   2   2   1      0          0      
41  2   2   2   2      0          0      
42  2   2   2   3      0          0      
43  2   2   2   4      0          0      
44  2   2   3   1      0          0      
45  2   2   3   2      0          0      
46  2   2   3   3      0          0      
47  2   2   3   4      0          0      

***********************************************
--> Cross-Entropy Error:      1.8296475E-13
--> Classification Error:     0.0
***********************************************

****************Time:  0.822
trainer.getErrorValue = 1.8296475445823478E-13
Link to Java source.