import java.io.BufferedWriter;

import java.io.FileWriter;

import java.io.IOException;

import java.lang.Math;

import java.util.Random;

 

public class TwoClass

{

      //Network architecture

      static final int maxInputNodes = 2;

      static final int maxOutputNodes = 2;

      static final int maxHiddenNodes = 8;

      static final int maxPointsForClassA = 1000;

      static final int maxPointsForClassB = 1000;

      static final int maxTrainingSetPoints = maxPointsForClassA + maxPointsForClassB;

      static final int maxEpochIndex = 1000;

      static final int maxTestPoints = 100;

     

      static final int classA = 1;

      static final int classB = 2;

     

      static final double stepSize = 0.3;

      static final double momentum = 0.7;

           

      static double error;

      static double [][]Wih;        //Connection strength matrix between input and hidden neurons

      static double [][]Who;        //Connection strength matrix between hidden and output neurons

      static double [][]OldWih;     //Record weights from previous iteration

      static double [][]OldWho;     //Record weights from previous iteration

      static double [][]dWih;       //Delta adjustment for Connection strength matrix between input and hidden neurons

      static double [][]dWho;       //Delta adjustment for Connection strength matrix between hidden and output neurons

     

      static double ci[][];   //Training set - coordinates of the input points

      static double at[][];   //Training set - targeted activation values of the output nodes

      static double ai[];           //Activation values of the input nodes

      static double ah[];           //Activation values of the hidden nodes

      static double ao[];           //Activation values of the output nodes

      static double bh[];           //Bias values of the hidden nodes

      static double bo[];           //Bias values of the output nodes

 

      static double ea[];           //The error value between the actual and targeted activation values of the output nodes

      static double dout[];         //delta factor for Who matrix

      static double dh[];           //delta factor for Wih matrix

      static BufferedWriter Error_File, Training_Set_File, ClassA_Set_File, ClassB_Set_File;

     

      public static void main(String[] args)

      {

            Random rand = new Random();

 

            //-------------------------------------------------------------------------

            //PART1 -  training the network

            //-------------------------------------------------------------------------

           

            //-------------------------------------------------------------------------

            //Initializations

            //-------------------------------------------------------------------------

            Wih = new double[maxInputNodes][maxHiddenNodes];

            Who = new double[maxHiddenNodes][maxOutputNodes];

            OldWih = new double[maxInputNodes][maxHiddenNodes];

            OldWho = new double[maxHiddenNodes][maxOutputNodes];

            dWih = new double[maxInputNodes][maxHiddenNodes];

            dWho = new double[maxHiddenNodes][maxOutputNodes];

           

            at = new double[maxTrainingSetPoints][maxOutputNodes];

            ai = new double[maxInputNodes];

            ah = new double[maxHiddenNodes];

            bh = new double[maxHiddenNodes];

            ao = new double[maxOutputNodes];

            bo = new double[maxOutputNodes];

            ea = new double[maxOutputNodes];

            dout = new double[maxOutputNodes];

            dh = new double[maxHiddenNodes];

 

            ci = new double[maxTrainingSetPoints][2];

           

          //First the weights are randomly initialized as small positive and negative values.

            for (int i=0 ; i<maxInputNodes ; i++)

            for (int h=0; h<maxHiddenNodes; h++)

                  Wih[i][h] = (2*(rand.nextDouble() / 10.0)) - 0.1;

 

          for (int h=0 ; h<maxHiddenNodes ; h++)

            for (int o=0; o<maxOutputNodes; o++)

                  Who[h][o] = (2*(rand.nextDouble() / 10.0)) - 0.1;

         

          //Set the old weights at zero

          for (int i=0 ; i<maxInputNodes ; i++)

            for (int h=0; h<maxHiddenNodes; h++)

                  OldWih[i][h] = 0.0;

 

          for (int h=0 ; h<maxHiddenNodes ; h++)

            for (int o=0; o<maxOutputNodes; o++)

                  OldWho[h][o] = 0.0;

 

          //Adjust biases of the hidden nodes

          for(int i = 0 ; i < maxHiddenNodes ; i++)

          {

            bh[i]= (2*(rand.nextDouble() / 10.0)) - 0.1;

          }                                  

         

          //Adjust biases of the output nodes

          for(int i = 0 ; i < maxOutputNodes ; i++)

          {

            bo[i]= (2*(rand.nextDouble() / 10.0)) - 0.1;

          }                                  

 

          //Create the training set with maxPointsForClassA points within the circle,

          //and maxPointsForClassB outside of it.

          //Note that the circle is centered at the origin with radius 1. 

          //This method ensures that a minimum number of points is generated for each class.

         

          int classACount = 0;

          int classBCount = 0;

          while(true)

          {      

            int index = classACount + classBCount;

            double x = 4*rand.nextDouble() - 2.0; //-2 < x < 2

            double y = 4*rand.nextDouble() - 2.0; //-2 < y < 2

           

            if(Math.pow(x, 2) + Math.pow(y, 2) < 1)

            {

                  if(classACount >= maxPointsForClassA)

                        continue;

                  ci[index][0] = x;

                  ci[index][1] = y;

                  at[index][0] = 1.0;

                  at[index][1] = 0.0;

                  classACount++;

            }

            else

            {

                  if(classBCount >= maxPointsForClassB)

                        continue;

                  ci[index][0] = x;

                  ci[index][1] = y;

                  at[index][0] = 0.0;

                  at[index][1] = 1.0;

                  classBCount++;

            }          

            if(classACount >=maxPointsForClassA && classBCount >= maxPointsForClassB)

                  break;      //We have enough points for the training set         

          }

 

          try

            {

                  Error_File = new BufferedWriter(new FileWriter("error.csv"));

                  Training_Set_File = new BufferedWriter(new FileWriter("training_set.csv"));

                  ClassA_Set_File = new BufferedWriter(new FileWriter("classA_set.csv"));

                  ClassB_Set_File = new BufferedWriter(new FileWriter("classB_set.csv"));

            }

            catch (IOException e)

            {

          }

 

            //Print out the training set to a file

            for(int i = 0 ; i < maxTrainingSetPoints ; i++)

                  try

                  {

                        Training_Set_File.write(String.format( "%2.4f, %2.4f\n", ci[i][0], ci[i][1]));

                  }

                  catch (IOException e)

                  {

                }           

           

           

            for(int epochIndex = 0 ; epochIndex < maxEpochIndex ; epochIndex++)

          {

            error = 0.0;

          for(int trainingSetIndex = 0 ; trainingSetIndex < maxTrainingSetPoints ; trainingSetIndex++)

          {   

                  //-------------------------------------------------------------------------

                  //Feed Forward

                  //-------------------------------------------------------------------------

            double point [] = {ci[trainingSetIndex][0], ci[trainingSetIndex][1]};        

            feedForward(point);

                   

                  //-------------------------------------------------------------------------

                  //Compute Error

                  //-------------------------------------------------------------------------

                //Compute error between measured and targeted output activation value

                for(int i = 0 ; i < maxOutputNodes ; i++)

                {

                  ea[i] = at[trainingSetIndex][i] - ao[i];

                  error += Math.pow(ea[i], 2);

                }

                 

               

                  //-------------------------------------------------------------------------

                  //Back Propagation

                  //-------------------------------------------------------------------------

                //Compute the error term of the selected output node

                for(int i = 0 ; i < maxOutputNodes ; i++)

                {

                  dout[i] = ea[i] * ao[i] * (1 - ao[i]) ;

                }

               

                //Compute the error term of the selected hidden node

                for(int i = 0 ; i < maxHiddenNodes ; i++)

                {

                  double BpInput = 0.0;        

                  for (int j = 0; j < maxOutputNodes; j++)

                      {

                        BpInput += Who[i][j] * dout[j];

                      }

                    dh[i] = BpInput * ah[i] * (1 - ah[i]);

                }

               

                //Adjust weights between hidden and output nodes

                for(int i = 0 ; i < maxHiddenNodes ; i++)

                {

                  for (int j = 0; j < maxOutputNodes; j++)

                      {

                        dWho[i][j]= ah[i]* dout[j] * stepSize;

                        Who[i][j] = Who[i][j] + dWho[i][j] + momentum *( Who[i][j] - OldWho[i][j]);

                        OldWho[i][j] = Who[i][j];

                      }              

                }

               

                //Adjust weights between input and hidden nodes

                for(int i = 0 ; i < maxInputNodes ; i++)

                {

                  for (int j = 0; j < maxHiddenNodes; j++)

                      {

                        dWih[i][j]= ai[i]* dh[j] * stepSize;

                        Wih[i][j] = Wih[i][j] + dWih[i][j] + momentum *( Wih[i][j] - OldWih[i][j]);

                        OldWih[i][j] = Wih[i][j];

                      }

                }                                                          

 

                //Adjust biases of the hidden nodes

                for(int i = 0 ; i < maxHiddenNodes ; i++)

                {

                  bh[i]+= dh[i] * stepSize;

                }                                  

               

                //Adjust biases of the output nodes

                for(int i = 0 ; i < maxOutputNodes ; i++)

                {

                  bo[i]+= dout[i] * stepSize;

                }                                  

          } 

            //Record the sum squared error after each epoch.

            try

            {

                  Error_File.write(String.format( "%d, %2.4f\n", epochIndex, error));

            }

            catch (IOException e)

            {

          }           

          }    

         

            //-------------------------------------------------------------------------

            //PART2 -  testing the trained network

            //-------------------------------------------------------------------------      

          int incorrectClassification = 0;

          for(int i = 0 ; i < maxTestPoints ; i++)

          {

            double inputPoint[] = new double[2];

            inputPoint[0] = 4*rand.nextDouble() - 2.0; //-2 < x < 2

            inputPoint[1] = 4*rand.nextDouble() - 2.0; //-2 < y < 2          

           

            int retval = feedForward(inputPoint);

            if( isCircle(inputPoint) && retval != classA)

            {

                  incorrectClassification++;

                  continue;

            }

            if( !isCircle(inputPoint) && retval != classB)

            {

                  incorrectClassification++;

                  continue;

            }

                 

                  //Print out the point correctly classified to a file

                  try

                  {

                        if(isCircle(inputPoint))

                              ClassA_Set_File.write(String.format( "%2.4f, %2.4f\n", inputPoint[0], inputPoint[1]));

                        else

                              ClassB_Set_File.write(String.format( "%2.4f, %2.4f\n", inputPoint[0], inputPoint[1]));

                  }

                  catch (IOException e)

                  {

                }                  

          }

         

         System.out.println("End of Simulation");

         System.out.format("%2.2f percent of the test points were incorrectly classified", (double) ((double)incorrectClassification/(double)maxTestPoints * 100), maxTestPoints);

         

            try

            {

                  Error_File.close();

                  Training_Set_File.close();

                  ClassA_Set_File.close();

                  ClassB_Set_File.close();

            }

            catch (IOException e)

            {

          }    

      }

 

      static int feedForward(double point[])

      {

            //Pass the input vector forward

          for(int i = 0 ; i < maxHiddenNodes ; i++)

          {

            double netInput = bh[i];           

            for (int j = 0; j < maxInputNodes; j++)

                {

                  ai[j] = point[j];

                  netInput += Wih[j][i] * ai[j];

                }

              ah[i] = 1.0 / (1.0 + Math.exp( -netInput));

          }

             

          //Pass the hidden vector forward

          for(int i = 0 ; i < maxOutputNodes ; i++)

          {

            double netInput = bo[i];           

            for (int j = 0; j < maxHiddenNodes; j++)

                {

                  netInput += Who[j][i] * ah[j];

                }

              ao[i] = 1.0 / (1.0 + Math.exp( -netInput));

          }

          if(ao[0] > 0.99 && ao[1] < 0.01)

            return classA;       

          else

            return classB;

      }

 

      static boolean isCircle(double point[])

      {

      if(Math.pow(point[0], 2) + Math.pow(point[1], 2) < 1)

      {

            return true;

      }

      else

      {

            return false;

      }                

      }

     

}

 

Hosted by www.Geocities.ws

1