import
java.io.BufferedWriter;
import
java.io.FileWriter;
import
java.io.IOException;
import java.lang.Math;
import
java.util.Random;
public class TwoClass
{
//Network architecture
static final int maxInputNodes = 2;
static final int maxOutputNodes = 2;
static final int maxHiddenNodes = 8;
static final int maxPointsForClassA = 1000;
static final int maxPointsForClassB = 1000;
static final int maxTrainingSetPoints = maxPointsForClassA + maxPointsForClassB;
static final int maxEpochIndex = 1000;
static final int maxTestPoints = 100;
static final int classA = 1;
static final int classB = 2;
static final double stepSize = 0.3;
static final double momentum = 0.7;
static double error;
static double [][]Wih; //Connection strength matrix between
input and hidden neurons
static double [][]Who; //Connection strength matrix between
hidden and output neurons
static double [][]OldWih; //Record weights from previous
iteration
static double [][]OldWho; //Record weights from previous
iteration
static double [][]dWih; //Delta adjustment for Connection
strength matrix between input and hidden neurons
static double [][]dWho; //Delta adjustment for Connection
strength matrix between hidden and output neurons
static double ci[][]; //Training set - coordinates of the
input points
static double at[][]; //Training set - targeted activation
values of the output nodes
static double ai[]; //Activation values of the input nodes
static double ah[]; //Activation values of the hidden
nodes
static double ao[]; //Activation values of the output
nodes
static double bh[]; //Bias values of the hidden nodes
static double bo[]; //Bias values of the output nodes
static double ea[]; //The error value between the actual
and targeted activation values of the output nodes
static double dout[]; //delta factor for Who matrix
static double dh[]; //delta factor for Wih matrix
static BufferedWriter Error_File, Training_Set_File, ClassA_Set_File, ClassB_Set_File;
public static void main(String[]
args)
{
Random rand = new Random();
//-------------------------------------------------------------------------
//PART1 - training the network
//-------------------------------------------------------------------------
//-------------------------------------------------------------------------
//Initializations
//-------------------------------------------------------------------------
Wih = new double[maxInputNodes][maxHiddenNodes];
Who = new double[maxHiddenNodes][maxOutputNodes];
OldWih = new double[maxInputNodes][maxHiddenNodes];
OldWho = new double[maxHiddenNodes][maxOutputNodes];
dWih = new double[maxInputNodes][maxHiddenNodes];
dWho = new double[maxHiddenNodes][maxOutputNodes];
at = new double[maxTrainingSetPoints][maxOutputNodes];
ai = new double[maxInputNodes];
ah = new double[maxHiddenNodes];
bh = new double[maxHiddenNodes];
ao = new double[maxOutputNodes];
bo = new double[maxOutputNodes];
ea = new double[maxOutputNodes];
dout = new double[maxOutputNodes];
dh = new double[maxHiddenNodes];
ci = new double[maxTrainingSetPoints][2];
//First the
weights are randomly initialized as small positive and negative values.
for (int i=0 ; i<maxInputNodes ; i++)
for (int h=0; h<maxHiddenNodes; h++)
Wih[i][h] = (2*(rand.nextDouble() /
10.0)) - 0.1;
for (int h=0 ; h<maxHiddenNodes ; h++)
for (int o=0; o<maxOutputNodes; o++)
Who[h][o] = (2*(rand.nextDouble() /
10.0)) - 0.1;
//Set the old
weights at zero
for (int i=0 ; i<maxInputNodes ; i++)
for (int h=0; h<maxHiddenNodes; h++)
OldWih[i][h] = 0.0;
for (int h=0 ; h<maxHiddenNodes ; h++)
for (int o=0; o<maxOutputNodes; o++)
OldWho[h][o] = 0.0;
//Adjust
biases of the hidden nodes
for(int i = 0 ; i < maxHiddenNodes ; i++)
{
bh[i]=
(2*(rand.nextDouble() / 10.0)) - 0.1;
}
//Adjust
biases of the output nodes
for(int i = 0 ; i < maxOutputNodes ; i++)
{
bo[i]=
(2*(rand.nextDouble() / 10.0)) - 0.1;
}
//Create the
training set with maxPointsForClassA points within the circle,
//and
maxPointsForClassB outside of it.
//Note that
the circle is centered at the origin with radius 1.
//This method
ensures that a minimum number of points is generated for each class.
int classACount =
0;
int classBCount =
0;
while(true)
{
int index = classACount +
classBCount;
double x =
4*rand.nextDouble() - 2.0; //-2 < x < 2
double y =
4*rand.nextDouble() - 2.0; //-2 < y < 2
if(Math.pow(x, 2) + Math.pow(y,
2) < 1)
{
if(classACount >= maxPointsForClassA)
continue;
ci[index][0] = x;
ci[index][1] = y;
at[index][0] = 1.0;
at[index][1] = 0.0;
classACount++;
}
else
{
if(classBCount >= maxPointsForClassB)
continue;
ci[index][0] = x;
ci[index][1] = y;
at[index][0] = 0.0;
at[index][1] = 1.0;
classBCount++;
}
if(classACount >=maxPointsForClassA &&
classBCount >= maxPointsForClassB)
break; //We have enough points for the
training set
}
try
{
Error_File = new BufferedWriter(new FileWriter("error.csv"));
Training_Set_File = new BufferedWriter(new FileWriter("training_set.csv"));
ClassA_Set_File = new BufferedWriter(new FileWriter("classA_set.csv"));
ClassB_Set_File = new BufferedWriter(new FileWriter("classB_set.csv"));
}
catch (IOException e)
{
}
//Print out the training set to a file
for(int i = 0 ; i < maxTrainingSetPoints ; i++)
try
{
Training_Set_File.write(String.format(
"%2.4f,
%2.4f\n", ci[i][0], ci[i][1]));
}
catch (IOException e)
{
}
for(int epochIndex = 0
; epochIndex < maxEpochIndex ; epochIndex++)
{
error = 0.0;
for(int
trainingSetIndex = 0 ; trainingSetIndex < maxTrainingSetPoints ;
trainingSetIndex++)
{
//-------------------------------------------------------------------------
//Feed Forward
//-------------------------------------------------------------------------
double point [] = {ci[trainingSetIndex][0],
ci[trainingSetIndex][1]};
feedForward(point);
//-------------------------------------------------------------------------
//Compute Error
//-------------------------------------------------------------------------
//Compute error between measured and targeted output activation
value
for(int i = 0 ; i < maxOutputNodes ; i++)
{
ea[i] = at[trainingSetIndex][i]
- ao[i];
error += Math.pow(ea[i], 2);
}
//-------------------------------------------------------------------------
//Back Propagation
//-------------------------------------------------------------------------
//Compute the error term of the selected output node
for(int i = 0 ; i < maxOutputNodes ; i++)
{
dout[i] = ea[i] * ao[i] * (1 - ao[i]) ;
}
//Compute the error term of the selected hidden node
for(int i = 0 ; i < maxHiddenNodes ; i++)
{
double BpInput = 0.0;
for (int j = 0; j < maxOutputNodes; j++)
{
BpInput += Who[i][j] * dout[j];
}
dh[i] = BpInput * ah[i] * (1 - ah[i]);
}
//Adjust weights between hidden and output nodes
for(int i = 0 ; i < maxHiddenNodes ; i++)
{
for (int j = 0; j < maxOutputNodes; j++)
{
dWho[i][j]= ah[i]* dout[j] * stepSize;
Who[i][j] = Who[i][j] + dWho[i][j] + momentum *( Who[i][j] - OldWho[i][j]);
OldWho[i][j] = Who[i][j];
}
}
//Adjust weights between input and hidden nodes
for(int i = 0 ; i < maxInputNodes ; i++)
{
for (int j = 0; j < maxHiddenNodes; j++)
{
dWih[i][j]= ai[i]* dh[j] * stepSize;
Wih[i][j] = Wih[i][j] + dWih[i][j] + momentum *( Wih[i][j] - OldWih[i][j]);
OldWih[i][j] = Wih[i][j];
}
}
//Adjust biases of the hidden nodes
for(int i = 0 ; i < maxHiddenNodes ; i++)
{
bh[i]+= dh[i] * stepSize;
}
//Adjust biases of the output nodes
for(int i = 0 ; i < maxOutputNodes ; i++)
{
bo[i]+= dout[i] * stepSize;
}
}
//Record the sum squared error after
each epoch.
try
{
Error_File.write(String.format(
"%d,
%2.4f\n", epochIndex, error));
}
catch (IOException e)
{
}
}
//-------------------------------------------------------------------------
//PART2 - testing the trained network
//-------------------------------------------------------------------------
int
incorrectClassification = 0;
for(int i = 0 ; i < maxTestPoints ; i++)
{
double inputPoint[] = new double[2];
inputPoint[0] = 4*rand.nextDouble() - 2.0; //-2 < x
< 2
inputPoint[1] = 4*rand.nextDouble() - 2.0; //-2 < y
< 2
int retval = feedForward(inputPoint);
if( isCircle(inputPoint)
&& retval != classA)
{
incorrectClassification++;
continue;
}
if( !isCircle(inputPoint)
&& retval != classB)
{
incorrectClassification++;
continue;
}
//Print out the point correctly
classified to a file
try
{
if(isCircle(inputPoint))
ClassA_Set_File.write(String.format(
"%2.4f,
%2.4f\n", inputPoint[0], inputPoint[1]));
else
ClassB_Set_File.write(String.format(
"%2.4f,
%2.4f\n", inputPoint[0], inputPoint[1]));
}
catch (IOException e)
{
}
}
System.out.println("End of Simulation");
System.out.format("%2.2f percent of the test points were incorrectly
classified", (double) ((double)incorrectClassification/(double)maxTestPoints * 100), maxTestPoints);
try
{
Error_File.close();
Training_Set_File.close();
ClassA_Set_File.close();
ClassB_Set_File.close();
}
catch (IOException e)
{
}
}
static int feedForward(double point[])
{
//Pass the input vector forward
for(int i = 0 ; i < maxHiddenNodes ; i++)
{
double netInput = bh[i];
for (int j = 0; j < maxInputNodes; j++)
{
ai[j] = point[j];
netInput += Wih[j][i] * ai[j];
}
ah[i] = 1.0 / (1.0 + Math.exp( -netInput));
}
//Pass the
hidden vector forward
for(int i = 0 ; i < maxOutputNodes ; i++)
{
double netInput = bo[i];
for (int j = 0; j < maxHiddenNodes; j++)
{
netInput += Who[j][i] * ah[j];
}
ao[i] = 1.0 / (1.0 + Math.exp( -netInput));
}
if(ao[0] > 0.99
&& ao[1] < 0.01)
return classA;
else
return classB;
}
static boolean isCircle(double point[])
{
if(Math.pow(point[0],
2) + Math.pow(point[1], 2) < 1)
{
return true;
}
else
{
return false;
}
}
}