JMSLTM Numerical Library 4.0

com.imsl.datamining.neural
Class BinaryClassification

java.lang.Object
  extended bycom.imsl.datamining.neural.BinaryClassification
All Implemented Interfaces:
Serializable

public class BinaryClassification
extends Object
implements Serializable

Classifies patterns into two classes.

Uses a FeedForwardNetwork to solve binary classification problems. In these problems, the target output for the network is the probability that the pattern falls into one of two classes. The first class, P(C_1), is usually equal to one and the second class, P(C_2) equal to zero. These probabilities are then used to assign patterns to one of the two classes. Typical applications include determining whether a credit applicant is a good or bad credit risk, and determining whether a person should or should not receive a particular treatment based upon their physical, clinical and laboratory information. This class signals that network training will minimize the binary cross-entropy error, and that network output is the probability that the pattern belongs to the first class, P(C_1). Which is calculated by applying the logistic activation function to the potential of the single output. The probability for the second class is calculated by P(C_2) = 1 - P(C_1).

See Also:
Example 1, Example 2, Serialized Form

Constructor Summary
BinaryClassification(Network network)
          Creates a binary classifier.
 
Method Summary
 double[] computeStatistics(double[][] xData, int[] yData)
          Computes the classification error statistics for the supplied network patterns and their associated classifications.
 QuasiNewtonTrainer.Error getError()
          Returns the error function for use by QuasiNewtonTrainer for training a binary classification network.
 Network getNetwork()
          Returns the network being used for classification.
 int predictedClass(double[] x)
          Calculates the classification probablities for the input pattern x, and returns either 0 or 1 identifying the class with the highest probability.
 double[] probabilities(double[] x)
          Returns classification probabilities for the input pattern x.
 void train(Trainer trainer, double[][] xData, int[] yData)
          Trains the classification neural network using supplied trainer and patterns.
 
Methods inherited from class java.lang.Object
clone, equals, finalize, getClass, hashCode, notify, notifyAll, toString, wait, wait, wait
 

Constructor Detail

BinaryClassification

public BinaryClassification(Network network)
Creates a binary classifier.

Parameters:
network - is the neural network used for classification. Its output perceptron should use the logistic activation function.
Method Detail

computeStatistics

public double[] computeStatistics(double[][] xData,
                                  int[] yData)
Computes the classification error statistics for the supplied network patterns and their associated classifications.

The first element returned is the binary cross-entropy error; the second is the classification error rate. The classification error rate is calculated by comparing the estimated classification probabilities to the target classifications. If the estimated probability for the target class is less than 0.5, then this is tallied as a classification error.

Parameters:
xData - A double matrix specifying the input training patterns. The number of columns in xData must equal the number of Nodes in the InputLayer.
yData - A double containing the output classification patterns. The number of columns in yData must equal the number of Perceptrons in the OutputLayer.
Returns:
A two-element double array containing the binary cross-entropy error and the classification error rate.

getError

public QuasiNewtonTrainer.Error getError()
Returns the error function for use by QuasiNewtonTrainer for training a binary classification network.

Returns:
an implementation of the binary-entropy error function.

getNetwork

public Network getNetwork()
Returns the network being used for classification.

Returns:
the network set by the constructor.

predictedClass

public int predictedClass(double[] x)
Calculates the classification probablities for the input pattern x, and returns either 0 or 1 identifying the class with the highest probability.

This method is used to classify patterns into one of the two target classes based upon the pattern's values. The predicted classification is the class with the largest probability, i.e. greater than 0.5.

Parameters:
x - the double array containing the network input patterns to classify. The length of x should be equal to the number of inputs in the network.
Returns:
The classification predicted by the trained network for x. This will be either 0 or 1.

probabilities

public double[] probabilities(double[] x)
Returns classification probabilities for the input pattern x.

Calculates the two probabilities for the pattern supplied: P(C_1) and P(C_2). The probability that the pattern belongs to the first class, P(C_1), is estmated using the logistic function of the output perceptron's potential. The probability for the second class is claculated as P(C_2) = 1 - P(C_1). The predicted classification is the class with the largest probability, i.e. greater than 0.5.

Parameters:
x - a double array containing the network input pattern to classify. The length of x must equal the number of nodes in the input layer.
Returns:
the probability of x being in class C_1, followed by the probability of x being in class C_2.

train

public void train(Trainer trainer,
                  double[][] xData,
                  int[] yData)
Trains the classification neural network using supplied trainer and patterns.

Parameters:
trainer - A Trainer object, which is used to train the network. The error function in any QuasiNewton trainer included in trainer should be set to the error function from this class using the getError method provided by this class.
xData - A double matrix containing the input training patterns. The number of columns in xData must equal the number of nodes in the input layer. Each row of xData contains a training pattern.
yData - An int array containing the output classification values. These values must be 0 or 1.

JMSLTM Numerical Library 4.0

Copyright 1970-2006 Visual Numerics, Inc.
Built June 1 2006.