/* * ------------------------------------------------------------------------- * $Id: ForecastNet.java,v 1.3 2005/03/14 14:57:10 estewart Exp $ * ------------------------------------------------------------------------- * Copyright (c) 1999 Visual Numerics Inc. All Rights Reserved. * * This software is confidential information which is proprietary to * and a trade secret of Visual Numerics, Inc. Use, duplication or * disclosure is subject to the terms of an appropriate license * agreement. * * VISUAL NUMERICS MAKES NO REPRESENTATIONS OR WARRANTIES ABOUT THE * SUITABILITY OF THE SOFTWARE, EITHER EXPRESS OR IMPLIED, INCLUDING * BUT NOT LIMITED TO THE IMPLIED WARRANTIES OF MERCHANTABILITY, * FITNESS FOR A PARTICULAR PURPOSE, OR NON-INFRINGEMENT. VISUAL * NUMERICS SHALL NOT BE LIABLE FOR ANY DAMAGES SUFFERED BY LICENSEE * AS A RESULT OF USING, MODIFYING OR DISTRIBUTING THIS SOFTWARE OR * ITS DERIVATIVES. *-------------------------------------------------------------------------- */ package com.imsl.demo.nnet; import javax.swing.*; import javax.swing.event.*; import java.awt.*; import java.awt.event.*; import java.util.*; import java.text.*; import com.imsl.stat.*; import com.imsl.chart.*; import com.imsl.datamining.neural.*; import com.imsl.demo.gallery.Describe; public class ForecastNet extends JFrameChart { private Chart chart; private AxisXY axis; private Data data, forecastData, nnData, realData; private JRadioButtonMenuItem[] jRadioButtons; private JCheckBox jCheckData, jCheckArma, jCheckNnet; private JPanel jCheckPanel; private final int numData = 3; private final String dataName[] = {"Oil", "Sales", "Sunspots"}; private ARMA arma; private SimpleDateFormat dateFormat = new SimpleDateFormat("M/yyyy"); private double[] y, x, yRange, xsub, ysub; private int[] yearRange; private String yTitle; private int dataIncrement, seed; private FeedForwardNetwork network; private EpochTrainer trainer; private Activation hiddenLayerActivation = Activation.LOGISTIC; private Activation outputLayerActivation = Activation.LINEAR; private double weight[], gradient[], stats[]; int nOutputs = 1; int nPerceptrons = 8; int nLayers = 1; int nLags = 12; int nPredicts = 12; private String errorMsg; public ForecastNet(boolean exitOnClose) { setTitle("Neural Net Forecasting"); if (!exitOnClose) { // remove the WindowListener, installed by JFrameChart, that // exits the application when the window is closed. Object l[] = getListeners(java.awt.event.WindowListener.class); for (int k = 0; k < l.length; k++) { removeWindowListener((java.awt.event.WindowListener)l[k]); } } Describe des = new Describe(this, "/com/imsl/demo/nnet/ForecastNet.html"); des.show(); Dimension ds = des.getSize(); Dimension ss = getToolkit().getScreenSize(); //int h = Math.min(ss.width/2, ss.height-64); int h = Math.min(ss.width/2, ss.height-ds.height-32); int w = (int)(h/0.8); setSize(w, h); //setLocation(32, 32); setLocation(ss.width-ds.width, ds.height); // Set default values. chart = getChart(); axis = new AxisXY(chart); axis.getAxisX().getAxisLabel().setTextFormat(dateFormat); axis.getAxisX().getAxisLabel().setTextAngle(90); axis.getAxisX().getAxisTitle().setTitle("Date"); axis.getAxisY().setTextFormat(new java.text.DecimalFormat("###,###")); chart.getLegend().setPaint(true); jCheckPanel = new JPanel(); jCheckArma = new JCheckBox("ARMA Forecast "); jCheckNnet = new JCheckBox("Neural Net Forecast "); jCheckData = new JCheckBox("Actual outcome "); jCheckArma.setSelected(true); jCheckPanel.add(jCheckArma); jCheckPanel.add(jCheckNnet); jCheckPanel.add(jCheckData); java.awt.event.ActionListener cal = new java.awt.event.ActionListener() { public void actionPerformed(java.awt.event.ActionEvent evt) { update(); } }; jCheckArma.addActionListener(cal); jCheckNnet.addActionListener(cal); jCheckData.addActionListener(cal); getContentPane().add(jCheckPanel, java.awt.BorderLayout.NORTH); JMenuBar jMenu = this.getJMenuBar(); JMenu menuData = new JMenu(); menuData.setMnemonic('D'); menuData.setText("Data"); jMenu.add(menuData); jRadioButtons = new javax.swing.JRadioButtonMenuItem[numData]; ButtonGroup group = new ButtonGroup(); for (int i=0; i<numData; i++) { jRadioButtons[i] = new javax.swing.JRadioButtonMenuItem(); jRadioButtons[i].setText(dataName[i]); jRadioButtons[i].addActionListener(new java.awt.event.ActionListener() { public void actionPerformed(java.awt.event.ActionEvent evt) { resetChecks(); getData(); update(); } }); group.add(jRadioButtons[i]); menuData.add(jRadioButtons[i]); } jRadioButtons[1].setSelected(true); getData(); drawGraph(); setResizable(false); } private void resetChecks() { jCheckArma.setSelected(true); jCheckNnet.setSelected(false); jCheckData.setSelected(false); } // load up the selected data set private void getData() { int selected = -1; for (int i=0; i<numData; i++) { if (jRadioButtons[i].isSelected()) { selected = i; break; } } switch (selected) { case 2: // sunspot data y = new double[] {100.8, 81.6, 66.5, 34.8, 30.6, 7, 19.8, 92.5, 154.4, 125.9, 84.8, 68.1, 38.5, 22.8, 10.2, 24.1, 82.9, 132, 130.9, 118.1, 89.9, 66.6, 60, 46.9, 41, 21.3, 16, 6.4, 4.1, 6.8, 14.5, 34, 45, 43.1, 47.5, 42.2, 28.1, 10.1, 8.1, 2.5, 0, 1.4, 5, 12.2, 13.9, 35.4, 45.8, 41.1, 30.4, 23.9, 15.7, 6.6, 4, 1.8, 8.5, 16.6, 36.3, 49.7, 62.5, 67, 71, 47.8, 27.5, 8.5, 13.2, 56.9, 121.5, 138.3, 103.2, 85.8, 63.2, 36.8, 24.2, 10.7, 15, 40.1, 61.5, 98.5, 124.3, 95.9, 66.5, 64.5, 54.2, 39, 20.6, 6.7, 4.3, 22.8, 54.8, 93.8, 95.7, 77.2, 59.1, 44, 47, 30.5, 16.3, 7.3, 37.3, 73.9}; yearRange = new int[] {1770, 1858, 1874}; yRange = new double[] {-50, 200}; dataIncrement = GregorianCalendar.YEAR; yTitle = "Number of Sun Spots per Year"; seed = 1593; //best so far: 8765, 369, 1593=123457 //not bad: 45678 break; case 1: // champagne sales data y = new double[] {2815, 2672, 2755, 2721, 2946, 3036, 2282, 2212, 2922, 4301, 5764, 7132, 2541, 2475, 3031, 3266, 3776, 3230, 3028, 1759, 3595, 4474, 6838, 8357, 3113, 3006, 4047, 3523, 3937, 3986, 3260, 1573, 3528, 5211, 7614, 9254, 5375, 3088, 3718, 4514, 4520, 4539, 3663, 1643, 4739, 5428, 8314, 10651, 3633, 4292, 4154, 4121, 4647, 4753, 3965, 1723, 5048, 6922, 9858, 11331, 4016, 3957, 4510, 4276, 4968, 4677, 3523, 1821, 5222, 6873, 10803, 13916, 2639, 2899, 3370, 3740, 2927, 3986, 4217, 1738, 5221, 6424, 9842, 13076, 3934, 3162, 4286, 4676, 5010, 4874, 4633, 1659, 5951, 6981, 9851, 12670}; yearRange = new int[] {1964, 1971, 1972}; yRange = new double[] {0, 14000}; dataIncrement = GregorianCalendar.MONTH; yTitle = "Cases of Champagne Sold per Month"; seed = 6789; //best: 6789 //not bad 23456,65432 break; case 0: // west texas crude oil price y = new double[] {23,15.4,12.6,12.8,15.4,13.5,11.5,15,14.9,14.9, 15.1,16.1,18.6,17.7,18.3,18.6,19.4,20,21.3,20.2,19.5,19.8,19, 17.2,17.1,16.7,16.2,17.8,17.4,16.5,15.5,15.5,14.4,13.8,13.9, 16.2,17.9,17.8,19.4,21,20,20,19.6,18.5,19.5,20,19.8,21,22.6, 22.1,20.4,18.5,18.2,16.8,18.6,27.1,33.6,36,32.3,27.3,24.9,20.5, 19.8,20.8,21.2,20.2,21.4,21.6,21.8,23.2,22.4,19.5,18.8,19,18.9, 20.2,20.9,22.3,21.7,21.3,21.9,21.6,20.3,19.4,19,20,20.3,20.2, 19.9,19,17.8,18,17.5,18.1,16.6,14.5,15,14.7,14.6,16.3,17.8,19, 19.6,18.3,17.4,17.7,18.1,17.1,17.9,18.5,18.5,19.8,19.7,18.4, 17.3,18,18.2,17.4,17.9,19,18.8,19,21.3,23.5,21.2,20.4,21.3,21.9, 23.9,24.9,23.7,25.3,25.1,22.2,20.9,19.7,20.8,19.1,19.6,19.9, 19.7,21.2,20.1,18.3,16.7,16,15,15.4,14.8,13.6,14,13.3,14.9,14.3, 12.9,11.2,12.4,12,14.6,17.3,17.7,17.8,20,21.2,23.8,22.6,24.8, 26.1,27.2,29.3,29.8,25.7,28.7,31.8,29.7,31.2,33.8,33,34.4,28.4, 29.5,29.6,27.2,27.4,28.6,27.6,26.4,27.4,25.8,22.2,19.6,19.3, 19.6,20.7,24.4,26.2,27,25.5,26.9,28.3,29.6,28.8,26.2}; yearRange = new int[] {1986, 2002, 2003}; yRange = new double[] {0, 50}; dataIncrement = GregorianCalendar.MONTH; yTitle = "West Texas Crude Oil, Dollars per Barrel"; seed = 764; // 2468 over, 87654, 8765, 764 break; } x = new double[y.length]; for (int i = 0; i < y.length; i++) { switch (dataIncrement) { case GregorianCalendar.YEAR: x[i] = (new GregorianCalendar(yearRange[0]+i, 0, 1)).getTimeInMillis(); break; case GregorianCalendar.MONTH: x[i] = (new GregorianCalendar(yearRange[0], i, 1)).getTimeInMillis(); break; } } ysub = new double[y.length-nPredicts]; xsub = new double[x.length-nPredicts]; for (int i=0; i<ysub.length; i++) { ysub[i] = y[i]; xsub[i] = x[i]; } trainNetwork(); } // Draw the graph. private void drawGraph() { axis.getAxisX().setAutoscaleInput(axis.AUTOSCALE_OFF); axis.getAxisX().setWindow((new GregorianCalendar(yearRange[0], 0, 1)).getTimeInMillis(), (new GregorianCalendar(yearRange[2], 0, 1)).getTimeInMillis()); axis.getAxisY().setAutoscaleInput(axis.AUTOSCALE_OFF); axis.getAxisY().setWindow(yRange); axis.getAxisY().getAxisTitle().setTitle(yTitle); // Draw the actual data. data = new Data(axis, xsub, ysub); data.setTitle("Historical data"); data.setDataType(Data.DATA_TYPE_LINE | Data.DATA_TYPE_MARKER); data.setMarkerType(Data.MARKER_TYPE_HOLLOW_CIRCLE); data.setMarkerSize(0.50); data.setMarkerColor(Color.BLUE); data.setLineColor(Color.BLUE); // Use ARMA class (ARMA(2,1)) to get the forecast using backorigin = 3. // Number of predicts and confidence level are from the interface. try { arma = new ARMA(2, 1, ysub); arma.setRelativeError(0.0); arma.setMaxIterations(0); arma.compute(); arma.setConfidence(0.50); arma.setBackwardOrigin(4); double[][] forecast = arma.forecast(nPredicts); double[] tmp = new double[nPredicts]; double[] x1 = new double[nPredicts]; // Draw the forecast data. for (int i = 0; i < nPredicts; i++) { tmp[i] = forecast[i][3]; switch (dataIncrement) { case GregorianCalendar.YEAR: x1[i] = (new GregorianCalendar(yearRange[1]+i, 0, 1)).getTimeInMillis(); break; case GregorianCalendar.MONTH: x1[i] = (new GregorianCalendar(yearRange[1], i, 1)).getTimeInMillis(); break; } } double[] y2 = new double[nPredicts+1]; double[] x2 = new double[nPredicts+1]; for (int i=0; i<y2.length; i++) { y2[i] = y[i+ysub.length-1]; } System.arraycopy(x1, 0, x2, 1, x1.length); if (dataIncrement == GregorianCalendar.YEAR) { x2[0] = (new GregorianCalendar(yearRange[1]-1, 0, 1)).getTimeInMillis(); } else { x2[0] = (new GregorianCalendar(yearRange[1], -1, 1)).getTimeInMillis(); } if (jCheckData.isSelected()) { realData = new Data(axis, x2, y2); realData.setTitle("Actual outcome"); realData.setDataType(Data.DATA_TYPE_LINE); realData.setLineColor(Color.BLUE); //realData.setLineWidth(1.0); realData.setLineDashPattern(Data.DASH_PATTERN_DOT); } if (jCheckArma.isSelected()) { //forecastData = new Data(axis, x1, tmp); forecastData = new Data(axis, prepend(x2[0], x1), prepend(y2[0], tmp)); forecastData.setTitle("ARMA Forecast"); forecastData.setDataType(Data.DATA_TYPE_LINE); forecastData.setLineColor(Color.RED); //forecastData.setLineWidth(2.0); } // Now do a FFNNET forecast // network was trained when new data were selected // Now the thing is trained an we need new data to push through // start with the full data in lagged[0] and then add results // to the end as they're forecast int offset = 0; double[] ins = new double[nLags]; double[] sol = new double[nPredicts+offset]; System.arraycopy(y, y.length-nLags, ins, 0, nLags); for (int i=0; i<nPredicts+offset; i++) { double[] nnforecast = network.forecast(ins); sol[i] = nnforecast[0]; for (int j=0; j<nLags-1; j++) { ins[j] = ins[j+1]; } ins[nLags-1] = nnforecast[0]; } if (jCheckNnet.isSelected()) { //nnData = new Data(axis, x1, sol); nnData = new Data(axis, prepend(x2[0], x1), prepend(y2[0], sol)); nnData.setTitle("FFNN Forecast"); nnData.setDataType(Data.DATA_TYPE_LINE); nnData.setLineColor(Color.MAGENTA); //nnData.setLineWidth(2.0); } } catch (com.imsl.IMSLException e) { System.out.println(e.getMessage()); } } private double[] prepend (double v, double[] a) { double[] c = new double[a.length+1]; c[0] = v; System.arraycopy(a, 0, c, 1, a.length); return c; } private void trainNetwork() { //System.out.print("Training..."); //long start = System.currentTimeMillis(); this.setCursor(new java.awt.Cursor(java.awt.Cursor.WAIT_CURSOR)); double[] inits = read1D("weights"+seed, 113); // lagged values of the timeseries are inputs // actual values are the output errorMsg = ""; double[][] lagged = lagData(ysub, nLags); double[][] outs = new double[ysub.length-nLags][1]; for (int i=0; i<ysub.length-nLags; i++) { outs[i][0] = ysub[i+nLags]; } network = new FeedForwardNetwork(); network.getInputLayer().createInputs(nLags); for (int i=0; i<nLayers; i++) { network.createHiddenLayer().createPerceptrons(nPerceptrons); } network.getOutputLayer().createPerceptrons(1); network.linkAll(); Perceptron perceptrons[] = network.getPerceptrons(); // Set all perceptrons activation function for (int i=0; i < perceptrons.length-nOutputs; i++) { perceptrons[i].setActivation(hiddenLayerActivation); } perceptrons[perceptrons.length-nOutputs].setActivation(outputLayerActivation); network.setWeights(inits); /* QuasiNewtonTrainer stage1trainer = new QuasiNewtonTrainer(); stage1trainer.setMaximumTrainingIterations(500); QuasiNewtonTrainer stage2trainer = new QuasiNewtonTrainer(); stage2trainer.setMaximumTrainingIterations(1000); trainer = new EpochTrainer(stage1trainer, stage2trainer); trainer.setEpochSize(ysub.length-nLags); trainer.setNumberOfEpochs(30); trainer.setRandom(new com.imsl.stat.Random(seed)); trainer.setRandomSamples(new com.imsl.stat.Random(123), new com.imsl.stat.Random(678)); trainer.train(network, lagged, outs); switch(trainer.getErrorStatus()){ case 0: errorMsg = errorMsg0; break; case 1: errorMsg = errorMsg1; break; case 2: errorMsg = errorMsg2; break; case 3: errorMsg = errorMsg3; break; case 4: errorMsg = errorMsg4; break; case 5: errorMsg = errorMsg5; break; default:errorMsg = errorMsg0; } if (!errorMsg.equals(errorMsg0)) { javax.swing.JOptionPane.showMessageDialog(this,errorMsg,"Training Error",javax.swing.JOptionPane.ERROR_MESSAGE); } // stats of the trained network stats = network.computeStatistics(lagged, outs); weight = network.getWeights(); gradient = trainer.getErrorGradient(); */ //no need to benchmark when reading a file :) //float dur = (System.currentTimeMillis()-start)/1000.0f; //System.out.println(" Done: "+dur+" seconds."); this.setCursor(new java.awt.Cursor(java.awt.Cursor.DEFAULT_CURSOR)); // network was trained once and weights written to disk // this takes approx. 95 seconds, 8 seconds, 40 seconds for each data set //new com.imsl.math.PrintMatrix("weights").print(weight); //com.imsl.util.IOHelp.write1D("weights"+seed,weight); } // lag the data. would use the TimeSeriesFilter, but y is 1D private double[][] lagData(double[] a, int n) { int len = a.length; double[][] b = new double[len-n][n]; for (int i=0; i<len-n; i++) { for (int j=0; j<n; j++) { b[i][j] = a[i+j]; } } return b; } // Get the information from the interface and redraw the chart. private void update() { if (data != null) data.remove(); if (forecastData != null) forecastData.remove(); if (nnData != null) nnData.remove(); if (realData != null) realData.remove(); drawGraph(); repaint(); } public static void main(String args[]) { boolean exitOnClose = true; if (args.length > 0 && args[0].equals("-noexit")) exitOnClose = false; new ForecastNet(exitOnClose).show(); } /** * Read a double array from a file. * * @param fileName the name of the file to read * @param len the length of the file * @return the double array */ double[] read1D(String fileName, int len) { double[] d = new double[len]; boolean done = false; int i = 0; String line; try { java.io.InputStream is = getClass().getResourceAsStream(fileName+".dat"); java.io.BufferedReader br = new java.io.BufferedReader(new java.io.InputStreamReader(is)); //java.io.BufferedReader br = new java.io.BufferedReader(new java.io.FileReader(fileName)); while (!done) { try { line = br.readLine(); if (line == null) { done = true; br.close(); } else { d[i++] = (Double.valueOf(line.trim())).doubleValue(); } } catch (NumberFormatException nfe) { System.out.println("NumberFormatException: " + nfe.getMessage()); } catch (java.io.EOFException eof) { done = true; br.close(); } catch (ArrayIndexOutOfBoundsException obe) { done = true; br.close(); } } } catch (java.io.IOException ioe) { System.err.println("FileIO: " + ioe.getMessage()); } return d; } // Error Status Messages for the Least Squares Trainer private static String errorMsg0 = "Least Squares Training Completed Successfully"; private static String errorMsg1 = "Scaled step tolerance was satisfied. The current solution \n"+ "may be an approximate local solution, or the algorithm is making\n"+ "slow progress and is not near a solution, or the Step Tolerance\n"+ "is too big"; private static String errorMsg2 = "Scaled actual and predicted reductions in the function are\n"+ "less than or equal to the relative function convergence\n"+ "tolerance RelativeTolerance"; private static String errorMsg3 = "Iterates appear to be converging to a noncritical point.\n"+ "Incorrect gradient information, a discontinuous function,\n"+ "or stopping tolerances being too tight may be the cause."; private static String errorMsg4 = "Five consecutive steps with the maximum stepsize have\n"+ "been taken. Either the function is unbounded below, or has\n"+ "a finite asymptote in some direction, or the maximum stepsize\n"+ "is too small."; private static String errorMsg5 = "Too many iterations required"; }