This example trains a 3-layer network using 48 training patterns from four nominal input attributes. The first two nominal attributes have two classifications. The third and fourth nominal attributes have three and four classifications respectively. All four attributes are encoded using binary encoding. This results in eleven binary network input columns. The output class is 1 if the first two nominal attributes sum to 1, and 0 otherwise.
The structure of the network consists of eleven input nodes and three layers, with three perceptrons in the first hidden layer, two perceptrons in the second hidden layer, and one perceptron in the output layer.
There are a total of 47 weights in this network, including the six bias weights. The linearactivations function is used for both hidden layers. Since the target output is binary classification the logistic activation function is used in the output layer. Training is conducted using the quasi-newton trainer with the binary-entropy error function provided by the BinaryClassification
class.
import com.imsl.datamining.neural.*;
import java.io.*;
import java.util.logging.*;
import com.imsl.math.PrintMatrix;
import com.imsl.math.PrintMatrixFormat;
import java.util.Random;
//*****************************************************************************
// Two Layer Feed-Forward Network with 11 inputs: 4 nominal with 2,2,3,4 categories,
// encoded using binary encoding, and 1 output target (class).
//
// new classification training_ex1.c
//*****************************************************************************
public class BinaryClassificationEx1 implements Serializable
{
// Network Settings
private static int nObs = 48; // number of training patterns
private static int nInputs = 11; // four nominal with 2,2,3,4 categories
private static int nCategorical = 11; // three categorical attributes
private static int nOutputs = 1; // one continuous output (nClasses=2)
private static int nPerceptrons1 = 3; // perceptrons in 1st hidden layer
private static int nPerceptrons2 = 2; // perceptrons in 2nd hidden layer
private static boolean trace = true; // Turns on/off training log
private static Activation hiddenLayerActivation = Activation.LINEAR;
private static Activation outputLayerActivation = Activation.LOGISTIC;
/* 2 classifications */
private static int[] x1 = {
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 2, 2,
2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2};
/* 2 classifications */
private static int[] x2 = {
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2,
2, 2, 2, 2, 2, 2, 2, 2, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2};
/* 3 classifications */
private static int[] x3 = {
1, 1, 1, 1, 2, 2, 2, 2, 3, 3, 3, 3, 1, 1, 1, 1,
2, 2, 2, 2, 3, 3, 3, 3, 1, 1, 1, 1, 2, 2, 2, 2,
3, 3, 3, 3, 1, 1, 1, 1, 2, 2, 2, 2, 3, 3, 3, 3 };
/* 4 classifications */
private static int[] x4 = {
1, 2, 3, 4, 1, 2, 3, 4, 1, 2, 3, 4, 1, 2, 3, 4,
1, 2, 3, 4, 1, 2, 3, 4, 1, 2, 3, 4, 1, 2, 3, 4,
1, 2, 3, 4, 1, 2, 3, 4, 1, 2, 3, 4, 1, 2, 3, 4 };
// **********************************************************************
// MAIN
// **********************************************************************
public static void main(String[] args) throws Exception
{
double x[]; // temporary x space for generating forecasts
double xData[][]; // Input Attributes for Trainer
int yData[]; // Output Attributes for Trainer
int i, j; // array indicies
int nWeights = 0; // Number of weights obtained from network
String trainLogName = "BinaryClassificationExample.log";
// ******************************************************************
// Binary encode 4 categorical variables.
// Var x1 contains 2 classes
// Var x2 contains 2 classes
// Var x3 contains 3 classes
// Var x4 contains 4 classes
// *******************************************************************
int[][] z1;
int[][] z2;
int[][] z3;
int[][] z4;
UnsupervisedNominalFilter filter = new UnsupervisedNominalFilter(2);
z1 = filter.encode(x1);
z2 = filter.encode(x2);
filter = new UnsupervisedNominalFilter(3);
z3 = filter.encode(x3);
filter = new UnsupervisedNominalFilter(4);
z4 = filter.encode(x4);
/* Concatenate binary encoded z's */
xData = new double[nObs][nInputs];
yData = new int[nObs];
for (i=0; i<(nObs); i++)
{
for (j=0; j <nCategorical; j++) {
xData[i][j] = 0;
if (j < 2) xData[i][j] = (double) z1[i][j];
if (j > 1 && j < 4) xData[i][j] = (double) z2[i][j-2];
if (j > 3 && j < 7) xData[i][j] = (double) z3[i][j-4];
if (j > 6) xData[i][j] = (double)z4[i][j-7];
}
yData[i] = ((x1[i] +x2[i] == 2) ? 1 : 0);
}
// **********************************************************************
// CREATE FEEDFORWARD NETWORK
// **********************************************************************
long t0 = System.currentTimeMillis();
FeedForwardNetwork network = new FeedForwardNetwork();
network.getInputLayer().createInputs(nInputs);
network.createHiddenLayer().createPerceptrons(nPerceptrons1);
network.createHiddenLayer().createPerceptrons(nPerceptrons2);
network.getOutputLayer().createPerceptrons(nOutputs);
BinaryClassification classification = new BinaryClassification(network);
network.linkAll();
Random r = new Random(123457L);
network.setRandomWeights(xData, r);
Perceptron perceptrons[] = network.getPerceptrons();
for (i=0; i < perceptrons.length-1; i++) {
perceptrons[i].setActivation(hiddenLayerActivation);
}
perceptrons[perceptrons.length-1].setActivation(outputLayerActivation);
// **********************************************************************
// TRAIN NETWORK USING QUASI-NEWTON TRAINER
// **********************************************************************
QuasiNewtonTrainer trainer = new QuasiNewtonTrainer();
trainer.setError(classification.getError());
trainer.setMaximumTrainingIterations(1000);
trainer.setMaximumStepsize(3.0);
trainer.setGradientTolerance(1.0e-20);
trainer.setFalseConvergenceTolerance(1.0e-20);
trainer.setStepTolerance(1.0e-20);
trainer.setRelativeTolerance(1.0e-20);
if (trace) {
try {
Handler handler = new FileHandler(trainLogName);
Logger logger = Logger.getLogger("com.imsl.datamining.neural");
logger.setLevel(Level.FINEST);
logger.addHandler(handler);
handler.setFormatter(QuasiNewtonTrainer.getFormatter());
System.out.println("--> Training Log Created in "+
trainLogName);
} catch (Exception e) {
System.out.println("--> Cannot Create Training Log.");
}
}
classification.train(trainer, xData, yData);
// **********************************************************************
// DISPLAY TRAINING STATISTICS
// **********************************************************************
double stats[] = classification.computeStatistics(xData, yData);
System.out.println("***********************************************");
System.out.println("--> Cross-entropy error: "+(float)stats[0]);
System.out.println("--> Classification error rate: "+(float)stats[1]);
System.out.println("***********************************************");
System.out.println("");
// **********************************************************************
// OBTAIN AND DISPLAY NETWORK WEIGHTS AND GRADIENTS
// **********************************************************************
double weight[] = network.getWeights();
double gradient[] = trainer.getErrorGradient();
double wg[][] = new double[weight.length][2];
for(i = 0; i < weight.length; i++)
{
wg[i][0] = weight[i];
wg[i][1] = gradient[i];
}
PrintMatrixFormat pmf = new PrintMatrixFormat();
pmf.setNumberFormat(new java.text.DecimalFormat("0.000000"));
pmf.setColumnLabels(new String[]{"Weights", "Gradients"});
new PrintMatrix().print(pmf,wg);
// ****************************
// forecast the network
// ****************************
double report[][] = new double[nObs][6];
for ( i = 0; i < nObs; i++)
{
report[i][0] = x1[i];
report[i][1] = x2[i];
report[i][2] = x3[i];
report[i][3] = x4[i];
report[i][4] = yData[i];
report[i][5] = classification.predictedClass(xData[i]);
}
pmf = new PrintMatrixFormat();
pmf.setColumnLabels(new String[]{
"X1", "X2", "X3", "X4",
"Expected", "Predicted"});
new PrintMatrix("Forecast").print(pmf, report);
// **********************************************************************
// DISPLAY CLASSIFICATION STATISTICS
// **********************************************************************
double statsClass[] = classification.computeStatistics(xData, yData);
// Display Network Errors
System.out.println("***********************************************");
System.out.println("--> Cross-Entropy Error: "+(float)statsClass[0]);
System.out.println("--> Classification Error: "+(float)statsClass[1]);
System.out.println("***********************************************");
System.out.println("");
long t1 = System.currentTimeMillis();
double small = 1.e-7;
double time = t1-t0;
time = time/1000;
System.out.println("****************Time: "+time);
System.out.println("trainer.getErrorValue = "+trainer.getErrorValue());
}
}
--> Training Log Created in BinaryClassificationExample.log
***********************************************
--> Cross-entropy error: 1.8296475E-13
--> Classification error rate: 0.0
***********************************************
Weights Gradients
0 2.575599 -0.000000
1 1.770546 -0.000000
2 1.675687 -0.000000
3 -5.859796 0.000000
4 -1.794721 0.000000
5 -4.925026 0.000000
6 3.654187 0.000000
7 2.089872 0.000000
8 2.485173 0.000000
9 -5.238608 0.000000
10 -1.396975 0.000000
11 -4.730949 0.000000
12 0.143083 0.000000
13 0.777367 0.000000
14 0.316769 0.000000
15 -3.270781 -0.000000
16 0.283153 -0.000000
17 -0.162338 -0.000000
18 1.153316 0.000000
19 0.782549 0.000000
20 -0.387279 0.000000
21 -2.010958 -0.000000
22 0.273662 -0.000000
23 -0.670019 -0.000000
24 2.096144 0.000000
25 -0.264374 0.000000
26 0.351305 0.000000
27 1.190361 0.000000
28 -0.053966 0.000000
29 0.555192 0.000000
30 -2.001125 -0.000000
31 0.735950 -0.000000
32 -0.829534 -0.000000
33 -4.824521 0.000000
34 -4.824521 0.000000
35 -0.652606 0.000000
36 -0.652606 0.000000
37 -2.921224 0.000000
38 -2.921224 0.000000
39 -1.621591 0.000000
40 -1.621591 0.000000
41 -1.967947 0.000000
42 1.534864 0.000000
43 0.907830 0.000000
44 1.594078 -0.000000
45 1.594078 -0.000000
46 -0.169361 0.000000
Forecast
X1 X2 X3 X4 Expected Predicted
0 1 1 1 1 1 1
1 1 1 1 2 1 1
2 1 1 1 3 1 1
3 1 1 1 4 1 1
4 1 1 2 1 1 1
5 1 1 2 2 1 1
6 1 1 2 3 1 1
7 1 1 2 4 1 1
8 1 1 3 1 1 1
9 1 1 3 2 1 1
10 1 1 3 3 1 1
11 1 1 3 4 1 1
12 1 2 1 1 0 0
13 1 2 1 2 0 0
14 1 2 1 3 0 0
15 1 2 1 4 0 0
16 1 2 2 1 0 0
17 1 2 2 2 0 0
18 1 2 2 3 0 0
19 1 2 2 4 0 0
20 1 2 3 1 0 0
21 1 2 3 2 0 0
22 1 2 3 3 0 0
23 1 2 3 4 0 0
24 2 1 1 1 0 0
25 2 1 1 2 0 0
26 2 1 1 3 0 0
27 2 1 1 4 0 0
28 2 1 2 1 0 0
29 2 1 2 2 0 0
30 2 1 2 3 0 0
31 2 1 2 4 0 0
32 2 1 3 1 0 0
33 2 1 3 2 0 0
34 2 1 3 3 0 0
35 2 1 3 4 0 0
36 2 2 1 1 0 0
37 2 2 1 2 0 0
38 2 2 1 3 0 0
39 2 2 1 4 0 0
40 2 2 2 1 0 0
41 2 2 2 2 0 0
42 2 2 2 3 0 0
43 2 2 2 4 0 0
44 2 2 3 1 0 0
45 2 2 3 2 0 0
46 2 2 3 3 0 0
47 2 2 3 4 0 0
***********************************************
--> Cross-Entropy Error: 1.8296475E-13
--> Classification Error: 0.0
***********************************************
****************Time: 0.822
trainer.getErrorValue = 1.8296475445823478E-13
Link to Java source.