
package neuralnets; 

import java.util.*;
import java.io.*;
/**
 * @author Dustin Stevens-Baier
 * 
 * original implmentation by adam lee
 * 
 * added distance code and hard coded random vs set city
 *
 * 
 */
public class TSP {
	
	private int n = 20;
	private int ITERATIONS = 10000;
	private Random rand;
	//private double n[][]; //neuron grid
	private double w[][][][] = new double[n][n][n][n]; //connection strengths
	private double a[][] = new double[n][n]; // activation matrix
	private double d[][] = new double[n][n]; // the distance between cities matrix
	private double energy[] = new double[ITERATIONS];
	private double step = 1.0;
	private double M = 1.0;
	private double m = (-1.0)/((double)n - 1.0);
	private double THRESHOLD = 0.6;
	private static int CIRCULAR = 0;
	private static int RANDOM = 1;
	private City cities[];
	private Tour tour[];
	
		
	
	public TSP(int simType)
	{
		cities = new City[n];
		tour = new Tour[n];
		rand = new Random();
		
		if (simType == CIRCULAR)
			generateCircularCities();
		else if (simType== RANDOM)
			generateRandomCities();
		processCityData(); // print the data to console, datfile, and html
		fillDistances();
		initConnections();
		initActivations();
		converge();
		processActivationData("Final");
		buildTour();
		processTourData();
	}

	private void generateCircularCities() {
		double angle = 0.0;
		for (int i= 0; i<n; i++) {
			double x = Math.cos(angle)*.25 + .5;
			double y = Math.sin(angle)*.25 + .5;
			angle = angle + (Math.PI*2)/n;
			cities[i] = new City(x, y);
		}
	}
	
	private void generateRandomCities() {
		double x = rand.nextDouble();
		double y = rand.nextDouble();
		for (int i = 0; i<n; i++) {
			cities[i] = new City(x, y);
			x = rand.nextDouble();
			y = rand.nextDouble();
		}
	}
	
	private void fillDistances() {
		double xDist, yDist;
		for (int i = 0; i < n; i++) {
			for (int j = 0; j < n; j++) {
				xDist = Math.pow(cities[i].getX() - cities[j].getX(),2);
				yDist = Math.pow(cities[i].getY() - cities[j].getY(),2);
				d[i][j] = Math.sqrt(xDist + yDist);
			}
		}
	}
	
	private void initConnections() {
		for (int i=0; i<n; i++) {
			for (int x=0; x<n; x++) {
				for (int j=0; j<n; j++) {
					for (int y=0; y<n; y++) {
						// Row/Column connections
						if (((y == x) && (j != i))
						|| ((y != x) && (j == i))) {
							w[i][x][j][y] = (double)(1.0/(n*n) - 1.0/n);
                        // Self connections
						}
						else if ((y == x) && (j == i)) {
							w[i][x][j][y] = (double)(1.0/(n*n) - 2.0/n);

				        // Distance (neighboring columns) connections
						}
						else if ((y != x) && ((j == (i+1)) || (j == (i-1))) ) {
							w[i][x][j][y] = (double)(1.0/(n*n)-(d[x][y])/n);

				        // All other connections
						}
						else {
							w[i][x][j][y] = (double)(1.0/(n*n));
						}
					}
				}
			}
		}
	}
	
	private void initActivations() {
		for (int i=0; i<n; i++) {
			for (int x=0; x<n; x++) {
				a[i][x] = (10e-10)*rand.nextDouble();
			}
		}
	}
	
	private void converge() {
		int index = 0;
		int convergeCounter = 0;
		boolean converged;
		double temp = 0.0;
		double net = 0.0;

		while ( (index < ITERATIONS) && (convergeCounter < n)) {
			convergeCounter = 0;
			for (int i = 0; i< n; i++) {
				converged = false;
				for (int x=0; x< n; x++) {
					net = 0.0;
					for (int j=0; j < n; j++) {
						for (int y=0; y < n; y++) {
							net += w[i][x][j][y] * a[j][y];
						}
					}
					temp = a[i][x];
					a[i][x] = temp + step * (temp - m)*(M - temp)*net;

					if ( !converged && a[i][x] >= THRESHOLD )
						converged = true;
				}
				if (converged)
					convergeCounter++;
			}

			updateEnergy();

			index++;
//			if (index % 50 == 0) System.out.println("index: " +index + "converged count: " + convergeCounter);
		}
		System.out.println("done converging, iterations: " + index);
	}
	
	private void updateEnergy() {
		// Update energy array
		double energySum = 0.0;

		for (int i=0; i<n; i++) {
			for (int x=0; x<n; x++) {
				for (int j=0; j<n; j++) {
					for (int y=0; y<n; y++) {
						energySum += a[i][x]*a[j][y]*w[i][x][j][y];
					}
				}
			}
		}
	}
	
	private void buildTour() {
		int i_max, tmpIdx, base = 2;
		double max_act, sum1, sum2;

		// Find maximum value/index
		for (int j=0; j < n; j++) {
			max_act = -999999;
			i_max = -1;

			for (int i=0; i<n; i++) {
				if (a[i][j] > max_act) {
					max_act = a[i][j];
					i_max = i;
				}
			}

			// Calculate the center of mass
			sum1 = sum2 = 0.0;
			for (int k=(i_max - base); k < (i_max + base); k++) {
				if (a[(n+k)%n][j] > 0) {
					sum1 += k * (a[((int)n+k)%(int)n][j]);
					sum2 += (a[((int)n+k)%(int)n][j]);
				}
			}

			// Create Tour object
			tour[j] = new Tour(sum1/sum2, j);
		}

		Tour temptour = null;
		
		// Bubble sort courtesy of James Gosling
		for (int i = tour.length; --i>0;) {
			for (int j=0; j<i; j++) {
				sum1 = tour[j].getCenterOfMass();
				sum2 = tour[(j+1)].getCenterOfMass();
				if(sum1 > sum2) {
					temptour = tour[j];
					tour[j] = tour[j+1];
					tour[j+1] = temptour;
				}
			}
		}
		System.out.println("Done with tour");
		//Done with tour
	}

	private void processCityData() {
		FileWriter datWriter = null;
		FileWriter htmlWriter = null;
		double xStr, yStr;
		int newi;
		try {
			datWriter = new FileWriter("city2.dat");
			htmlWriter = new FileWriter("city2.html");
			htmlWriter.write("<html><body><table border=1><tr><td colspan=3 bgcolor=lightblue>\n");
			htmlWriter.write("<center><b>City Positions</b></center></td></tr>\n");
			htmlWriter.write("<tr><td bgcolor=silver><b>City</b></td>\n");
			htmlWriter.write("<td bgcolor=silver><b>X</b></td>\n");
			htmlWriter.write("<td bgcolor=silver><b>Y</b></td></tr>\n");
			for (int i=0; i<n; i++) {
				xStr = cities[i].getX();
				yStr = cities[i].getY();
				newi = i+1;
				System.out.println( newi + "\tx: " + xStr + "\ty: " + yStr );
				datWriter.write(xStr + "," + yStr+"\n");
				htmlWriter.write("<tr><td>"+ newi + "</td>\n");
				htmlWriter.write("<td>" + xStr + "</td>\n");
				htmlWriter.write("<td>" + yStr + "</td></tr>\n");
			}
			htmlWriter.write("</table></body></html>\n");
			datWriter.close();
			htmlWriter.close();
		}
		catch (Exception e) {
			e.printStackTrace();
		}
	}
	
	private void processActivationData(String level) {
		FileWriter htmlWriter = null;
		int temp=0;
		try {
			temp = n+1;
			htmlWriter = new FileWriter(level + "2.html");
			htmlWriter.write("<html><body><table border=1><td colspan=" + temp + " bgcolor=lightblue>\n");
			htmlWriter.write("<center><b>Final Activations</b></center></td></tr>\n");			
			htmlWriter.write("<tr><td>&nbsp;</td>");
			for (int i=0; i<n; i++) {
				temp = i+1;
				htmlWriter.write("<td bgcolor=silver>"+ temp + "</td>\n");
			}
			htmlWriter.write("</tr>\n");
			for (int i=0; i<n; i++) {
				temp = i+1;
				htmlWriter.write("<tr><td bgcolor=silver>" + temp + "</td>\n");
				for (int j=0; j<n; j++) {
					htmlWriter.write("<td>" + a[i][j] + "</td>\n");
					System.out.print(a[i][j] + "\t");
				}
				htmlWriter.write("</tr>\n");
				System.out.print("\n");
			}
			htmlWriter.write("</table></body></html>\n");
			htmlWriter.close();
		}
		catch (Exception e) {
			e.printStackTrace();
		}

	}
	
	private void processTourData() {
		FileWriter htmlWriter = null;
		int temp = 0;
		int cityIndex,lastCity;
		double incDistance=0;
		double totDistance =0;
		double xcoord = 0.0;
		double ycoord = 0.0;
		City tempCity = null;
		try {
			htmlWriter = new FileWriter("Tour2.html");
			htmlWriter.write("<html><body><table border=1><tr><td colspan=6 bgcolor=lightblue>\n\n");
			htmlWriter.write("<center><b>Tour Results</b></center></td></tr>\n");
			htmlWriter.write("<tr><td bgcolor=silver>Step</td><td bgcolor=silver>City</td>\n");
			htmlWriter.write("<td bgcolor=silver>X coord</td><td bgcolor=silver>Y coord</td>\n");
			htmlWriter.write("<td bgcolor=silver>Incremental Distance</td>\n");
			htmlWriter.write("<td bgcolor=silver>Total Distance</td></tr>\n");
			for (int i=0; i<n; i++) {
				temp = i+1;
				cityIndex = tour[i].getCityIndex();
				
				tempCity = cities[cityIndex];
				if (temp > 1)
				{
					lastCity = tour[i-1].getCityIndex();
					incDistance = d[cityIndex][lastCity];
					totDistance = totDistance + incDistance;
				}
				xcoord = tempCity.getX();
				ycoord = tempCity.getY();
				htmlWriter.write("<tr><td>" + temp + "</td>\n");
				htmlWriter.write("<td>" + cityIndex + "</td>\n");
				htmlWriter.write("<td>" + xcoord + "</td>\n");
				htmlWriter.write("<td>" + ycoord + "</td>\n");
				if(temp ==1)
					htmlWriter.write("<td>N/A</td>\n");
				else
					htmlWriter.write("<td>" +  incDistance  + "</td>\n");
				htmlWriter.write("<td>"+totDistance+"</td></tr>\n");
			}
			htmlWriter.write("</table></body></html>");
			htmlWriter.close();
		}
		catch (Exception e) {
			e.printStackTrace();
		}
	}
	
	
	public static void main(String[] args) {
		
		
		int simType = 0;
		//int simType =1;
		
		//int simType = Integer.parseInt(args[0]);
		TSP tspSIM = new TSP(simType);
	}
}

class City {
	double x, y;

	City (double x, double y)
	{
		this.x = x;
		this.y = y;
	}
	
	double getX()
	{
		return this.x;
	}
	
	double getY()
	{
		return this.y;
	}

}

class Tour {

    Tour(double centerOfMass, int cityIdx) {
        this.centerOfMass = centerOfMass;
        this.cityIdx = cityIdx;
        double totalDistance[] = new double[20];
    }

    /** From the center of mass calculation */
    double centerOfMass;

    /** Corresponding city index */
    int cityIdx;
    
    double getCenterOfMass() {
    	return centerOfMass;
    }
    
    int getCityIndex() {
    	return cityIdx;
    }
    
    void setCenterOfMass(double center) {
    	this.centerOfMass = center;
    }
    
    void setCityIndex( int index) {
    	this.cityIdx = index;
    }
}

