/*
* -------------------------------------------------------------------------
* $Id: ForecastNet.java,v 1.3 2005/03/14 14:57:10 estewart Exp $
* -------------------------------------------------------------------------
* Copyright (c) 1999 Visual Numerics Inc. All Rights Reserved.
*
* This software is confidential information which is proprietary to
* and a trade secret of Visual Numerics, Inc. Use, duplication or
* disclosure is subject to the terms of an appropriate license
* agreement.
*
* VISUAL NUMERICS MAKES NO REPRESENTATIONS OR WARRANTIES ABOUT THE
* SUITABILITY OF THE SOFTWARE, EITHER EXPRESS OR IMPLIED, INCLUDING
* BUT NOT LIMITED TO THE IMPLIED WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE, OR NON-INFRINGEMENT. VISUAL
* NUMERICS SHALL NOT BE LIABLE FOR ANY DAMAGES SUFFERED BY LICENSEE
* AS A RESULT OF USING, MODIFYING OR DISTRIBUTING THIS SOFTWARE OR
* ITS DERIVATIVES.
*--------------------------------------------------------------------------
*/
package com.imsl.demo.nnet;
import javax.swing.*;
import javax.swing.event.*;
import java.awt.*;
import java.awt.event.*;
import java.util.*;
import java.text.*;
import com.imsl.stat.*;
import com.imsl.chart.*;
import com.imsl.datamining.neural.*;
import com.imsl.demo.gallery.Describe;
public class ForecastNet extends JFrameChart {
private Chart chart;
private AxisXY axis;
private Data data, forecastData, nnData, realData;
private JRadioButtonMenuItem[] jRadioButtons;
private JCheckBox jCheckData, jCheckArma, jCheckNnet;
private JPanel jCheckPanel;
private final int numData = 3;
private final String dataName[] = {"Oil", "Sales", "Sunspots"};
private ARMA arma;
private SimpleDateFormat dateFormat = new SimpleDateFormat("M/yyyy");
private double[] y, x, yRange, xsub, ysub;
private int[] yearRange;
private String yTitle;
private int dataIncrement, seed;
private FeedForwardNetwork network;
private EpochTrainer trainer;
private Activation hiddenLayerActivation = Activation.LOGISTIC;
private Activation outputLayerActivation = Activation.LINEAR;
private double weight[], gradient[], stats[];
int nOutputs = 1;
int nPerceptrons = 8;
int nLayers = 1;
int nLags = 12;
int nPredicts = 12;
private String errorMsg;
public ForecastNet(boolean exitOnClose) {
setTitle("Neural Net Forecasting");
if (!exitOnClose) {
// remove the WindowListener, installed by JFrameChart, that
// exits the application when the window is closed.
Object l[] = getListeners(java.awt.event.WindowListener.class);
for (int k = 0; k < l.length; k++) {
removeWindowListener((java.awt.event.WindowListener)l[k]);
}
}
Describe des = new Describe(this, "/com/imsl/demo/nnet/ForecastNet.html");
des.show();
Dimension ds = des.getSize();
Dimension ss = getToolkit().getScreenSize();
//int h = Math.min(ss.width/2, ss.height-64);
int h = Math.min(ss.width/2, ss.height-ds.height-32);
int w = (int)(h/0.8);
setSize(w, h);
//setLocation(32, 32);
setLocation(ss.width-ds.width, ds.height);
// Set default values.
chart = getChart();
axis = new AxisXY(chart);
axis.getAxisX().getAxisLabel().setTextFormat(dateFormat);
axis.getAxisX().getAxisLabel().setTextAngle(90);
axis.getAxisX().getAxisTitle().setTitle("Date");
axis.getAxisY().setTextFormat(new java.text.DecimalFormat("###,###"));
chart.getLegend().setPaint(true);
jCheckPanel = new JPanel();
jCheckArma = new JCheckBox("ARMA Forecast ");
jCheckNnet = new JCheckBox("Neural Net Forecast ");
jCheckData = new JCheckBox("Actual outcome ");
jCheckArma.setSelected(true);
jCheckPanel.add(jCheckArma);
jCheckPanel.add(jCheckNnet);
jCheckPanel.add(jCheckData);
java.awt.event.ActionListener cal = new java.awt.event.ActionListener() {
public void actionPerformed(java.awt.event.ActionEvent evt) {
update();
}
};
jCheckArma.addActionListener(cal);
jCheckNnet.addActionListener(cal);
jCheckData.addActionListener(cal);
getContentPane().add(jCheckPanel, java.awt.BorderLayout.NORTH);
JMenuBar jMenu = this.getJMenuBar();
JMenu menuData = new JMenu();
menuData.setMnemonic('D');
menuData.setText("Data");
jMenu.add(menuData);
jRadioButtons = new javax.swing.JRadioButtonMenuItem[numData];
ButtonGroup group = new ButtonGroup();
for (int i=0; i<numData; i++) {
jRadioButtons[i] = new javax.swing.JRadioButtonMenuItem();
jRadioButtons[i].setText(dataName[i]);
jRadioButtons[i].addActionListener(new java.awt.event.ActionListener() {
public void actionPerformed(java.awt.event.ActionEvent evt) {
resetChecks();
getData();
update();
}
});
group.add(jRadioButtons[i]);
menuData.add(jRadioButtons[i]);
}
jRadioButtons[1].setSelected(true);
getData();
drawGraph();
setResizable(false);
}
private void resetChecks() {
jCheckArma.setSelected(true);
jCheckNnet.setSelected(false);
jCheckData.setSelected(false);
}
// load up the selected data set
private void getData() {
int selected = -1;
for (int i=0; i<numData; i++) {
if (jRadioButtons[i].isSelected()) {
selected = i;
break;
}
}
switch (selected) {
case 2: // sunspot data
y = new double[] {100.8, 81.6, 66.5, 34.8, 30.6, 7, 19.8, 92.5,
154.4, 125.9, 84.8, 68.1, 38.5, 22.8, 10.2, 24.1, 82.9,
132, 130.9, 118.1, 89.9, 66.6, 60, 46.9, 41, 21.3, 16,
6.4, 4.1, 6.8, 14.5, 34, 45, 43.1, 47.5, 42.2, 28.1, 10.1,
8.1, 2.5, 0, 1.4, 5, 12.2, 13.9, 35.4, 45.8, 41.1, 30.4,
23.9, 15.7, 6.6, 4, 1.8, 8.5, 16.6, 36.3, 49.7, 62.5, 67,
71, 47.8, 27.5, 8.5, 13.2, 56.9, 121.5, 138.3, 103.2,
85.8, 63.2, 36.8, 24.2, 10.7, 15, 40.1, 61.5, 98.5, 124.3,
95.9, 66.5, 64.5, 54.2, 39, 20.6, 6.7, 4.3, 22.8, 54.8,
93.8, 95.7, 77.2, 59.1, 44, 47, 30.5, 16.3, 7.3, 37.3,
73.9};
yearRange = new int[] {1770, 1858, 1874};
yRange = new double[] {-50, 200};
dataIncrement = GregorianCalendar.YEAR;
yTitle = "Number of Sun Spots per Year";
seed = 1593;
//best so far: 8765, 369, 1593=123457
//not bad: 45678
break;
case 1: // champagne sales data
y = new double[] {2815, 2672, 2755, 2721, 2946, 3036, 2282, 2212, 2922, 4301, 5764, 7132,
2541, 2475, 3031, 3266, 3776, 3230, 3028, 1759, 3595, 4474, 6838, 8357,
3113, 3006, 4047, 3523, 3937, 3986, 3260, 1573, 3528, 5211, 7614, 9254,
5375, 3088, 3718, 4514, 4520, 4539, 3663, 1643, 4739, 5428, 8314, 10651,
3633, 4292, 4154, 4121, 4647, 4753, 3965, 1723, 5048, 6922, 9858, 11331,
4016, 3957, 4510, 4276, 4968, 4677, 3523, 1821, 5222, 6873, 10803, 13916,
2639, 2899, 3370, 3740, 2927, 3986, 4217, 1738, 5221, 6424, 9842, 13076,
3934, 3162, 4286, 4676, 5010, 4874, 4633, 1659, 5951, 6981, 9851, 12670};
yearRange = new int[] {1964, 1971, 1972};
yRange = new double[] {0, 14000};
dataIncrement = GregorianCalendar.MONTH;
yTitle = "Cases of Champagne Sold per Month";
seed = 6789;
//best: 6789
//not bad 23456,65432
break;
case 0: // west texas crude oil price
y = new double[] {23,15.4,12.6,12.8,15.4,13.5,11.5,15,14.9,14.9,
15.1,16.1,18.6,17.7,18.3,18.6,19.4,20,21.3,20.2,19.5,19.8,19,
17.2,17.1,16.7,16.2,17.8,17.4,16.5,15.5,15.5,14.4,13.8,13.9,
16.2,17.9,17.8,19.4,21,20,20,19.6,18.5,19.5,20,19.8,21,22.6,
22.1,20.4,18.5,18.2,16.8,18.6,27.1,33.6,36,32.3,27.3,24.9,20.5,
19.8,20.8,21.2,20.2,21.4,21.6,21.8,23.2,22.4,19.5,18.8,19,18.9,
20.2,20.9,22.3,21.7,21.3,21.9,21.6,20.3,19.4,19,20,20.3,20.2,
19.9,19,17.8,18,17.5,18.1,16.6,14.5,15,14.7,14.6,16.3,17.8,19,
19.6,18.3,17.4,17.7,18.1,17.1,17.9,18.5,18.5,19.8,19.7,18.4,
17.3,18,18.2,17.4,17.9,19,18.8,19,21.3,23.5,21.2,20.4,21.3,21.9,
23.9,24.9,23.7,25.3,25.1,22.2,20.9,19.7,20.8,19.1,19.6,19.9,
19.7,21.2,20.1,18.3,16.7,16,15,15.4,14.8,13.6,14,13.3,14.9,14.3,
12.9,11.2,12.4,12,14.6,17.3,17.7,17.8,20,21.2,23.8,22.6,24.8,
26.1,27.2,29.3,29.8,25.7,28.7,31.8,29.7,31.2,33.8,33,34.4,28.4,
29.5,29.6,27.2,27.4,28.6,27.6,26.4,27.4,25.8,22.2,19.6,19.3,
19.6,20.7,24.4,26.2,27,25.5,26.9,28.3,29.6,28.8,26.2};
yearRange = new int[] {1986, 2002, 2003};
yRange = new double[] {0, 50};
dataIncrement = GregorianCalendar.MONTH;
yTitle = "West Texas Crude Oil, Dollars per Barrel";
seed = 764;
// 2468 over, 87654, 8765, 764
break;
}
x = new double[y.length];
for (int i = 0; i < y.length; i++) {
switch (dataIncrement) {
case GregorianCalendar.YEAR:
x[i] = (new GregorianCalendar(yearRange[0]+i, 0, 1)).getTimeInMillis();
break;
case GregorianCalendar.MONTH:
x[i] = (new GregorianCalendar(yearRange[0], i, 1)).getTimeInMillis();
break;
}
}
ysub = new double[y.length-nPredicts];
xsub = new double[x.length-nPredicts];
for (int i=0; i<ysub.length; i++) {
ysub[i] = y[i];
xsub[i] = x[i];
}
trainNetwork();
}
// Draw the graph.
private void drawGraph() {
axis.getAxisX().setAutoscaleInput(axis.AUTOSCALE_OFF);
axis.getAxisX().setWindow((new GregorianCalendar(yearRange[0], 0, 1)).getTimeInMillis(),
(new GregorianCalendar(yearRange[2], 0, 1)).getTimeInMillis());
axis.getAxisY().setAutoscaleInput(axis.AUTOSCALE_OFF);
axis.getAxisY().setWindow(yRange);
axis.getAxisY().getAxisTitle().setTitle(yTitle);
// Draw the actual data.
data = new Data(axis, xsub, ysub);
data.setTitle("Historical data");
data.setDataType(Data.DATA_TYPE_LINE | Data.DATA_TYPE_MARKER);
data.setMarkerType(Data.MARKER_TYPE_HOLLOW_CIRCLE);
data.setMarkerSize(0.50);
data.setMarkerColor(Color.BLUE);
data.setLineColor(Color.BLUE);
// Use ARMA class (ARMA(2,1)) to get the forecast using backorigin = 3.
// Number of predicts and confidence level are from the interface.
try {
arma = new ARMA(2, 1, ysub);
arma.setRelativeError(0.0);
arma.setMaxIterations(0);
arma.compute();
arma.setConfidence(0.50);
arma.setBackwardOrigin(4);
double[][] forecast = arma.forecast(nPredicts);
double[] tmp = new double[nPredicts];
double[] x1 = new double[nPredicts];
// Draw the forecast data.
for (int i = 0; i < nPredicts; i++) {
tmp[i] = forecast[i][3];
switch (dataIncrement) {
case GregorianCalendar.YEAR:
x1[i] = (new GregorianCalendar(yearRange[1]+i, 0, 1)).getTimeInMillis();
break;
case GregorianCalendar.MONTH:
x1[i] = (new GregorianCalendar(yearRange[1], i, 1)).getTimeInMillis();
break;
}
}
double[] y2 = new double[nPredicts+1];
double[] x2 = new double[nPredicts+1];
for (int i=0; i<y2.length; i++) {
y2[i] = y[i+ysub.length-1];
}
System.arraycopy(x1, 0, x2, 1, x1.length);
if (dataIncrement == GregorianCalendar.YEAR) {
x2[0] = (new GregorianCalendar(yearRange[1]-1, 0, 1)).getTimeInMillis();
} else {
x2[0] = (new GregorianCalendar(yearRange[1], -1, 1)).getTimeInMillis();
}
if (jCheckData.isSelected()) {
realData = new Data(axis, x2, y2);
realData.setTitle("Actual outcome");
realData.setDataType(Data.DATA_TYPE_LINE);
realData.setLineColor(Color.BLUE);
//realData.setLineWidth(1.0);
realData.setLineDashPattern(Data.DASH_PATTERN_DOT);
}
if (jCheckArma.isSelected()) {
//forecastData = new Data(axis, x1, tmp);
forecastData = new Data(axis, prepend(x2[0], x1), prepend(y2[0], tmp));
forecastData.setTitle("ARMA Forecast");
forecastData.setDataType(Data.DATA_TYPE_LINE);
forecastData.setLineColor(Color.RED);
//forecastData.setLineWidth(2.0);
}
// Now do a FFNNET forecast
// network was trained when new data were selected
// Now the thing is trained an we need new data to push through
// start with the full data in lagged[0] and then add results
// to the end as they're forecast
int offset = 0;
double[] ins = new double[nLags];
double[] sol = new double[nPredicts+offset];
System.arraycopy(y, y.length-nLags, ins, 0, nLags);
for (int i=0; i<nPredicts+offset; i++) {
double[] nnforecast = network.forecast(ins);
sol[i] = nnforecast[0];
for (int j=0; j<nLags-1; j++) {
ins[j] = ins[j+1];
}
ins[nLags-1] = nnforecast[0];
}
if (jCheckNnet.isSelected()) {
//nnData = new Data(axis, x1, sol);
nnData = new Data(axis, prepend(x2[0], x1), prepend(y2[0], sol));
nnData.setTitle("FFNN Forecast");
nnData.setDataType(Data.DATA_TYPE_LINE);
nnData.setLineColor(Color.MAGENTA);
//nnData.setLineWidth(2.0);
}
} catch (com.imsl.IMSLException e) {
System.out.println(e.getMessage());
}
}
private double[] prepend (double v, double[] a) {
double[] c = new double[a.length+1];
c[0] = v;
System.arraycopy(a, 0, c, 1, a.length);
return c;
}
private void trainNetwork() {
//System.out.print("Training...");
//long start = System.currentTimeMillis();
this.setCursor(new java.awt.Cursor(java.awt.Cursor.WAIT_CURSOR));
double[] inits = read1D("weights"+seed, 113);
// lagged values of the timeseries are inputs
// actual values are the output
errorMsg = "";
double[][] lagged = lagData(ysub, nLags);
double[][] outs = new double[ysub.length-nLags][1];
for (int i=0; i<ysub.length-nLags; i++) {
outs[i][0] = ysub[i+nLags];
}
network = new FeedForwardNetwork();
network.getInputLayer().createInputs(nLags);
for (int i=0; i<nLayers; i++) {
network.createHiddenLayer().createPerceptrons(nPerceptrons);
}
network.getOutputLayer().createPerceptrons(1);
network.linkAll();
Perceptron perceptrons[] = network.getPerceptrons();
// Set all perceptrons activation function
for (int i=0; i < perceptrons.length-nOutputs; i++) {
perceptrons[i].setActivation(hiddenLayerActivation);
}
perceptrons[perceptrons.length-nOutputs].setActivation(outputLayerActivation);
network.setWeights(inits);
/*
QuasiNewtonTrainer stage1trainer = new QuasiNewtonTrainer();
stage1trainer.setMaximumTrainingIterations(500);
QuasiNewtonTrainer stage2trainer = new QuasiNewtonTrainer();
stage2trainer.setMaximumTrainingIterations(1000);
trainer = new EpochTrainer(stage1trainer, stage2trainer);
trainer.setEpochSize(ysub.length-nLags);
trainer.setNumberOfEpochs(30);
trainer.setRandom(new com.imsl.stat.Random(seed));
trainer.setRandomSamples(new com.imsl.stat.Random(123),
new com.imsl.stat.Random(678));
trainer.train(network, lagged, outs);
switch(trainer.getErrorStatus()){
case 0: errorMsg = errorMsg0;
break;
case 1: errorMsg = errorMsg1;
break;
case 2: errorMsg = errorMsg2;
break;
case 3: errorMsg = errorMsg3;
break;
case 4: errorMsg = errorMsg4;
break;
case 5: errorMsg = errorMsg5;
break;
default:errorMsg = errorMsg0;
}
if (!errorMsg.equals(errorMsg0)) {
javax.swing.JOptionPane.showMessageDialog(this,errorMsg,"Training Error",javax.swing.JOptionPane.ERROR_MESSAGE);
}
// stats of the trained network
stats = network.computeStatistics(lagged, outs);
weight = network.getWeights();
gradient = trainer.getErrorGradient();
*/
//no need to benchmark when reading a file :)
//float dur = (System.currentTimeMillis()-start)/1000.0f;
//System.out.println(" Done: "+dur+" seconds.");
this.setCursor(new java.awt.Cursor(java.awt.Cursor.DEFAULT_CURSOR));
// network was trained once and weights written to disk
// this takes approx. 95 seconds, 8 seconds, 40 seconds for each data set
//new com.imsl.math.PrintMatrix("weights").print(weight);
//com.imsl.util.IOHelp.write1D("weights"+seed,weight);
}
// lag the data. would use the TimeSeriesFilter, but y is 1D
private double[][] lagData(double[] a, int n) {
int len = a.length;
double[][] b = new double[len-n][n];
for (int i=0; i<len-n; i++) {
for (int j=0; j<n; j++) {
b[i][j] = a[i+j];
}
}
return b;
}
// Get the information from the interface and redraw the chart.
private void update() {
if (data != null) data.remove();
if (forecastData != null) forecastData.remove();
if (nnData != null) nnData.remove();
if (realData != null) realData.remove();
drawGraph();
repaint();
}
public static void main(String args[]) {
boolean exitOnClose = true;
if (args.length > 0 && args[0].equals("-noexit")) exitOnClose = false;
new ForecastNet(exitOnClose).show();
}
/**
* Read a double array from a file.
*
* @param fileName the name of the file to read
* @param len the length of the file
* @return the double array
*/
double[] read1D(String fileName, int len) {
double[] d = new double[len];
boolean done = false;
int i = 0;
String line;
try {
java.io.InputStream is = getClass().getResourceAsStream(fileName+".dat");
java.io.BufferedReader br = new java.io.BufferedReader(new java.io.InputStreamReader(is));
//java.io.BufferedReader br = new java.io.BufferedReader(new java.io.FileReader(fileName));
while (!done) {
try {
line = br.readLine();
if (line == null) {
done = true;
br.close();
} else {
d[i++] = (Double.valueOf(line.trim())).doubleValue();
}
} catch (NumberFormatException nfe) {
System.out.println("NumberFormatException: " + nfe.getMessage());
} catch (java.io.EOFException eof) {
done = true;
br.close();
} catch (ArrayIndexOutOfBoundsException obe) {
done = true;
br.close();
}
}
} catch (java.io.IOException ioe) {
System.err.println("FileIO: " + ioe.getMessage());
}
return d;
}
// Error Status Messages for the Least Squares Trainer
private static String errorMsg0 =
"Least Squares Training Completed Successfully";
private static String errorMsg1 =
"Scaled step tolerance was satisfied. The current solution \n"+
"may be an approximate local solution, or the algorithm is making\n"+
"slow progress and is not near a solution, or the Step Tolerance\n"+
"is too big";
private static String errorMsg2 =
"Scaled actual and predicted reductions in the function are\n"+
"less than or equal to the relative function convergence\n"+
"tolerance RelativeTolerance";
private static String errorMsg3 =
"Iterates appear to be converging to a noncritical point.\n"+
"Incorrect gradient information, a discontinuous function,\n"+
"or stopping tolerances being too tight may be the cause.";
private static String errorMsg4 =
"Five consecutive steps with the maximum stepsize have\n"+
"been taken. Either the function is unbounded below, or has\n"+
"a finite asymptote in some direction, or the maximum stepsize\n"+
"is too small.";
private static String errorMsg5 =
"Too many iterations required";
}