/*
 * -------------------------------------------------------------------------
 *      $Id: ChartStock.java,v 1.10 2004/05/26 19:20:42 estewart Exp $
 * -------------------------------------------------------------------------
 *      Copyright (c) 1999 Visual Numerics Inc. All Rights Reserved.
 *
 *      This software is confidential information which is proprietary to
 *      and a trade secret of Visual Numerics, Inc.  Use, duplication or
 *      disclosure is subject to the terms of an appropriate license
 *      agreement.
 *
 *      VISUAL NUMERICS MAKES NO REPRESENTATIONS OR WARRANTIES ABOUT THE
 *      SUITABILITY OF THE SOFTWARE, EITHER EXPRESS OR IMPLIED, INCLUDING
 *      BUT NOT LIMITED TO THE IMPLIED WARRANTIES OF MERCHANTABILITY,
 *      FITNESS FOR A PARTICULAR PURPOSE, OR NON-INFRINGEMENT. VISUAL
 *      NUMERICS SHALL NOT BE LIABLE FOR ANY DAMAGES SUFFERED BY LICENSEE
 *      AS A RESULT OF USING, MODIFYING OR DISTRIBUTING THIS SOFTWARE OR
 *      ITS DERIVATIVES.
 *--------------------------------------------------------------------------
 */

package com.imsl.demo.WallStreet;
import com.imsl.chart.*;
import com.imsl.math.Sfun;
import java.awt.Font;
import java.text.*;
import java.util.*;

/**
 *
 * @author  brophy
 * @created January 24, 2002
 */
class ChartStock extends JFrameChart  {
    static private final long DAY = 3600*24*1000;
    static private final String GRID_COLOR = "lightGray";
    static private final DateFormat DATE_FORMAT = new SimpleDateFormat("MM/dd/yy");

    private Model   model;
    private AxisXY  axis;
    private Axis1D  axisX, axisY, axisVolumeX, axisVolumeY;
    private Data    data;
    private AxisXY  axisVolume;
    private Data    dataVolume;
    private Data    dataOver;
    private String  titleOver;

    ChartStock(String ticker[], Model model) {
        super();
        setTitle("Stock Price Charting");
        Chart chart = getChart();
        chart.getChartTitle().setFont(new Font("Serif",Font.BOLD|Font.ITALIC,14));
        chart.getBackground().setFillColor("lightYellow");
        chart.setTextColor("darkBlue");
        axis = new AxisXY(chart);
        axisX = axis.getAxisX();
        axisY = axis.getAxisY();
        axisX.getAxisLabel().setTextFormat(DATE_FORMAT);
        TransformDate transform = new TransformDate();
        axisX.setCustomTransform(transform);
        axisX.setTransform(AxisXY.TRANSFORM_CUSTOM);
        axisX.setSkipWeekends(true);
        axisY.getAxisTitle().setTitle("Price");
        axisX.getGrid().setPaint(true);
        axisY.getGrid().setPaint(true);
        axisX.getGrid().setLineColor(GRID_COLOR);
        axisY.getGrid().setLineColor(GRID_COLOR);

        axisVolume = new AxisXY(chart);
        axisVolumeX = axisVolume.getAxisX();
        axisVolumeY = axisVolume.getAxisY();
        TransformDate transformVolume = new TransformDate();
        axisVolumeX.setCustomTransform(transformVolume);
        axisVolumeX.setTransform(AxisXY.TRANSFORM_CUSTOM);
        axisVolumeX.setSkipWeekends(true);
        axisVolumeX.getAxisLabel().setTextFormat(DATE_FORMAT);
        axisVolumeX.getAxisLabel().setPaint(false);
        axisVolumeY.getAxisLabel().setTextFormat("0");
        axisVolumeX.getMinorTick().setPaint(false);
        axisVolumeX.getGrid().setPaint(true);
        axisVolumeX.getGrid().setLineColor(GRID_COLOR);
        axisVolumeY.getGrid().setLineColor(GRID_COLOR);
        axisVolumeX.setAutoscaleInput(Axis1D.AUTOSCALE_OFF);

        this.model = model;
        model.addListener(new Model.Listener() {
            public void modelChanged() {
                drawChart();
            }
        });
    }


    private void setChartTitle() {
        int style = model.getStyle();
        int period = model.getPeriod();

        String sPeriod = "";
        switch (period) {
            case GregorianCalendar.MONTH:
                sPeriod = "(Monthly)";
                break;
            case GregorianCalendar.WEEK_OF_YEAR:
                sPeriod = "(Weekly)";
                break;
            case GregorianCalendar.DAY_OF_WEEK:
                sPeriod = "(Daily)";
                break;
        }

        String sFormat = "";
        switch (style) {
            case Model.STYLE_CLOSE:
                sFormat = "Closing Prices for {0} {1}";
                break;
            case Model.STYLE_CANDLESTICK:
                sFormat = "Candlesticks of {0} {1}";
                break;
            case Model.STYLE_HIGH_LOW_CLOSE:
                sFormat = "High-Low-Close-Open of {0} {1}";
                break;
        }

        MessageFormat format = new java.text.MessageFormat(sFormat);
        Object args[] = new Object[]{model.getTicker(), sPeriod};
        getChart().getChartTitle().setTitle(format.format(args));
    }


    private void drawChart() {
        try {
            Database database = model.getDatabase();
            String ticker = model.getTicker();
            int style = model.getStyle();
            setChartTitle();
            Database.Series series = database.getSeries(ticker, model.getPeriod(), model.getInterval());
            init(series);
            switch (style) {
                case Model.STYLE_CLOSE:
                    plotClose(ticker, series);
                    break;
                case Model.STYLE_CANDLESTICK:
                    plotCandlestick(ticker, series);
                    break;
                case Model.STYLE_HIGH_LOW_CLOSE:
                    plotHighLowClose(ticker, series);
                    break;
            }
            int over = model.getOverlay();
            if (series.date.length == 0)  over = Model.OVERLAY_NONE;
            double[] overlay;
            Chart chart = getChart();
            switch (over) {
                case Model.OVERLAY_NONE:
                    chart.getLegend().setPaint(false);
                    break;
                case Model.OVERLAY_SAR:
                    overlay = computeSAR(series);
                    titleOver = "Parabolic SAR";
                    chart.getLegend().setPaint(true);
                    plotOverlay(ticker, series, overlay);
                    break;
                case Model.OVERLAY_MA20:
                    overlay = computeMA20(series);
                    titleOver = "20 pt Moving Avg";
                    chart.getLegend().setPaint(true);
                    plotOverlay(ticker, series, overlay);
                    break;
                case Model.OVERLAY_KAL:
                    overlay = computeKalman(series);
                    titleOver = "Kalman Filter";
                    chart.getLegend().setPaint(true);
                    plotOverlay(ticker, series, overlay);
                    break;
            }
            chart.getLegend().setViewport(0.8,0.9,0.04,0.09);
            repaint();
        } catch (Exception e) {
            e.printStackTrace();
        }
    }

    private double[] computeSAR(Database.Series series) {
        // set the constants
        final double a0 = 0.00; // should be 0.02, but always incremented, so 0.00
        final double delta_a = 0.02;
        final double max_a = 0.20;
        // declare the variables
        double a = a0;
        boolean lon = true;     // long (true) or short (false)
        double ep;              // extreme point
        double[] s = new double[series.close.length];

        if (series.close[1] >= series.close[0]) {
            lon = true;
            ep = series.high[0];
        } else {
            lon = false;
            ep = series.low[0];
        }

        s[0] = series.close[0];
        double maxhi = series.high[0];
        double minlo = series.low[0];
        for (int i=1; i<series.close.length; i++) {
            if (series.high[i] > maxhi) {
                maxhi = series.high[i];
            }
            if (series.low[i] < minlo) {
                minlo = series.low[i];
            }

            // test for reversal
            if ((lon) && (series.close[i] < s[i-1])) {
                lon = false;    // we're now short
                a = a0;         // reset acceleration factor
                ep = maxhi;     // new extreme point
                s[i] = ep;
                continue;
            }
            if ((!lon) && (series.close[i] > s[i-1])) {
                lon = true;     // we're now long
                a = a0;         // reset acceleration factor
                ep = minlo;     // new extreme point
                s[i] = ep;
                continue;
            }

            // test to increase acceleration
            if ((lon) && (series.high[i] > ep)) {
                a += delta_a;
                if (a > max_a) a = max_a;
                ep = series.high[i];
            }
            if ((!lon) && (series.low[i] < ep)) {
                a += delta_a;
                if (a > max_a) a = max_a;
                ep = series.low[i];
            }

            // compute new SAR
            s[i] = s[i-1] + a*(ep-s[i-1]);

            // s cannot enter trading range of last 2 periods
            if (lon) {
                if (s[i] > series.low[i]) s[i] = series.low[i];
                if (s[i] > series.low[i-1]) s[i] = series.low[i-1];
            }
            if (!lon) {
                if (s[i] < series.high[i]) s[i] = series.high[i];
                if (s[i] < series.high[i-1]) s[i] = series.high[i-1];
            }
        }
        return s;
    }

    private double[] computeMA20(Database.Series series) {
        final int period = 20;
        int num = 0;
        double sum = 0.0;
        double[] m = new double[series.close.length];
        for (int i=0; i<series.close.length; i++) {
            sum += series.close[i];
            if (i < period) {
                num++;
                m[i] = sum/num;
            } else {
                sum -= series.close[i-period];
                m[i] = sum/period;
            }
        }
        return m;
    }

    private double[] computeKalman(Database.Series series) {
        int nobs = series.close.length;
        double[] ff = new double[nobs];
        int rank = 0;
        double logDeterminant = 0.0;
        double ss = 0.0;
        double[] b = {4};
        double[] covb = {16};
        double[][] q = {{1}};
        double[][] r = {{1}};
        double[][] t = {{1}};
        double[][] z = {{1}};
        q[0][0] = 0.1;

        for (int i = 0; i < nobs; i++) {
            double yi[] = {series.close[i]};
            com.imsl.stat.KalmanFilter kalman =
                    new com.imsl.stat.KalmanFilter(b, covb, rank, ss, logDeterminant);
            kalman.update(yi, z, r);
            kalman.filter();
            b = kalman.getStateVector();
            covb = kalman.getCovB();
            rank = kalman.getRank();
            ss = kalman.getSumOfSquares();
            logDeterminant = kalman.getLogDeterminant();
            double v[] = kalman.getPredictionError();
            double covv[][] = kalman.getCovV();

            kalman = new com.imsl.stat.KalmanFilter(b, covb, rank, ss, logDeterminant);
            kalman.setTransitionMatrix(t);
            kalman.setQ(q);
            kalman.filter();
            b = kalman.getStateVector();
            covb = kalman.getCovB();
            rank = kalman.getRank();
            ss = kalman.getSumOfSquares();
            logDeterminant = kalman.getLogDeterminant();
            ff[i] = b[0];
        }
        return ff;
    }

    private void plotOverlay(String ticker, Database.Series series, double[] overlay) {
        dataOver = new Data(axis, series.date, overlay);
        dataOver.setDataType(Data.DATA_TYPE_MARKER);
        dataOver.setMarkerType(Data.MARKER_TYPE_FILLED_DIAMOND);
        dataOver.setMarkerColor("black");
        dataOver.setMarkerSize(0.4);
        dataOver.setTitle(titleOver);
        repaint();
    }

    private void plotClose(String ticker, Database.Series series) {
        data = new Data(axis, series.date, series.close);
        data.setLineColor("blue");
        plotVolume(series.date, series.volume);
    }


    private void plotCandlestick(String ticker, Database.Series series) {
        data = new Candlestick(axis, series.date, series.high, series.low, series.close, series.open);
        data.setMarkerSize(Math.min(1,30./series.date.length));
        ((Candlestick)data).getUp().setFillColor("green");
        ((Candlestick)data).getDown().setFillColor("red");
        plotVolume(series.date, series.volume);
    }


    private void plotHighLowClose(String ticker, Database.Series series) {
        data = new HighLowClose(axis, series.date, series.high, series.low, series.close, series.open);
        data.setMarkerSize(Math.min(1.0,30./series.date.length));
        data.setMarkerColor("blue");
        plotVolume(series.date, series.volume);
    }


    private void plotVolume(double date[], double volume[]) {
        if (dataVolume != null)  dataVolume.remove();
        if (dataOver != null) dataOver.remove();
        int volumeStyle = model.getVolumeStyle();
        if (volumeStyle != Model.VOLUME_NONE) {
            axis.setViewport(0.1, 0.95, 0.1, 0.65);
            axisVolume.setViewport(0.1, 0.95, 0.725, 0.95);
            axisVolume.setPaint(true);
        } else {
            axis.setViewport(0.1, 0.95, 0.1, 0.95);
            axisVolume.setPaint(false);
            dataVolume = null;
            return;
        }

        double v[] = (double[])volume.clone();
        double max = 0.0;
        for (int k = 0;  k < v.length;  k++) max = Math.max(max,v[k]);
        double scale;
        String label;
        if (max > 3.0e6) {
            scale = 1.0e-6;
            label = "Millions";
        } else {
            scale = 1.0e-3;
            label = "Thousands";
        }
        for (int k = 0;  k < v.length;  k++) v[k] *= scale;

        axisVolumeY.getAxisTitle().setTitle(label);

        // Copy scaling parameters from from price chart to the volume chart
        axis.setupMapping();
        axisVolumeX.setWindow(axisX.getWindow());
        axisVolumeX.setTicks(axisX.getTicks());
        axisVolume.setupMapping();

        if (v.length == 0) {
        } else if (volumeStyle == Model.VOLUME_BARS) {
            dataVolume = new Bar(axisVolume, date, v);
            dataVolume.setBarType(Bar.BAR_TYPE_VERTICAL);
            double win[] = axisVolumeX.getWindow();
            double width = 0.25*(win[1]-win[0])/Math.max(10,v.length);
            if (v.length > 100)  width = 0;
            dataVolume.setBarWidth(width);
            dataVolume.setFillColor("blue");
            dataVolume.setFillOutlineType(Data.FILL_TYPE_NONE);
        } else if (volumeStyle == Model.VOLUME_LINE) {
            dataVolume = new Data(axisVolume, date, v);
            dataVolume.setDataType(Data.DATA_TYPE_LINE);
            dataVolume.setLineColor("blue");
        } else if (volumeStyle == Model.VOLUME_AREA) {
            dataVolume = new Data(axisVolume, date, v);
            dataVolume.setDataType(Data.DATA_TYPE_FILL);
            dataVolume.setFillColor("blue");
        }
    }


    private void init(Database.Series series) {
        if (data != null)  data.remove();

        int scaling = model.getScaling();
        int transform = Data.TRANSFORM_LINEAR;
        if (scaling == Model.SCALING_LOG)  transform = Data.TRANSFORM_LOG;
        axisY.setTransform(transform);
        if (series.close.length == 0) {
            axisY.setWindow(0, 1);
            axisY.setAutoscaleInput(AxisXY.AUTOSCALE_OFF);
        } else if (scaling == Model.SCALING_LOG) {
            double min = series.low[0];
            double max = series.high[0];
            for (int k = 1;  k < series.low.length;  k++) {
                min = Math.min(min,series.low[k]);
                max = Math.max(max,series.high[k]);
            }
            int logMin = (int)Math.floor(Sfun.log10(min));
            int logMax = (int)Math.ceil(Sfun.log10(max));
            min = Math.pow(10.0, logMin);
            max = Math.pow(10.0, logMax);

            double ticks[] = new double[5];
            double f = max/ticks.length;
            ticks[0] = min;
            for (int k = 0;  k < ticks.length;  k++) {
                ticks[k] = (k+1)*f;
            }
            axisY.setTicks(ticks);
            axisY.setWindow(min,max);
            axisY.setAutoscaleInput(AxisXY.AUTOSCALE_OFF);
        } else {
            axisY.setAttribute("Ticks", null);
            axisY.setAutoscaleInput(AxisXY.AUTOSCALE_DATA);
            axisY.setAutoscaleOutput(AxisXY.AUTOSCALE_WINDOW | AxisXY.AUTOSCALE_NUMBER);
        }
    }
}