This example trains a 3-layer network using Fisher's Iris data with four continuous input attributes and three output classifications. This is perhaps the best known database to be found in the pattern recognition literature. Fisher's paper is a classic in the field. The data set contains 3 classes of 50 instances each, where each class refers to a type of iris plant.
The structure of the network consists of four input nodes and three layers, with four perceptrons in the first hidden layer, three perceptrons in the second hidden layer and three in the output layer.
The four input attributes represent
The output attribute represents the class of the iris plant and are encoded using binary encoding.
There are a total of 46 weights in this network, including the bias weights. All hidden layers use the logistic activation function. Since the target output is multi-classification the softmax activation function is used in the output layer and the MultiClassification
error function class is used by the trainer. The error class MultiClassification
combines the cross-entropy error claculations and the softmax function.
import com.imsl.datamining.neural.*;
import com.imsl.math.PrintMatrix;
import com.imsl.math.PrintMatrixFormat;
import java.io.*;
import java.util.logging.*;
//*****************************************************************************
// Three Layer Feed-Forward Network with 4 inputs, all
// continuous, and 3 classification categories.
//
// new classification training_ex5.c
//
// This is perhaps the best known database to be found in the pattern
// recognition literature. Fisher's paper is a classic in the field.
// The data set contains 3 classes of 50 instances each,
// where each class refers to a type of iris plant. One class is
// linearly separable from the other 2; the latter are NOT linearly
// separable from each other.
//
// Predicted attribute: class of iris plant.
// 1=Iris Setosa, 2=Iris Versicolour, and 3=Iris Virginica
//
// Input Attributes (4 Continuous Attributes)
// X1: Sepal length, X2: Sepal width, X3: Petal length, and X4: Petal width
//*****************************************************************************
public class MultiClassificationEx1 implements Serializable {
private static int nObs = 150; // number of training patterns
private static int nInputs = 4; // 9 nominal coded as 0=x, 1=o, 2=blank
private static int nOutputs = 3; // one continuous output (nClasses=2)
private static boolean trace = true; // Turns on/off training log
// irisData[]: The raw data matrix. This is a 2-D matrix with 150 rows and 5 columns. *
// The first 4 columns are the continuous input attributes and the 5th *
// column is the classification category (1-3). These data contain no *
// categorical input attributes. *
private static double[][] irisData = {
{5.1,3.5,1.4,0.2,1},{4.9,3.0,1.4,0.2,1},{4.7,3.2,1.3,0.2,1},{4.6,3.1,1.5,0.2,1},
{5.0,3.6,1.4,0.2,1},{5.4,3.9,1.7,0.4,1},{4.6,3.4,1.4,0.3,1},{5.0,3.4,1.5,0.2,1},
{4.4,2.9,1.4,0.2,1},{4.9,3.1,1.5,0.1,1},{5.4,3.7,1.5,0.2,1},{4.8,3.4,1.6,0.2,1},
{4.8,3.0,1.4,0.1,1},{4.3,3.0,1.1,0.1,1},{5.8,4.0,1.2,0.2,1},{5.7,4.4,1.5,0.4,1},
{5.4,3.9,1.3,0.4,1},{5.1,3.5,1.4,0.3,1},{5.7,3.8,1.7,0.3,1},{5.1,3.8,1.5,0.3,1},
{5.4,3.4,1.7,0.2,1},{5.1,3.7,1.5,0.4,1},{4.6,3.6,1.0,0.2,1},{5.1,3.3,1.7,0.5,1},
{4.8,3.4,1.9,0.2,1},{5.0,3.0,1.6,0.2,1},{5.0,3.4,1.6,0.4,1},{5.2,3.5,1.5,0.2,1},
{5.2,3.4,1.4,0.2,1},{4.7,3.2,1.6,0.2,1},{4.8,3.1,1.6,0.2,1},{5.4,3.4,1.5,0.4,1},
{5.2,4.1,1.5,0.1,1},{5.5,4.2,1.4,0.2,1},{4.9,3.1,1.5,0.1,1},{5.0,3.2,1.2,0.2,1},
{5.5,3.5,1.3,0.2,1},{4.9,3.1,1.5,0.1,1},{4.4,3.0,1.3,0.2,1},{5.1,3.4,1.5,0.2,1},
{5.0,3.5,1.3,0.3,1},{4.5,2.3,1.3,0.3,1},{4.4,3.2,1.3,0.2,1},{5.0,3.5,1.6,0.6,1},
{5.1,3.8,1.9,0.4,1},{4.8,3.0,1.4,0.3,1},{5.1,3.8,1.6,0.2,1},{4.6,3.2,1.4,0.2,1},
{5.3,3.7,1.5,0.2,1},{5.0,3.3,1.4,0.2,1},
{7.0,3.2,4.7,1.4,2},{6.4,3.2,4.5,1.5,2},{6.9,3.1,4.9,1.5,2},{5.5,2.3,4.0,1.3,2},
{6.5,2.8,4.6,1.5,2},{5.7,2.8,4.5,1.3,2},{6.3,3.3,4.7,1.6,2},{4.9,2.4,3.3,1.0,2},
{6.6,2.9,4.6,1.3,2},{5.2,2.7,3.9,1.4,2},{5.0,2.0,3.5,1.0,2},{5.9,3.0,4.2,1.5,2},
{6.0,2.2,4.0,1.0,2},{6.1,2.9,4.7,1.4,2},{5.6,2.9,3.6,1.3,2},{6.7,3.1,4.4,1.4,2},
{5.6,3.0,4.5,1.5,2},{5.8,2.7,4.1,1.0,2},{6.2,2.2,4.5,1.5,2},{5.6,2.5,3.9,1.1,2},
{5.9,3.2,4.8,1.8,2},{6.1,2.8,4.0,1.3,2},{6.3,2.5,4.9,1.5,2},{6.1,2.8,4.7,1.2,2},
{6.4,2.9,4.3,1.3,2},{6.6,3.0,4.4,1.4,2},{6.8,2.8,4.8,1.4,2},{6.7,3.0,5.0,1.7,2},
{6.0,2.9,4.5,1.5,2},{5.7,2.6,3.5,1.0,2},{5.5,2.4,3.8,1.1,2},{5.5,2.4,3.7,1.0,2},
{5.8,2.7,3.9,1.2,2},{6.0,2.7,5.1,1.6,2},{5.4,3.0,4.5,1.5,2},{6.0,3.4,4.5,1.6,2},
{6.7,3.1,4.7,1.5,2},{6.3,2.3,4.4,1.3,2},{5.6,3.0,4.1,1.3,2},{5.5,2.5,4.0,1.3,2},
{5.5,2.6,4.4,1.2,2},{6.1,3.0,4.6,1.4,2},{5.8,2.6,4.0,1.2,2},{5.0,2.3,3.3,1.0,2},
{5.6,2.7,4.2,1.3,2},{5.7,3.0,4.2,1.2,2},{5.7,2.9,4.2,1.3,2},{6.2,2.9,4.3,1.3,2},
{5.1,2.5,3.0,1.1,2},{5.7,2.8,4.1,1.3,2},
{6.3,3.3,6.0,2.5,3},{5.8,2.7,5.1,1.9,3},{7.1,3.0,5.9,2.1,3},{6.3,2.9,5.6,1.8,3},
{6.5,3.0,5.8,2.2,3},{7.6,3.0,6.6,2.1,3},{4.9,2.5,4.5,1.7,3},{7.3,2.9,6.3,1.8,3},
{6.7,2.5,5.8,1.8,3},{7.2,3.6,6.1,2.5,3},{6.5,3.2,5.1,2.0,3},{6.4,2.7,5.3,1.9,3},
{6.8,3.0,5.5,2.1,3},{5.7,2.5,5.0,2.0,3},{5.8,2.8,5.1,2.4,3},{6.4,3.2,5.3,2.3,3},
{6.5,3.0,5.5,1.8,3},{7.7,3.8,6.7,2.2,3},{7.7,2.6,6.9,2.3,3},{6.0,2.2,5.0,1.5,3},
{6.9,3.2,5.7,2.3,3},{5.6,2.8,4.9,2.0,3},{7.7,2.8,6.7,2.0,3},{6.3,2.7,4.9,1.8,3},
{6.7,3.3,5.7,2.1,3},{7.2,3.2,6.0,1.8,3},{6.2,2.8,4.8,1.8,3},{6.1,3.0,4.9,1.8,3},
{6.4,2.8,5.6,2.1,3},{7.2,3.0,5.8,1.6,3},{7.4,2.8,6.1,1.9,3},{7.9,3.8,6.4,2.0,3},
{6.4,2.8,5.6,2.2,3},{6.3,2.8,5.1,1.5,3},{6.1,2.6,5.6,1.4,3},{7.7,3.0,6.1,2.3,3},
{6.3,3.4,5.6,2.4,3},{6.4,3.1,5.5,1.8,3},{6.0,3.0,4.8,1.8,3},{6.9,3.1,5.4,2.1,3},
{6.7,3.1,5.6,2.4,3},{6.9,3.1,5.1,2.3,3},{5.8,2.7,5.1,1.9,3},{6.8,3.2,5.9,2.3,3},
{6.7,3.3,5.7,2.5,3},{6.7,3.0,5.2,2.3,3},{6.3,2.5,5.0,1.9,3},{6.5,3.0,5.2,2.0,3},
{6.2,3.4,5.4,2.3,3},{5.9,3.0,5.1,1.8,3}
};
public static void main(String[] args) throws Exception {
double xData[][] = new double[nObs][nInputs];
int yData[] = new int[nObs];
for (int i = 0; i < nObs; i++) {
for (int j = 0; j < nInputs; j++) {
xData[i][j] = irisData[i][j];
}
yData[i] = (int)irisData[i][4];
}
// Create network
FeedForwardNetwork network = new FeedForwardNetwork();
network.getInputLayer().createInputs(nInputs);
network.createHiddenLayer().createPerceptrons(4, Activation.LOGISTIC, 0.0);
network.createHiddenLayer().createPerceptrons(3, Activation.LOGISTIC, 0.0);
network.getOutputLayer().createPerceptrons(nOutputs, Activation.SOFTMAX, 0.0);
network.linkAll();
MultiClassification classification = new MultiClassification(network);
// Create trainer
QuasiNewtonTrainer trainer = new QuasiNewtonTrainer();
trainer.setError(classification.getError());
trainer.setMaximumTrainingIterations(1000);
// If tracing is requested setup training logger
if (trace) {
Handler handler = new FileHandler("ClassificationNetworkTraining.log");
Logger logger = Logger.getLogger("com.imsl.datamining.neural");
logger.setLevel(Level.FINEST);
logger.addHandler(handler);
handler.setFormatter(QuasiNewtonTrainer.getFormatter());
}
// Train Network
long t0 = System.currentTimeMillis();
classification.train(trainer, xData, yData);
// Display Network Errors
double stats[] = classification.computeStatistics(xData, yData);
System.out.println("***********************************************");
System.out.println("--> Cross-entropy error: "+(float)stats[0]);
System.out.println("--> Classification error rate: "+(float)stats[1]);
System.out.println("***********************************************");
System.out.println("");
double weight[] = network.getWeights();
double gradient[] = trainer.getErrorGradient();
double wg[][] = new double[weight.length][2];
for(int i = 0; i < weight.length; i++)
{
wg[i][0] = weight[i];
wg[i][1] = gradient[i];
}
PrintMatrixFormat pmf = new PrintMatrixFormat();
pmf.setNumberFormat(new java.text.DecimalFormat("0.000000"));
pmf.setColumnLabels(new String[]{"Weights", "Gradients"});
new PrintMatrix().print(pmf,wg);
double report[][] = new double[nObs][nInputs+2];
for (int i = 0; i < nObs; i++) {
for (int j = 0; j < nInputs; j++) {
report[i][j] = xData[i][j];
}
report[i][nInputs] = irisData[i][4];
report[i][nInputs+1] = classification.predictedClass(xData[i]);
}
pmf = new PrintMatrixFormat();
pmf.setColumnLabels(new String[]{
"Sepal Length",
"Sepal Width",
"Petal Length",
"Petal Width",
"Expected",
"Predicted"});
new PrintMatrix("Forecast").print(pmf, report);
// **********************************************************************
// DISPLAY CLASSIFICATION STATISTICS
// **********************************************************************
double statsClass[] = classification.computeStatistics(xData, yData);
// Display Network Errors
System.out.println("***********************************************");
System.out.println("--> Cross-Entropy Error: "+(float)statsClass[0]);
System.out.println("--> Classification Error: "+(float)statsClass[1]);
System.out.println("***********************************************");
System.out.println("");
long t1 = System.currentTimeMillis();
double small = 1.e-7;
double time = t1-t0; //Math.max(small, (double)(t1-t0)/(double)iters);
time = time/1000;
System.out.println("****************Time: "+time);
System.out.println("Cross-Entropy Error Value = "+trainer.getErrorValue());
}
}
***********************************************
--> Cross-entropy error: 4.640623
--> Classification error rate: 0.006666667
***********************************************
Weights Gradients
0 -51.777881 -0.021660
1 605.119380 0.000000
2 -284.226877 0.000000
3 327.038883 0.000000
4 -41.160485 -0.009887
5 -867.891312 0.000000
6 -1210.846071 0.000000
7 -994.103717 0.000000
8 73.932788 -0.016740
9 -346.829319 0.000000
10 704.482597 0.000000
11 -497.908892 0.000000
12 51.636506 -0.006301
13 1943.984336 0.000000
14 1516.711136 0.000000
15 1935.687178 0.000000
16 -3.143561 -2.271656
17 -443.852301 -7.201949
18 242.475544 -0.000024
19 23.461487 -2.272793
20 189.287779 -7.201954
21 260.386655 -0.096456
22 564.420647 -2.272793
23 607.227248 -7.201954
24 -62.368750 -0.096456
25 163.370794 -2.272793
26 216.054929 -7.201954
27 296.537883 -0.096456
28 -15686.506783 0.000000
29 3478.164215 0.004606
30 12209.342568 -0.004606
31 -15443.797985 0.000000
32 4719.334347 0.002674
33 10725.463639 -0.002674
34 -15303.926099 0.000000
35 3602.472102 0.004863
36 11702.453998 -0.004863
37 -19.854440 -0.003322
38 965.005400 0.000000
39 874.394173 0.000000
40 898.666721 0.000000
41 -745.305267 -2.272793
42 -568.545362 -7.201954
43 -494.170957 -0.096456
44 36175.248628 0.000000
45 -8292.572938 0.004882
46 -27882.675691 -0.004882
Forecast
Sepal Length Sepal Width Petal Length Petal Width Expected Predicted
0 5.1 3.5 1.4 0.2 1 1
1 4.9 3 1.4 0.2 1 1
2 4.7 3.2 1.3 0.2 1 1
3 4.6 3.1 1.5 0.2 1 1
4 5 3.6 1.4 0.2 1 1
5 5.4 3.9 1.7 0.4 1 1
6 4.6 3.4 1.4 0.3 1 1
7 5 3.4 1.5 0.2 1 1
8 4.4 2.9 1.4 0.2 1 1
9 4.9 3.1 1.5 0.1 1 1
10 5.4 3.7 1.5 0.2 1 1
11 4.8 3.4 1.6 0.2 1 1
12 4.8 3 1.4 0.1 1 1
13 4.3 3 1.1 0.1 1 1
14 5.8 4 1.2 0.2 1 1
15 5.7 4.4 1.5 0.4 1 1
16 5.4 3.9 1.3 0.4 1 1
17 5.1 3.5 1.4 0.3 1 1
18 5.7 3.8 1.7 0.3 1 1
19 5.1 3.8 1.5 0.3 1 1
20 5.4 3.4 1.7 0.2 1 1
21 5.1 3.7 1.5 0.4 1 1
22 4.6 3.6 1 0.2 1 1
23 5.1 3.3 1.7 0.5 1 1
24 4.8 3.4 1.9 0.2 1 1
25 5 3 1.6 0.2 1 1
26 5 3.4 1.6 0.4 1 1
27 5.2 3.5 1.5 0.2 1 1
28 5.2 3.4 1.4 0.2 1 1
29 4.7 3.2 1.6 0.2 1 1
30 4.8 3.1 1.6 0.2 1 1
31 5.4 3.4 1.5 0.4 1 1
32 5.2 4.1 1.5 0.1 1 1
33 5.5 4.2 1.4 0.2 1 1
34 4.9 3.1 1.5 0.1 1 1
35 5 3.2 1.2 0.2 1 1
36 5.5 3.5 1.3 0.2 1 1
37 4.9 3.1 1.5 0.1 1 1
38 4.4 3 1.3 0.2 1 1
39 5.1 3.4 1.5 0.2 1 1
40 5 3.5 1.3 0.3 1 1
41 4.5 2.3 1.3 0.3 1 1
42 4.4 3.2 1.3 0.2 1 1
43 5 3.5 1.6 0.6 1 1
44 5.1 3.8 1.9 0.4 1 1
45 4.8 3 1.4 0.3 1 1
46 5.1 3.8 1.6 0.2 1 1
47 4.6 3.2 1.4 0.2 1 1
48 5.3 3.7 1.5 0.2 1 1
49 5 3.3 1.4 0.2 1 1
50 7 3.2 4.7 1.4 2 2
51 6.4 3.2 4.5 1.5 2 2
52 6.9 3.1 4.9 1.5 2 2
53 5.5 2.3 4 1.3 2 2
54 6.5 2.8 4.6 1.5 2 2
55 5.7 2.8 4.5 1.3 2 2
56 6.3 3.3 4.7 1.6 2 2
57 4.9 2.4 3.3 1 2 2
58 6.6 2.9 4.6 1.3 2 2
59 5.2 2.7 3.9 1.4 2 2
60 5 2 3.5 1 2 2
61 5.9 3 4.2 1.5 2 2
62 6 2.2 4 1 2 2
63 6.1 2.9 4.7 1.4 2 2
64 5.6 2.9 3.6 1.3 2 2
65 6.7 3.1 4.4 1.4 2 2
66 5.6 3 4.5 1.5 2 2
67 5.8 2.7 4.1 1 2 2
68 6.2 2.2 4.5 1.5 2 2
69 5.6 2.5 3.9 1.1 2 2
70 5.9 3.2 4.8 1.8 2 2
71 6.1 2.8 4 1.3 2 2
72 6.3 2.5 4.9 1.5 2 2
73 6.1 2.8 4.7 1.2 2 2
74 6.4 2.9 4.3 1.3 2 2
75 6.6 3 4.4 1.4 2 2
76 6.8 2.8 4.8 1.4 2 2
77 6.7 3 5 1.7 2 2
78 6 2.9 4.5 1.5 2 2
79 5.7 2.6 3.5 1 2 2
80 5.5 2.4 3.8 1.1 2 2
81 5.5 2.4 3.7 1 2 2
82 5.8 2.7 3.9 1.2 2 2
83 6 2.7 5.1 1.6 2 3
84 5.4 3 4.5 1.5 2 2
85 6 3.4 4.5 1.6 2 2
86 6.7 3.1 4.7 1.5 2 2
87 6.3 2.3 4.4 1.3 2 2
88 5.6 3 4.1 1.3 2 2
89 5.5 2.5 4 1.3 2 2
90 5.5 2.6 4.4 1.2 2 2
91 6.1 3 4.6 1.4 2 2
92 5.8 2.6 4 1.2 2 2
93 5 2.3 3.3 1 2 2
94 5.6 2.7 4.2 1.3 2 2
95 5.7 3 4.2 1.2 2 2
96 5.7 2.9 4.2 1.3 2 2
97 6.2 2.9 4.3 1.3 2 2
98 5.1 2.5 3 1.1 2 2
99 5.7 2.8 4.1 1.3 2 2
100 6.3 3.3 6 2.5 3 3
101 5.8 2.7 5.1 1.9 3 3
102 7.1 3 5.9 2.1 3 3
103 6.3 2.9 5.6 1.8 3 3
104 6.5 3 5.8 2.2 3 3
105 7.6 3 6.6 2.1 3 3
106 4.9 2.5 4.5 1.7 3 3
107 7.3 2.9 6.3 1.8 3 3
108 6.7 2.5 5.8 1.8 3 3
109 7.2 3.6 6.1 2.5 3 3
110 6.5 3.2 5.1 2 3 3
111 6.4 2.7 5.3 1.9 3 3
112 6.8 3 5.5 2.1 3 3
113 5.7 2.5 5 2 3 3
114 5.8 2.8 5.1 2.4 3 3
115 6.4 3.2 5.3 2.3 3 3
116 6.5 3 5.5 1.8 3 3
117 7.7 3.8 6.7 2.2 3 3
118 7.7 2.6 6.9 2.3 3 3
119 6 2.2 5 1.5 3 3
120 6.9 3.2 5.7 2.3 3 3
121 5.6 2.8 4.9 2 3 3
122 7.7 2.8 6.7 2 3 3
123 6.3 2.7 4.9 1.8 3 3
124 6.7 3.3 5.7 2.1 3 3
125 7.2 3.2 6 1.8 3 3
126 6.2 2.8 4.8 1.8 3 3
127 6.1 3 4.9 1.8 3 3
128 6.4 2.8 5.6 2.1 3 3
129 7.2 3 5.8 1.6 3 3
130 7.4 2.8 6.1 1.9 3 3
131 7.9 3.8 6.4 2 3 3
132 6.4 2.8 5.6 2.2 3 3
133 6.3 2.8 5.1 1.5 3 3
134 6.1 2.6 5.6 1.4 3 3
135 7.7 3 6.1 2.3 3 3
136 6.3 3.4 5.6 2.4 3 3
137 6.4 3.1 5.5 1.8 3 3
138 6 3 4.8 1.8 3 3
139 6.9 3.1 5.4 2.1 3 3
140 6.7 3.1 5.6 2.4 3 3
141 6.9 3.1 5.1 2.3 3 3
142 5.8 2.7 5.1 1.9 3 3
143 6.8 3.2 5.9 2.3 3 3
144 6.7 3.3 5.7 2.5 3 3
145 6.7 3 5.2 2.3 3 3
146 6.3 2.5 5 1.9 3 3
147 6.5 3 5.2 2 3 3
148 6.2 3.4 5.4 2.3 3 3
149 5.9 3 5.1 1.8 3 3
***********************************************
--> Cross-Entropy Error: 4.640623
--> Classification Error: 0.006666667
***********************************************
****************Time: 17.055
Cross-Entropy Error Value = 4.6406232788035595
Link to Java source.