/*
 * -------------------------------------------------------------------------
 *      $Id: Regression.java,v 1.2 2004/05/26 18:29:14 estewart Exp $
 * -------------------------------------------------------------------------
 *      Copyright (c) 1999 Visual Numerics Inc. All Rights Reserved.
 *
 *      This software is confidential information which is proprietary to
 *      and a trade secret of Visual Numerics, Inc.  Use, duplication or
 *      disclosure is subject to the terms of an appropriate license
 *      agreement.
 *
 *      VISUAL NUMERICS MAKES NO REPRESENTATIONS OR WARRANTIES ABOUT THE
 *      SUITABILITY OF THE SOFTWARE, EITHER EXPRESS OR IMPLIED, INCLUDING
 *      BUT NOT LIMITED TO THE IMPLIED WARRANTIES OF MERCHANTABILITY,
 *      FITNESS FOR A PARTICULAR PURPOSE, OR NON-INFRINGEMENT. VISUAL
 *      NUMERICS SHALL NOT BE LIABLE FOR ANY DAMAGES SUFFERED BY LICENSEE
 *      AS A RESULT OF USING, MODIFYING OR DISTRIBUTING THIS SOFTWARE OR
 *      ITS DERIVATIVES.
 *--------------------------------------------------------------------------
 */

package com.imsl.demo.Regression;

import javax.swing.*;
import javax.swing.event.*;
import java.awt.event.*;
import java.util.StringTokenizer;
import java.util.Vector;
import java.text.DecimalFormat;
import com.imsl.math.*;
import com.imsl.stat.LinearRegression;
import com.imsl.chart.*;
import com.imsl.demo.gallery.Describe;


/**
 *
 * @author  brophy
 * @created April 24, 2001
 */
public class Regression extends JFrameChart implements
ActionListener, ChangeListener, MouseListener {
    private JTextArea displayText;
    private JTextField numField;
    private JCheckBox checkBox;
    private double intercept, x1, mse;
    private boolean hasIntercept;
    private Chart chart;
    private AxisXY axis;
    private Data point;
    private Vector dataX, dataY;
    final private int nmethod=5;
    private Data line[] = new Data[nmethod+1];
    private JCheckBoxMenuItem jMenuItem[] = new JCheckBoxMenuItem[nmethod];
    final private String title[] = {"CsAkima", "CsInterpolate",
                                     "CsPeriodic", "CsShape", "CsSmooth"};


    public Regression(boolean exitOnClose) {
        if (!exitOnClose) {
            // remove the WindowListener,  installed by JFrameChart, that
            // exits the application when the window is closed.
            Object l[] = getListeners(java.awt.event.WindowListener.class);
            for (int k = 0;  k < l.length;  k++) {
                removeWindowListener((java.awt.event.WindowListener)l[k]);
            }
        }
        Describe des = new Describe(this, "/com/imsl/demo/Regression/Regression.html");
        des.show();
        java.awt.Dimension ds = des.getSize();

        java.awt.Dimension ss = getToolkit().getScreenSize();
        int w = Math.min(ss.width/2, ss.height-ds.height-32);
        setSize(w, w);
        setLocation(ss.width-ds.width, ds.height);

        // Set default values.
        hasIntercept = true;
        dataX = new Vector();
        dataY = new Vector();
        mse = 0.0;

        chart = getChart();
        chart.getLegend().setPaint(true);
        axis = new AxisXY(chart);

        // Set axes ranges to [0.0, 50.0]
        axis.getAxisX().setAutoscaleInput(AxisXY.AUTOSCALE_OFF);
        axis.getAxisX().setWindow(0.0, 50.0);
        axis.getAxisY().setAutoscaleInput(AxisXY.AUTOSCALE_OFF);
        axis.getAxisY().setWindow(0.0, 50.0);

        // Add MouseListeren to the chart.
        getPanel().addMouseListener(this);

        drawGraph();

        java.awt.Container cp = getContentPane();

        // Set up the panel to display information.
        JPanel msgPanel = new JPanel();
        displayText = new JTextArea(getDescription(), 2, 30);
        displayText.setEditable(false);
        msgPanel.add(displayText);
        cp.add(msgPanel, java.awt.BorderLayout.NORTH);

        // Create control panel to include buttons, text field and radio button.
        JPanel buttonPanel = new JPanel();

        // Set up buttons and ActionListener.
        JButton enter = new JButton("Enter Points");
        enter.addActionListener(this);

        JButton generate = new JButton("Generate");
        generate.addActionListener(this);

        JButton restart = new JButton("Restart");
        restart.addActionListener(this);

        // Set up label, text field and radio button.
        JLabel label = new JLabel("# of Points: ");
        numField = new JTextField("6", 4);
        checkBox = new JCheckBox("intercept", hasIntercept);
        checkBox.addChangeListener(this);

        // Add components to the chart.
        buttonPanel.add(enter);
        buttonPanel.add(generate);
        buttonPanel.add(label);
        buttonPanel.add(numField);
        buttonPanel.add(checkBox);
        buttonPanel.add(restart);
        cp.add(buttonPanel, java.awt.BorderLayout.SOUTH);

        JMenuBar jMenuBar = getJMenuBar();
        JMenu jMenuLines = new JMenu("Spline");
        jMenuLines.setMnemonic('S');
        jMenuBar.add(jMenuLines);
        for (int k=0; k<nmethod; k++) {
            final int index = k;
            jMenuItem[k] = new JCheckBoxMenuItem(title[index]);
            jMenuLines.add(jMenuItem[k]);
            jMenuItem[k].setState(false);
            jMenuItem[k].addActionListener(this);
        }
    }


    // Draw the graph.
    private void drawGraph() {
        if ((dataX.size() != 0) && (dataY.size() != 0)) {
            double[] x = new double[dataX.size()];
            double[] y = new double[dataY.size()];

            for (int i = 0; i < x.length; i++) {
                x[i] = ((Double) dataX.get(i)).doubleValue();
                y[i] = ((Double) dataY.get(i)).doubleValue();
            }

            // Draw the points.
            point = new Data(axis, x, y);
            point.setDataType(Data.DATA_TYPE_MARKER);
            point.setMarkerType(Data.MARKER_TYPE_OCTAGON_X);
            point.setMarkerSize(0.8);
            point.setMarkerColor(java.awt.Color.blue);

            // Draw the regression line.
            getRegCoefficient(x, y);

            ChartFunction fcn = new ChartFunction() {
                public double f(double x) {
                    return x*x1 + intercept;
                }
            };
            line[0] = new Data(axis, fcn, 0.0, 50.0);
            line[0].setLineColor(java.awt.Color.red);
            line[0].setTitle("Linear Fit");
            line[0].setLineWidth(2);

            // Draw the spline using CsAkima
            if (jMenuItem[0].getState() && (dataX.size() > 3) && (dataY.size() > 3)) {
                final CsAkima cs0 = new CsAkima(x, y);

                ChartFunction fcn0 = new ChartFunction() {
                    public double f(double x) {
                        return cs0.value(x);
                    }
                };
                line[1] = new Data(axis, fcn0, 0.0, 50.0);
                line[1].setLineColor(java.awt.Color.magenta);
                line[1].setTitle("CsAkima");
            }

            // Draw the spline using CsInterpolate
            if (jMenuItem[1].getState() && (dataX.size() > 2) && (dataY.size() > 2)) {
                final CsInterpolate cs1 = new CsInterpolate(x, y);

                ChartFunction fcn1 = new ChartFunction() {
                    public double f(double x) {
                        return cs1.value(x);
                    }
                };
                line[2] = new Data(axis, fcn1, 0.0, 50.0);
                line[2].setLineColor(java.awt.Color.green);
                line[2].setTitle("CsInterpolate");
            }

            // Draw the spline using CsPeriodic
            if (jMenuItem[2].getState()  && (dataX.size() > 3) && (dataY.size() > 3)) {
                final CsPeriodic cs2 = new CsPeriodic(x, y);
                ChartFunction fcn2 = new ChartFunction() {
                    public double f(double x) {
                        return cs2.value(x);
                    }
                };
                line[3] = new Data(axis, fcn2, 0.0, 50.0);
                line[3].setLineColor(java.awt.Color.orange);
                line[3].setTitle("CsPeriodic");
            }

            // Draw the spline using CsShape
            if (jMenuItem[3].getState()  && (dataX.size() > 2) && (dataY.size() > 2)) {
                try {
                    final CsShape cs3 = new CsShape(x, y);
                    ChartFunction fcn3 = new ChartFunction() {
                        public double f(double x) {
                            return cs3.value(x);
                        }
                    };
                    line[4] = new Data(axis, fcn3, 0.0, 50.0);
                    line[4].setLineColor(java.awt.Color.blue);
                    line[4].setTitle("CsShape");
                } catch (CsShape.TooManyIterationsException e) {
                    System.out.println("Too many iterations");
                } catch (SingularMatrixException e) {
                    System.out.println("Singular matrix");
                }
            }

            // Draw the spline using CsSmooth
            if (jMenuItem[4].getState()  && (dataX.size() > 2) && (dataY.size() > 2)) {
                final CsSmooth cs4 = new CsSmooth(x, y);
                ChartFunction fcn4 = new ChartFunction() {
                    public double f(double x) {
                        return cs4.value(x);
                    }
                };
                line[5] = new Data(axis, fcn4, 0.0, 50.0);
                line[5].setLineColor(java.awt.Color.cyan);
                line[5].setTitle("CsSmooth");
            }
        } else {
            intercept = 0.0;
            x1 = 0.0;
        }
    }

    // Get the coefficients for the regression line.
    // y = a + bx
    private void getRegCoefficient(double[] x, double[] y) {
        LinearRegression reg = new LinearRegression(1, hasIntercept);
        double[] ss = new double[1];

        for (int i = 0; i < y.length; i++) {
            ss[0] = x[i];
            reg.update(ss, y[i]);
        }
        double[] coe = reg.getCoefficients();

        mse = reg.getANOVA().getErrorMeanSquare();

        // Set the values depending on whether intercept is chosen.
        if (hasIntercept) {
            intercept = coe[0];
            x1 = coe[1];
        } else {
            intercept = 0.0;
            x1 = coe[0];
        }
    }



    //  Get the description line.
    private String getDescription() {
        DecimalFormat formatStat = new DecimalFormat("0.000");
        StringBuffer sb = new StringBuffer("# of Points:  ");

        if (dataX == null) {
            sb.append(0);
        } else {
            sb.append(dataX.size());
        }

        if (dataX.size() < 3) {
            sb.append("   Error Mean Square:  ??");
        } else {
            sb.append("   Error Mean Square:  " + formatStat.format(mse));
        }

        sb.append("\nRegression Line:  y = (" + formatStat.format(x1) +
        ")x + (" + formatStat.format(intercept) + ")");
        return sb.toString();
    }



    // Redraw the chart.
    private void update() {
        for (int i = 0; i<nmethod+1; i++) {
            if (line[i] != null) line[i].remove();
        }
        if (point != null) point.remove();
        drawGraph();
        repaint();
        displayText.setText(getDescription());
    }



    // Implement ActionListener
    public void actionPerformed(ActionEvent e) {
        if (e.getActionCommand().substring(0,2).equals("Cs")) {
            update();
        } else if (e.getActionCommand().equals("Enter Points")) {
            // Enter points using a text area.
            JTextArea textArea = new JTextArea(10, 5);
            JScrollPane scroll = new JScrollPane(textArea);

            int option = 0;
            try {
                option = JOptionPane.showConfirmDialog(this, scroll,
                    "Enter Points", JOptionPane.OK_CANCEL_OPTION);
                if (option == JOptionPane.CANCEL_OPTION) return;
            } catch (Exception ex) {
            }

            StringTokenizer lineToken =
                new StringTokenizer(textArea.getText(), "\n");
            boolean dup = false;
            while (lineToken.hasMoreElements()) {
                StringTokenizer token = new StringTokenizer(lineToken.nextToken(), "(){}[] ,\t\n");
                // If the line does not contain 2 numbers, then ignore.
                if (token.countTokens() == 2) {
                    try {
                        double x = Double.parseDouble(token.nextToken());
                        double y = Double.parseDouble(token.nextToken());

                        // Ignore any points that are not in the range.
                        if ((x >= 0.0) && (x <= 50.0) &&
                        (y >= 0.0) && (y <= 50.0)) {
                            // Only add if the x value is distinct or else
                            // we'll throw a nasty java.lang.IllegalArgumentException
                            for (int i = 0; i < dataX.size(); i++) {
                                if (((Double) dataX.get(i)).doubleValue() == x) {
                                    dup = true;
                                }
                            }
                            if (!dup) {
                                dataX.add(new Double(x));
                                dataY.add(new Double(y));
                            }
                        }
                    } catch (Exception ex) {
                    }
                }
            }
            if (dup) {
                JOptionPane.showMessageDialog(this, "X values must be distinct.\n" +
                        "Duplicates have not been added.",
                        "Data Error", JOptionPane.ERROR_MESSAGE);
            }
            update();
        } else if(e.getActionCommand().equals("Generate")) {
            int i = 0;
            try {
                i = Integer.parseInt(numField.getText());
            } catch (Exception exception) {
                i = 6;
                numField.setText("6");
            } finally {
                if (i <= 0) {
                    i = 6;
                    numField.setText("6");
                }
                for (int j = 0; j < i; j++) {
                    dataX.add(new Double(50.0 * Math.random()));
                    dataY.add(new Double(50.0 * Math.random()));
                }
                update();
            }
        } else {
            // Restart button is clicked.  Reset everything to default.
            numField.setText("6");
            checkBox.setSelected(true);
            hasIntercept = true;
            intercept = x1 = 0.0;
            mse = 0.0;
            dataX = new Vector();
            dataY = new Vector();
            update();
        }
    }



    // Implement ChangeListener
    public void stateChanged(ChangeEvent e) {
        hasIntercept = !hasIntercept;
        update();
    }



    // Implement MouseListener
    public void mouseClicked(MouseEvent e) {
        double user[] = {0,0};
        axis.mapDeviceToUser(e.getX(), e.getY(), user);

        // Do not add the point if it is not in the range.
        if ((user[0] < 0.0) || (user[0] > 50.0) ||
        (user[1] < 0.0) || (user[1] > 50.0)) return;


        // Button is to add points.  Otherwise, remove points.
        if (e.getModifiers() == MouseEvent.BUTTON1_MASK) {
            // But only add if the x value is distinct or else
            // Splines throw a java.lang.IllegalArgumentException
            boolean dup = false;
            for (int i = 0; i < dataX.size(); i++) {
                if (((Double) dataX.get(i)).doubleValue() == (double) user[0]) {
                    dup = true;
                }
            }
            if (dup) {
                JOptionPane.showMessageDialog(this, "X values must be distinct.",
                        "Data Error", JOptionPane.ERROR_MESSAGE);
            } else {
                dataX.add(new Double(user[0]));
                dataY.add(new Double(user[1]));
                update();
            }
        } else {
            int idx = -1;
            double min = 0.7;

            for (int i = 0; i < dataX.size(); i++) {
                double x = ((Double) dataX.get(i)).doubleValue();
                double y = ((Double) dataY.get(i)).doubleValue();
                double dist = Math.sqrt((user[0] - x)*(user[0] - x) +
                (user[1] - y)*(user[1] - y));

                if (dist < min) {
                    min = dist;
                    idx = i;
                }
            }

            if (idx != -1) {
                dataX.remove(idx);
                dataY.remove(idx);
                update();
            }
        }
    }

    public void mousePressed(MouseEvent e) {
    }

    public void mouseReleased(MouseEvent e) {
    }

    public void mouseEntered(MouseEvent e) {
    }

    public void mouseExited(MouseEvent e) {
    }



    public static void main(String args[]) {
        boolean exitOnClose = true;
        if (args.length > 0  && args[0].equals("-noexit"))  exitOnClose = false;
        new Regression(exitOnClose).show();
    }
}