Example: Kalman Filter

KalmanFilter is used to compute the filtered estimates and one-step-ahead estimates for a scalar problem discussed by Harvey (1981, pages 116-117). The observation equation and state equation are given by

y_k = b_k + e_k

b_{k+1} = b_k + w_{k+1}

k = 1, 2, 3, 4

where the e_ks are identically and independently distributed normal with mean 0 and variance \sigma ^2, the w_ks are identically and independently distributed normal with mean 0 and variance 4 \sigma ^2, and b_1 is distributed normal with mean 4 and variance 16 \sigma ^2. Two KalmanFilter objects are needed for each time point in order to compute the filtered estimate and the one-step-ahead estimate. The first object does not use the methods SetTransitionMatrix and setQ so that the prediction equations are skipped in the computations. The update equations are skipped in the computations in the second object.

This example also computes the one-step-ahead prediction errors. Harvey (1981, page 117) contains a misprint for the value v_4 that he gives as 1.197. The correct value of v_4 = 1.003 is computed by KalmanFilter.

import java.text.*;
import com.imsl.stat.*;
import java.text.MessageFormat;

public class KalmanFilterEx1 {
    static private final MessageFormat mf =
        new MessageFormat("{0}/{1}\t{2}\t{3}\t{4}\t{5}\t{6}\t{7}\t{8}");
    
    public static void main(String args[]) {
        int nobs = 4;
        int rank = 0;
        double logDeterminant = 0.0;
        double ss = 0.0;
        double[] b = {4};
        double[] covb = {16};
        double[][] q = {{4}};
        double[][] r = {{1}};
        double[][] t = {{1}};
        double[][] z = {{1}};
        double[] ydata = {4.4, 4.0, 3.5, 4.6};
        
        Object argFormat[] = 
            {"k", "j", "b", "cov(b)", "rank", "ss", "ln(det)", "v", "cov(v)"};
        System.out.println(mf.format(argFormat));
        
        for (int i = 0;  i < nobs;  i++) {
            double y[] = {ydata[i]};
            KalmanFilter kalman = 
            new KalmanFilter(b, covb, rank, ss, logDeterminant);
            kalman.update(y, z, r);
            kalman.filter();
            b = kalman.getStateVector();
            covb = kalman.getCovB();
            rank = kalman.getRank();
            ss = kalman.getSumOfSquares();
            logDeterminant = kalman.getLogDeterminant();
            double v[] = kalman.getPredictionError();
            double covv[][] = kalman.getCovV();
            argFormat[0] = new Integer(i);
            argFormat[1] = new Integer(i);
            argFormat[2] = new Double(b[0]);
            argFormat[3] = new Double(covb[0]);
            argFormat[4] = new Integer(rank);
            argFormat[5] = new Double(ss);
            argFormat[6] = new Double(logDeterminant);
            argFormat[7] = new Double(v[0]);
            argFormat[8] = new Double(covv[0][0]);
            System.out.println(mf.format(argFormat));
            
            kalman = new KalmanFilter(b, covb, rank, ss, logDeterminant);
            kalman.setTransitionMatrix(t);
            kalman.setQ(q);
            kalman.filter();
            b = kalman.getStateVector();
            covb = kalman.getCovB();
            rank = kalman.getRank();
            ss = kalman.getSumOfSquares();
            logDeterminant = kalman.getLogDeterminant();
            argFormat[0] = new Integer(i+1);
            argFormat[1] = new Integer(i);
            argFormat[2] = new Double(b[0]);
            argFormat[3] = new Double(covb[0]);
            argFormat[4] = new Integer(rank);
            argFormat[5] = new Double(ss);
            argFormat[6] = new Double(logDeterminant);
            argFormat[7] = new Double(v[0]);
            argFormat[8] = new Double(covv[0][0]);
            System.out.println(mf.format(argFormat));
        }
    }
}

Output

k/j	b	cov(b)	rank	ss	ln(det)	v	cov(v)
0/0	4.376	0.941	1	0.009	2.833	0.4	17
1/0	4.376	4.941	1	0.009	2.833	0.4	17
1/1	4.063	0.832	2	0.033	4.615	-0.376	5.941
2/1	4.063	4.832	2	0.033	4.615	-0.376	5.941
2/2	3.597	0.829	3	0.088	6.378	-0.563	5.832
3/2	3.597	4.829	3	0.088	6.378	-0.563	5.832
3/3	4.428	0.828	4	0.26	8.141	1.003	5.829
4/3	4.428	4.828	4	0.26	8.141	1.003	5.829
Link to Java source.