This example trains a 2-layer network using three binary inputs (X0, X1, X2) and one three-level classification (Y). Where
Y = 0 if X1 = 1
Y = 1 if X2 = 1
Y = 2 if X3 = 1
import com.imsl.datamining.neural.*;
import com.imsl.math.PrintMatrix;
import com.imsl.math.PrintMatrixFormat;
import java.io.*;
import java.util.logging.*;
//*****************************************************************************
// Two-Layer FFN with 3 binary inputs (X0, X1, X2) and one three-level
// classification variable (Y)
// Y = 0 if X1 = 1
// Y = 1 if X2 = 1
// Y = 2 if X3 = 1
// (training_ex6)
//*****************************************************************************
public class MultiClassificationEx2 implements Serializable {
private static int nObs = 6; // number of training patterns
private static int nInputs = 3; // 3 inputs, all categorical
private static int nOutputs = 3; //
private static boolean trace = true; // Turns on/off training log
private static double xData[][] = {
{1, 0, 0}, {1, 0, 0}, {0, 1, 0}, {0, 1, 0}, {0, 0, 1}, {0, 0, 1}
};
private static int yData[] = {1, 1, 2, 2, 3, 3};
private static double weights[] = {
1.29099444873580580000,-0.64549722436790280000,-0.64549722436790291000,
0.00000000000000000000, 1.11803398874989490000,-1.11803398874989470000,
0.57735026918962584000, 0.57735026918962584000, 0.57735026918962584000,
0.33333333333333331000, 0.33333333333333331000, 0.33333333333333331000,
0.33333333333333331000, 0.33333333333333331000, 0.33333333333333331000,
0.33333333333333331000, 0.33333333333333331000, 0.33333333333333331000,
-0.00000000000000005851,-0.00000000000000005851,-0.57735026918962573000,
0.00000000000000000000, 0.00000000000000000000, 0.00000000000000000000};
public static void main(String[] args) throws Exception {
FeedForwardNetwork network = new FeedForwardNetwork();
network.getInputLayer().createInputs(nInputs);
network.createHiddenLayer().createPerceptrons(3, Activation.LINEAR, 0.0);
//network.createHiddenLayer().createPerceptrons(4, Activation.TANH, 0.0);
network.getOutputLayer().createPerceptrons(nOutputs, Activation.SOFTMAX, 0.0);
network.linkAll();
network.setWeights(weights);
MultiClassification classification = new MultiClassification(network);
QuasiNewtonTrainer trainer = new QuasiNewtonTrainer();
trainer.setError(classification.getError());
trainer.setMaximumTrainingIterations(1000);
trainer.setFalseConvergenceTolerance(1.0e-20);
trainer.setGradientTolerance(1.0e-20);
trainer.setRelativeTolerance(1.0e-20);
trainer.setStepTolerance(1.0e-20);
// If tracing is requested setup training logger
if (trace) {
Handler handler = new FileHandler("ClassificationNetworkEx2.log");
Logger logger = Logger.getLogger("com.imsl.datamining.neural");
logger.setLevel(Level.FINEST);
logger.addHandler(handler);
handler.setFormatter(QuasiNewtonTrainer.getFormatter());
}
// Train Network
classification.train(trainer, xData, yData);
// Display Network Errors
double stats[] = classification.computeStatistics(xData, yData);
System.out.println("***********************************************");
System.out.println("--> Cross-Entropy Error: "+(float)stats[0]);
System.out.println("--> Classification Error: "+(float)stats[1]);
System.out.println("***********************************************");
System.out.println();
double weight[] = network.getWeights();
double gradient[] = trainer.getErrorGradient();
double wg[][] = new double[weight.length][2];
for(int i = 0; i < weight.length; i++) {
wg[i][0] = weight[i];
wg[i][1] = gradient[i];
}
PrintMatrixFormat pmf = new PrintMatrixFormat();
pmf.setNumberFormat(new java.text.DecimalFormat("0.000000"));
pmf.setColumnLabels(new String[]{"Weights", "Gradients"});
new PrintMatrix().print(pmf,wg);
double report[][] = new double[nObs][nInputs+nOutputs+2];
for (int i = 0; i < nObs; i++) {
for (int j = 0; j < nInputs; j++) {
report[i][j] = xData[i][j];
}
report[i][nInputs] = yData[i];
double p[] = classification.probabilities(xData[i]);
for (int j = 0; j < nOutputs; j++) {
report[i][nInputs+1+j] = p[j];
}
report[i][nInputs+nOutputs+1] = classification.predictedClass(xData[i]);
}
pmf = new PrintMatrixFormat();
pmf.setColumnLabels(new String[]{"X1", "X2", "X3", "Y", "P(C1)", "P(C2)",
"P(C3)", "Predicted"});
new PrintMatrix("Forecast").print(pmf, report);
System.out.println("Cross-Entropy Error Value = "+trainer.getErrorValue());
// **********************************************************************
// DISPLAY CLASSIFICATION STATISTICS
// **********************************************************************
double statsClass[] = classification.computeStatistics(xData, yData);
// Display Network Errors
System.out.println("***********************************************");
System.out.println("--> Cross-Entropy Error: "+(float)statsClass[0]);
System.out.println("--> Classification Error: "+(float)statsClass[1]);
System.out.println("***********************************************");
System.out.println("");
}
}
***********************************************
--> Cross-Entropy Error: 0.0
--> Classification Error: 0.0
***********************************************
Weights Gradients
0 3.401208 -0.000000
1 -4.126657 0.000000
2 -2.201606 -0.000000
3 -2.009527 0.000000
4 3.173323 -0.000000
5 -4.200377 -0.000000
6 0.028736 -0.000000
7 2.657051 0.000000
8 4.868134 -0.000000
9 3.711295 -0.000000
10 -2.723536 -0.000000
11 0.012241 0.000000
12 -4.996359 0.000000
13 4.296983 0.000000
14 1.699376 -0.000000
15 -1.993114 0.000000
16 -4.048833 0.000000
17 7.041948 -0.000000
18 -0.447927 -0.000000
19 0.653830 0.000000
20 -0.925019 -0.000000
21 -0.078963 0.000000
22 0.247835 0.000000
23 -0.168872 -0.000000
Forecast
X1 X2 X3 Y P(C1) P(C2) P(C3) Predicted
0 1 0 0 1 1 0 0 1
1 1 0 0 1 1 0 0 1
2 0 1 0 2 0 1 0 2
3 0 1 0 2 0 1 0 2
4 0 0 1 3 0 0 1 3
5 0 0 1 3 0 0 1 3
Cross-Entropy Error Value = 0.0
***********************************************
--> Cross-Entropy Error: 0.0
--> Classification Error: 0.0
***********************************************
Link to Java source.