/*
 * Created on August 3rd, 2006
 * 
 * Dustin Stevens-Baier
 *
 *
 *Original Code by Adam Lee Modifications by Dustin

 */
package neuralnets;

import java.io.*;

/**
 *
 * The class implementing the Backward Propogation algorithm
 */
public class BackProp {

	private int numInputs = 2;
	private int numHidden = 8;
	private int numOutputs = 2;
	private int trainingPoints = 1000;
	private int testingPoints = 100;
	private int epoch = 1000;
	private double[][] inputWeights = new double[numInputs][numHidden];
	private double[][] outputWeights = new double[numHidden][numOutputs];
	private Node[] inputs = new Node[numInputs];
	private Node[] hidden = new Node[numHidden];
	private Node[] outputs = new Node[numOutputs];
	private Point[] training = new Point[trainingPoints];
	private Point[] testing = new Point[testingPoints];
	private double threshold = 0.1;
    private double[] rateOfSuccess = new double[trainingPoints];
	
	public void initializeNetwork() throws IOException {
		// Initialize the Weights between the input Nodes and the
		// hidden Nodes.
		FileWriter NodeWriter = new FileWriter("nodes.csv");
		NodeWriter.flush();
		NodeWriter.write("input weights\n");
		NodeWriter.write("input,hidden,weight\n");
        for (int x=0; x<numInputs; x++)
            for (int y=0; y<numHidden; y++) {
                inputWeights[x][y] = (0.02 * Math.random() - 0.01);
                NodeWriter.write(x + "," + y + "," + inputWeights[x][y] + "\n");
            }
        
        //Initialize the Weights between the hidden nodes and the
        // output nodes.
        NodeWriter.write("\noutput weights\n");
        NodeWriter.write("hidden,output,weight\n");
        for (int x=0; x<numHidden; x++)
            for (int y=0; y<numOutputs; y++) {
                outputWeights[x][y] = (0.02 * Math.random() - 0.01);
                NodeWriter.write(x + "," + y + "," + outputWeights[x][y] + "\n");
            }
        
        NodeWriter.write("\nInput Nodes\n");
        NodeWriter.write("number, bias\n");
        for (int x=0; x<numInputs; x++) {
            inputs[x] = new Node(0.0, 0.0);
            NodeWriter.write(x + "," + inputs[x].getBias() + "\n");
        }
        
        NodeWriter.write("\nHidden Nodes\nnumber,bias\n");
        for (int x=0; x<numHidden; x++) {
            hidden[x] = new Node();
            NodeWriter.write(x + "," + hidden[x].getBias() + "\n");
        }
        
        NodeWriter.write("\nOutput Nodes\nnumber,bias\n");
        for (int x=0; x<numOutputs; x++) {
            outputs[x] = new Node();
            NodeWriter.write(x + "," + outputs[x].getBias() + "\n");
        }
        NodeWriter.close();
	}
	
	public void trainNetwork() throws IOException {
		double tmpAct, tmpDelta;
        double gain = 0.2;
        int success = 0;
        int currentPoint;
        
        double[] hiddenDelta = new double[numHidden];
        double[] outputDelta = new double[numOutputs];
        double[] target = new double[numOutputs];
        FileWriter trPtWriter = new FileWriter("training.csv");
        trPtWriter.flush();
        trPtWriter.write("Training Points\nx,y\n");
        for (int i=0; i< trainingPoints; i++) {
        	training[i] = new Point();
        	trPtWriter.write(training[i].toString() + "\n");
        }
        trPtWriter.close();
        
        for (int x=0; x<epoch; x++) {
        	for (currentPoint=0; currentPoint<trainingPoints; currentPoint++) {
        		inputs[0].setActivation(training[currentPoint].getX());
        		inputs[1].setActivation(training[currentPoint].getY());
            
        		// start the feed-forward steps:
        		//
        		// calculate values of the hidden layer
        		for (int y=0; y<numHidden; y++) {
        			tmpAct = hidden[y].getBias();
        			for (int z=0; z<numInputs; z++) {
        				tmpAct += inputs[z].getActivation() * inputWeights[z][y];
        			}
        			hidden[y].setActivation(calcSigmoid(tmpAct));
        		}
            
        		// calculate values of the Output layer
        		for (int y=0; y<numOutputs; y++) {
        			tmpAct = outputs[y].getBias();
        			for (int z=0; z<numHidden; z++) {
        				tmpAct += hidden[z].getActivation() * outputWeights[z][y];
        			}
        			outputs[y].setActivation(calcSigmoid(tmpAct));
        		}
            
        		// end the feed-forward step
            
        		// Given point class (A or B) set target values for outputLayer.
        		if (training[currentPoint].getType() == 'A') {
        			target[0] = 1.0;
        			target[1] = 0.0;
        		} else {
        			target[0] = 0.0;
        			target[1] = 1.0;
        		}

        		// Check for errors
        		if ((Math.abs(target[0] - outputs[0].getActivation()) < threshold)
        				&& (Math.abs(target[1] - outputs[1].getActivation()) < threshold))
        			success++;

        		// Perform the backpropogation steps
        		//
        		for (int i=0; i < numOutputs; i++)
        			outputDelta[i] = (target[i] - outputs[i].getActivation())
                		* outputs[i].getActivation() * (1.0 - outputs[i].getActivation());
            
        		for (int y=0; y<numHidden; y++) {
        			tmpDelta = 0.0;
        			for (int z=0; z<numOutputs; z++) {    
        				tmpDelta += outputDelta[z] * outputWeights[y][z];
        			}
        			hiddenDelta[y] = tmpDelta * hidden[y].getActivation()
						* (1.0 - hidden[y].getActivation());
        		}

        		// Update weights (between input and hidden layers).
        		for (int y=0; y<numInputs; y++)
        			for (int z=0; z<numHidden; z++)
        				inputWeights[y][z] += gain * hiddenDelta[z]
							 * inputs[y].getActivation();
            
        		// Update weights (between hidden and output layers).
        		for (int y=0; y<numHidden; y++)
        			for (int z=0; z<numOutputs; z++)
        				outputWeights[y][z] += gain * outputDelta[z]
                                * hidden[y].getActivation();
            
        		// Update bias (hidden layer).
        		for (int y=0; y<numHidden; y++)
        			hidden[y].setBias(hidden[y].getBias()
        					+ (gain * hiddenDelta[y]));

        		// Update bias (output layer).
        		for (int y=0; y<numOutputs; y++)
        			outputs[y].setBias(outputs[y].getBias() 
        					+ (gain * outputDelta[y]));
            
        		// end of the back-propogation steps
        	}
            // Report Success or failure during training.
            rateOfSuccess[x] = success / 10.0;
            success = 0;
        }
        
	}
	
	public void testNetwork() throws IOException {
		double net;
		char predicted;
		char actual;
		int missed=0;
		for (int i = 0; i < testingPoints; i++)
		{
			testing[i] = new Point();
		}
		
		FileWriter testWriter = new FileWriter("testing.csv");
		testWriter.flush();
		testWriter.write("X,Y,predicted,actual\n");
		for (int t = 0; t < testingPoints; t++) {
			inputs[0].setActivation(testing[t].getX());
			inputs[1].setActivation(testing[t].getY());
		  
			// Calculate the net input to each hidden node
			// and update the output
			for (int i = 1; i < numHidden; i++) {
				net = hidden[i].getBias();
				for (int j = 0; j < numInputs; j++)
				{
					net += (inputWeights[j][i] * inputs[j].getActivation());
					//System.out.println(Double.toString(net));
				}
				hidden[i].setActivation(calcSigmoid(net));
				//System.out.println(Double.toString(net) + " " + Double.toString(hidden[i]));
			}
		  
			//Calculate the net input to each output node
			//and udate the output
			for (int i = 0; i < numOutputs; i++) {
				net = outputs[i].getBias();
				for (int j = 0; j < numHidden; j++) {
					net += (outputWeights[j][i] * hidden[j].getActivation());
				}
				outputs[i].setActivation(calcSigmoid(net));
			}
			
			if (outputs[0].getActivation() > 1 - threshold)
				predicted = 'A';
			else
				predicted = 'B';
			
			testWriter.write(testing[t].getX() + ",");
			testWriter.write(testing[t].getY() + ",");
			testWriter.write(predicted + ",");
			actual = testing[t].getType();
			testWriter.write(actual + "\n");
			if ( actual != predicted) {
				missed++;
			}

		}//End City Testing
		double percentCorrect = (double)(testingPoints - missed)/testingPoints;
		testWriter.write("\n% correct: " + percentCorrect + "\n");
		testWriter.close();
	}
	
	public void outputSuccess() throws IOException{
		FileWriter successWriter = new FileWriter("success.csv");
		successWriter.flush();
		successWriter.write("Percent Correct per Epoch\n");
		for (int i=0; i < trainingPoints; i++) {
			successWriter.write(rateOfSuccess[i] + "\n");
		}
		successWriter.close();
	}
	
    private double calcSigmoid(double inpX) {
        return ((double)(1.0/(1.0+Math.exp(inpX*(-1)))));
    }
	
	public static void main( String[] args) throws IOException{
		BackProp testBackProp = new BackProp();
		testBackProp.initializeNetwork();
		testBackProp.trainNetwork();
		testBackProp.outputSuccess();
		testBackProp.testNetwork();
		
	}
}
 
//***********************************************************************

class Node {
	private double activation;
	private double bias;
	
    public Node() {
        activation = Math.random();
        bias = (0.02 * Math.random() - 0.01);
    }

    public Node(double inActivation, double inBias) {
        activation = inActivation;
        bias = inBias;
    }
    
    public double getActivation() {
    	return activation;
    }

    public void setActivation(double inActivation) {
    	activation = inActivation;
    }

    public double getBias() {
    	return bias;
    }

    public void setBias(double inBias) {
    	bias = inBias;
    }
}

//*************************************************************

class Point {
    private double x,y;

    public Point() {
        x = 4*Math.random() - 2;
        y = 4*Math.random() - 2;
    }

    public double getX() {
        return x;
    }

    public double getY() {
        return y;
    }

    public String toString() {
        String tmpStr = new Double(getX()).toString();
        tmpStr = tmpStr.concat(", ");
        tmpStr = tmpStr.concat(new Double(getY()).toString());
        return tmpStr;
    }

    public char getType() {
        if ((x > 1) || (x < -1) ||
            (y > 1) || (y < -1))
            return 'B';
        else if ( Math.pow(Math.abs(x), 2) + Math.pow(Math.abs(y), 2) > 1 )
            return 'B';
        else
            return 'A';
        }
}


