/*
 * -------------------------------------------------------------------------
 *      $Id: ForecastNet.java,v 1.3 2005/03/14 14:57:10 estewart Exp $
 * -------------------------------------------------------------------------
 *      Copyright (c) 1999 Visual Numerics Inc. All Rights Reserved.
 *
 *      This software is confidential information which is proprietary to
 *      and a trade secret of Visual Numerics, Inc.  Use, duplication or
 *      disclosure is subject to the terms of an appropriate license
 *      agreement.
 *
 *      VISUAL NUMERICS MAKES NO REPRESENTATIONS OR WARRANTIES ABOUT THE
 *      SUITABILITY OF THE SOFTWARE, EITHER EXPRESS OR IMPLIED, INCLUDING
 *      BUT NOT LIMITED TO THE IMPLIED WARRANTIES OF MERCHANTABILITY,
 *      FITNESS FOR A PARTICULAR PURPOSE, OR NON-INFRINGEMENT. VISUAL
 *      NUMERICS SHALL NOT BE LIABLE FOR ANY DAMAGES SUFFERED BY LICENSEE
 *      AS A RESULT OF USING, MODIFYING OR DISTRIBUTING THIS SOFTWARE OR
 *      ITS DERIVATIVES.
 *--------------------------------------------------------------------------
 */

package com.imsl.demo.nnet;
import javax.swing.*;
import javax.swing.event.*;
import java.awt.*;
import java.awt.event.*;
import java.util.*;
import java.text.*;
import com.imsl.stat.*;
import com.imsl.chart.*;
import com.imsl.datamining.neural.*;
import com.imsl.demo.gallery.Describe;


public class ForecastNet extends JFrameChart {
    private Chart chart;
    private AxisXY axis;
    private Data data, forecastData, nnData, realData;
    private JRadioButtonMenuItem[] jRadioButtons;
    private JCheckBox jCheckData, jCheckArma, jCheckNnet;
    private JPanel jCheckPanel;
    private final int numData = 3;
    private final String dataName[] = {"Oil", "Sales", "Sunspots"};
    private ARMA arma;
    private SimpleDateFormat dateFormat = new SimpleDateFormat("M/yyyy");
    private double[] y, x, yRange, xsub, ysub;
    private int[] yearRange;
    private String yTitle;
    private int dataIncrement, seed;
    private FeedForwardNetwork network;
    private EpochTrainer trainer;
    private Activation hiddenLayerActivation = Activation.LOGISTIC;
    private Activation outputLayerActivation = Activation.LINEAR;
    private double weight[], gradient[], stats[];
    int nOutputs = 1;
    int nPerceptrons = 8;
    int nLayers = 1;
    int nLags = 12;
    int nPredicts = 12;
    private String errorMsg;

    public ForecastNet(boolean exitOnClose) {
        setTitle("Neural Net Forecasting");
        if (!exitOnClose) {
            // remove the WindowListener,  installed by JFrameChart, that
            // exits the application when the window is closed.
            Object l[] = getListeners(java.awt.event.WindowListener.class);
            for (int k = 0;  k < l.length;  k++) {
                removeWindowListener((java.awt.event.WindowListener)l[k]);
            }
        }

        Describe des = new Describe(this, "/com/imsl/demo/nnet/ForecastNet.html");
        des.show();
        Dimension ds = des.getSize();

        Dimension ss = getToolkit().getScreenSize();
        //int h = Math.min(ss.width/2, ss.height-64);
        int h = Math.min(ss.width/2, ss.height-ds.height-32);
        int w = (int)(h/0.8);
        setSize(w, h);
        //setLocation(32, 32);
        setLocation(ss.width-ds.width, ds.height);

        // Set default values.
        chart = getChart();
        axis = new AxisXY(chart);
        axis.getAxisX().getAxisLabel().setTextFormat(dateFormat);
        axis.getAxisX().getAxisLabel().setTextAngle(90);
        axis.getAxisX().getAxisTitle().setTitle("Date");
        axis.getAxisY().setTextFormat(new java.text.DecimalFormat("###,###"));
        chart.getLegend().setPaint(true);

        jCheckPanel = new JPanel();
        jCheckArma = new JCheckBox("ARMA Forecast        ");
        jCheckNnet = new JCheckBox("Neural Net Forecast        ");
        jCheckData = new JCheckBox("Actual outcome        ");
        jCheckArma.setSelected(true);
        jCheckPanel.add(jCheckArma);
        jCheckPanel.add(jCheckNnet);
        jCheckPanel.add(jCheckData);
        java.awt.event.ActionListener cal = new java.awt.event.ActionListener() {
            public void actionPerformed(java.awt.event.ActionEvent evt) {
                update();
            }
        };
        jCheckArma.addActionListener(cal);
        jCheckNnet.addActionListener(cal);
        jCheckData.addActionListener(cal);
        getContentPane().add(jCheckPanel, java.awt.BorderLayout.NORTH);

        JMenuBar jMenu = this.getJMenuBar();
        JMenu menuData = new JMenu();
        menuData.setMnemonic('D');
        menuData.setText("Data");
        jMenu.add(menuData);

        jRadioButtons = new javax.swing.JRadioButtonMenuItem[numData];
        ButtonGroup group = new ButtonGroup();
        for (int i=0; i<numData; i++) {
            jRadioButtons[i] = new javax.swing.JRadioButtonMenuItem();
            jRadioButtons[i].setText(dataName[i]);

            jRadioButtons[i].addActionListener(new java.awt.event.ActionListener() {
                public void actionPerformed(java.awt.event.ActionEvent evt) {
                    resetChecks();
                    getData();
                    update();
                }
            });
            group.add(jRadioButtons[i]);
            menuData.add(jRadioButtons[i]);
        }
        jRadioButtons[1].setSelected(true);

        getData();
        drawGraph();
        setResizable(false);
    }


    private void resetChecks() {
        jCheckArma.setSelected(true);
        jCheckNnet.setSelected(false);
        jCheckData.setSelected(false);
    }

    // load up the selected data set
    private void getData() {
        int selected = -1;
        for (int i=0; i<numData; i++) {
            if (jRadioButtons[i].isSelected()) {
                selected = i;
                break;
            }
        }
        switch (selected) {
            case 2: // sunspot data
                y = new double[] {100.8, 81.6, 66.5, 34.8, 30.6, 7, 19.8, 92.5,
                    154.4, 125.9, 84.8, 68.1, 38.5, 22.8, 10.2, 24.1, 82.9,
                    132, 130.9, 118.1, 89.9, 66.6, 60, 46.9, 41, 21.3, 16,
                    6.4, 4.1, 6.8, 14.5, 34, 45, 43.1, 47.5, 42.2, 28.1, 10.1,
                    8.1, 2.5, 0, 1.4, 5, 12.2, 13.9, 35.4, 45.8, 41.1, 30.4,
                    23.9, 15.7, 6.6, 4, 1.8, 8.5, 16.6, 36.3, 49.7, 62.5, 67,
                    71, 47.8, 27.5, 8.5, 13.2, 56.9, 121.5, 138.3, 103.2,
                    85.8, 63.2, 36.8, 24.2, 10.7, 15, 40.1, 61.5, 98.5, 124.3,
                    95.9, 66.5, 64.5, 54.2, 39, 20.6, 6.7, 4.3, 22.8, 54.8,
                    93.8, 95.7, 77.2, 59.1, 44, 47, 30.5, 16.3, 7.3, 37.3,
                    73.9};
                yearRange = new int[] {1770, 1858, 1874};
                yRange = new double[] {-50, 200};
                dataIncrement = GregorianCalendar.YEAR;
                yTitle = "Number of Sun Spots per Year";
                seed = 1593;
                //best so far: 8765, 369, 1593=123457
                //not bad: 45678
                break;

            case 1: // champagne sales data
                y = new double[] {2815,  2672,  2755,  2721,  2946,  3036,  2282,  2212,  2922,  4301,  5764,  7132,
                    2541,  2475,  3031,  3266,  3776,  3230,  3028,  1759,  3595,  4474,  6838,  8357,
                    3113,  3006,  4047,  3523,  3937,  3986,  3260,  1573,  3528,  5211,  7614,  9254,
                    5375,  3088,  3718,  4514,  4520,  4539,  3663,  1643,  4739,  5428,  8314, 10651,
                    3633,  4292,  4154,  4121,  4647,  4753,  3965,  1723,  5048,  6922,  9858, 11331,
                    4016,  3957,  4510,  4276,  4968,  4677,  3523,  1821,  5222,  6873, 10803, 13916,
                    2639,  2899,  3370,  3740,  2927,  3986,  4217,  1738,  5221,  6424,  9842, 13076,
                    3934,  3162,  4286,  4676,  5010,  4874,  4633,  1659,  5951,  6981,  9851, 12670};
                yearRange = new int[] {1964, 1971, 1972};
                yRange = new double[] {0, 14000};
                dataIncrement = GregorianCalendar.MONTH;
                yTitle = "Cases of Champagne Sold per Month";
                seed = 6789;
                //best: 6789
                //not bad 23456,65432
                break;

            case 0: // west texas crude oil price
                y = new double[] {23,15.4,12.6,12.8,15.4,13.5,11.5,15,14.9,14.9,
                    15.1,16.1,18.6,17.7,18.3,18.6,19.4,20,21.3,20.2,19.5,19.8,19,
                    17.2,17.1,16.7,16.2,17.8,17.4,16.5,15.5,15.5,14.4,13.8,13.9,
                    16.2,17.9,17.8,19.4,21,20,20,19.6,18.5,19.5,20,19.8,21,22.6,
                    22.1,20.4,18.5,18.2,16.8,18.6,27.1,33.6,36,32.3,27.3,24.9,20.5,
                    19.8,20.8,21.2,20.2,21.4,21.6,21.8,23.2,22.4,19.5,18.8,19,18.9,
                    20.2,20.9,22.3,21.7,21.3,21.9,21.6,20.3,19.4,19,20,20.3,20.2,
                    19.9,19,17.8,18,17.5,18.1,16.6,14.5,15,14.7,14.6,16.3,17.8,19,
                    19.6,18.3,17.4,17.7,18.1,17.1,17.9,18.5,18.5,19.8,19.7,18.4,
                    17.3,18,18.2,17.4,17.9,19,18.8,19,21.3,23.5,21.2,20.4,21.3,21.9,
                    23.9,24.9,23.7,25.3,25.1,22.2,20.9,19.7,20.8,19.1,19.6,19.9,
                    19.7,21.2,20.1,18.3,16.7,16,15,15.4,14.8,13.6,14,13.3,14.9,14.3,
                    12.9,11.2,12.4,12,14.6,17.3,17.7,17.8,20,21.2,23.8,22.6,24.8,
                    26.1,27.2,29.3,29.8,25.7,28.7,31.8,29.7,31.2,33.8,33,34.4,28.4,
                    29.5,29.6,27.2,27.4,28.6,27.6,26.4,27.4,25.8,22.2,19.6,19.3,
                    19.6,20.7,24.4,26.2,27,25.5,26.9,28.3,29.6,28.8,26.2};
                yearRange = new int[] {1986, 2002, 2003};
                yRange = new double[] {0, 50};
                dataIncrement = GregorianCalendar.MONTH;
                yTitle = "West Texas Crude Oil, Dollars per Barrel";
                seed = 764;
                // 2468 over, 87654, 8765, 764
                break;
        }

        x = new double[y.length];
        for (int i = 0; i < y.length; i++) {
            switch (dataIncrement) {
                case GregorianCalendar.YEAR:
                    x[i] = (new GregorianCalendar(yearRange[0]+i, 0, 1)).getTimeInMillis();
                    break;
                case GregorianCalendar.MONTH:
                    x[i] = (new GregorianCalendar(yearRange[0], i, 1)).getTimeInMillis();
                    break;
            }
        }
        ysub = new double[y.length-nPredicts];
        xsub = new double[x.length-nPredicts];
        for (int i=0; i<ysub.length; i++) {
            ysub[i] = y[i];
            xsub[i] = x[i];
        }

        trainNetwork();
    }


    // Draw the graph.
    private void drawGraph() {
        axis.getAxisX().setAutoscaleInput(axis.AUTOSCALE_OFF);
        axis.getAxisX().setWindow((new GregorianCalendar(yearRange[0], 0, 1)).getTimeInMillis(),
            (new GregorianCalendar(yearRange[2], 0, 1)).getTimeInMillis());
        axis.getAxisY().setAutoscaleInput(axis.AUTOSCALE_OFF);
        axis.getAxisY().setWindow(yRange);

        axis.getAxisY().getAxisTitle().setTitle(yTitle);

        // Draw the actual data.
        data = new Data(axis, xsub, ysub);
        data.setTitle("Historical data");
        data.setDataType(Data.DATA_TYPE_LINE | Data.DATA_TYPE_MARKER);
        data.setMarkerType(Data.MARKER_TYPE_HOLLOW_CIRCLE);
        data.setMarkerSize(0.50);
        data.setMarkerColor(Color.BLUE);
        data.setLineColor(Color.BLUE);

        // Use ARMA class (ARMA(2,1)) to get the forecast using backorigin = 3.
        // Number of predicts and confidence level are from the interface.
        try {
            arma = new ARMA(2, 1, ysub);
            arma.setRelativeError(0.0);
            arma.setMaxIterations(0);
            arma.compute();
            arma.setConfidence(0.50);
            arma.setBackwardOrigin(4);
            double[][] forecast = arma.forecast(nPredicts);

            double[] tmp = new double[nPredicts];
            double[] x1 = new double[nPredicts];

            // Draw the forecast data.
            for (int i = 0; i < nPredicts; i++) {
                tmp[i] = forecast[i][3];
                switch (dataIncrement) {
                    case GregorianCalendar.YEAR:
                        x1[i] = (new GregorianCalendar(yearRange[1]+i, 0, 1)).getTimeInMillis();
                        break;
                    case GregorianCalendar.MONTH:
                        x1[i] = (new GregorianCalendar(yearRange[1], i, 1)).getTimeInMillis();
                        break;
                }
            }

            double[] y2 = new double[nPredicts+1];
            double[] x2 = new double[nPredicts+1];
            for (int i=0; i<y2.length; i++) {
                y2[i] = y[i+ysub.length-1];
            }
            System.arraycopy(x1, 0, x2, 1, x1.length);

            if (dataIncrement == GregorianCalendar.YEAR) {
                x2[0] = (new GregorianCalendar(yearRange[1]-1, 0, 1)).getTimeInMillis();
            } else {
                x2[0] = (new GregorianCalendar(yearRange[1], -1, 1)).getTimeInMillis();
            }

            if (jCheckData.isSelected()) {
                realData = new Data(axis, x2, y2);
                realData.setTitle("Actual outcome");
                realData.setDataType(Data.DATA_TYPE_LINE);
                realData.setLineColor(Color.BLUE);
                //realData.setLineWidth(1.0);
                realData.setLineDashPattern(Data.DASH_PATTERN_DOT);
            }

            if (jCheckArma.isSelected()) {
                //forecastData = new Data(axis, x1, tmp);
                forecastData = new Data(axis, prepend(x2[0], x1), prepend(y2[0], tmp));
                forecastData.setTitle("ARMA Forecast");
                forecastData.setDataType(Data.DATA_TYPE_LINE);
                forecastData.setLineColor(Color.RED);
                //forecastData.setLineWidth(2.0);
            }

            // Now do a FFNNET forecast
            // network was trained when new data were selected

            // Now the thing is trained an we need new data to push through
            // start with the full data in lagged[0] and then add results
            // to the end as they're forecast

            int offset = 0;

            double[] ins = new double[nLags];
            double[] sol = new double[nPredicts+offset];
            System.arraycopy(y, y.length-nLags, ins, 0, nLags);
            for (int i=0; i<nPredicts+offset; i++) {
                double[] nnforecast = network.forecast(ins);
                sol[i] = nnforecast[0];
                for (int j=0; j<nLags-1; j++) {
                    ins[j] = ins[j+1];
                }
                ins[nLags-1] = nnforecast[0];
            }

            if (jCheckNnet.isSelected()) {
                //nnData = new Data(axis, x1, sol);
                nnData = new Data(axis, prepend(x2[0], x1), prepend(y2[0], sol));
                nnData.setTitle("FFNN Forecast");
                nnData.setDataType(Data.DATA_TYPE_LINE);
                nnData.setLineColor(Color.MAGENTA);
                //nnData.setLineWidth(2.0);
            }

        } catch (com.imsl.IMSLException e) {
            System.out.println(e.getMessage());
        }
    }

    private double[] prepend (double v, double[] a) {
        double[] c = new double[a.length+1];
        c[0] = v;
        System.arraycopy(a, 0, c, 1, a.length);
        return c;
    }


    private void trainNetwork() {
        //System.out.print("Training...");
        //long start = System.currentTimeMillis();
        this.setCursor(new java.awt.Cursor(java.awt.Cursor.WAIT_CURSOR));

        double[] inits = read1D("weights"+seed, 113);

        // lagged values of the timeseries are inputs
        // actual values are the output
        errorMsg = "";

        double[][] lagged = lagData(ysub, nLags);
        double[][] outs = new double[ysub.length-nLags][1];
        for (int i=0; i<ysub.length-nLags; i++) {
            outs[i][0] = ysub[i+nLags];
        }

        network = new FeedForwardNetwork();
        network.getInputLayer().createInputs(nLags);
        for (int i=0; i<nLayers; i++) {
            network.createHiddenLayer().createPerceptrons(nPerceptrons);
        }
        network.getOutputLayer().createPerceptrons(1);
        network.linkAll();
        Perceptron perceptrons[] = network.getPerceptrons();
        // Set all perceptrons activation function
        for (int i=0; i < perceptrons.length-nOutputs; i++) {
           perceptrons[i].setActivation(hiddenLayerActivation);
        }
        perceptrons[perceptrons.length-nOutputs].setActivation(outputLayerActivation);

        network.setWeights(inits);
/*
        QuasiNewtonTrainer stage1trainer = new QuasiNewtonTrainer();
        stage1trainer.setMaximumTrainingIterations(500);
        QuasiNewtonTrainer stage2trainer = new QuasiNewtonTrainer();
        stage2trainer.setMaximumTrainingIterations(1000);

        trainer = new EpochTrainer(stage1trainer, stage2trainer);
        trainer.setEpochSize(ysub.length-nLags);
        trainer.setNumberOfEpochs(30);
        trainer.setRandom(new com.imsl.stat.Random(seed));
        trainer.setRandomSamples(new com.imsl.stat.Random(123),
                                 new com.imsl.stat.Random(678));
        trainer.train(network, lagged, outs);
        switch(trainer.getErrorStatus()){
            case 0: errorMsg = errorMsg0;
                break;
            case 1: errorMsg = errorMsg1;
                break;
            case 2: errorMsg = errorMsg2;
                break;
            case 3: errorMsg = errorMsg3;
                break;
            case 4: errorMsg = errorMsg4;
                break;
            case 5: errorMsg = errorMsg5;
                break;
            default:errorMsg = errorMsg0;
        }
        if (!errorMsg.equals(errorMsg0)) {
            javax.swing.JOptionPane.showMessageDialog(this,errorMsg,"Training Error",javax.swing.JOptionPane.ERROR_MESSAGE);
        }
        // stats of the trained network
        stats = network.computeStatistics(lagged, outs);
        weight = network.getWeights();
        gradient = trainer.getErrorGradient();
*/
        //no need to benchmark when reading a file :)
        //float dur = (System.currentTimeMillis()-start)/1000.0f;
        //System.out.println(" Done: "+dur+" seconds.");
        this.setCursor(new java.awt.Cursor(java.awt.Cursor.DEFAULT_CURSOR));

        // network was trained once and weights written to disk
        // this takes approx. 95 seconds, 8 seconds, 40 seconds for each data set
        //new com.imsl.math.PrintMatrix("weights").print(weight);
        //com.imsl.util.IOHelp.write1D("weights"+seed,weight);

    }


    // lag the data. would use the TimeSeriesFilter, but y is 1D
    private double[][] lagData(double[] a, int n) {
        int len = a.length;
        double[][] b = new double[len-n][n];
        for (int i=0; i<len-n; i++) {
            for (int j=0; j<n; j++) {
                b[i][j] = a[i+j];
            }
        }
        return b;
    }


    // Get the information from the interface and redraw the chart.
    private void update() {
        if (data != null) data.remove();
        if (forecastData != null) forecastData.remove();
        if (nnData != null) nnData.remove();
        if (realData != null) realData.remove();
        drawGraph();
        repaint();
    }


    public static void main(String args[]) {
        boolean exitOnClose = true;
        if (args.length > 0  && args[0].equals("-noexit"))  exitOnClose = false;
        new ForecastNet(exitOnClose).show();
    }


    /**
     *  Read a double array from a file.
     *
     *  @param fileName the name of the file to read
     *  @param len the length of the file
     *  @return the double array
     */
    double[] read1D(String fileName, int len) {
        double[] d = new double[len];
        boolean done = false;
        int i = 0;
        String line;
        try {
            java.io.InputStream is = getClass().getResourceAsStream(fileName+".dat");
            java.io.BufferedReader br = new java.io.BufferedReader(new java.io.InputStreamReader(is));
            //java.io.BufferedReader br = new java.io.BufferedReader(new java.io.FileReader(fileName));
            while (!done) {
                try {
                    line = br.readLine();
                    if (line == null) {
                        done = true;
                        br.close();
                    } else {
                        d[i++] = (Double.valueOf(line.trim())).doubleValue();
                    }
                } catch (NumberFormatException nfe) {
                    System.out.println("NumberFormatException: " + nfe.getMessage());
                } catch (java.io.EOFException eof) {
                    done = true;
                    br.close();
                } catch (ArrayIndexOutOfBoundsException obe) {
                    done = true;
                    br.close();
                }
            }
        } catch (java.io.IOException ioe) {
            System.err.println("FileIO: " + ioe.getMessage());
        }
        return d;
    }


    // Error Status Messages for the Least Squares Trainer
    private static String errorMsg0 =
      "Least Squares Training Completed Successfully";
    private static String errorMsg1 =
      "Scaled step tolerance was satisfied.  The current solution \n"+
      "may be an approximate local solution, or the algorithm is making\n"+
      "slow progress and is not near a solution, or the Step Tolerance\n"+
      "is too big";
    private static String errorMsg2 =
      "Scaled actual and predicted reductions in the function are\n"+
      "less than or equal to the relative function convergence\n"+
      "tolerance RelativeTolerance";
    private static String errorMsg3 =
      "Iterates appear to be converging to a noncritical point.\n"+
      "Incorrect gradient information, a discontinuous function,\n"+
      "or stopping tolerances being too tight may be the cause.";
    private static String errorMsg4 =
      "Five consecutive steps with the maximum stepsize have\n"+
      "been taken.  Either the function is unbounded below, or has\n"+
      "a finite asymptote in some direction, or the maximum stepsize\n"+
      "is too small.";
    private static String errorMsg5 =
      "Too many iterations required";
}