package com.kloover.school.compvision;

public class SeamFinder {

	private int ROWS;
	private int COLS;
	
	private final static double NEG_ONE = -1.0;
	
	private double[][] memoization;
	private double[][] energies;
	
	public SeamFinder(double[][] energies) {
		this.energies = energies;
		ROWS = energies.length;
		COLS = energies[0].length;
		memoization = new double[ROWS][COLS];
		for (int i = 1; i < memoization.length; i++) {
			for (int j = 0; j < memoization[i].length; j++) {
				if (i == 0) {
					memoization[i][j] = energies[i][j];
				} else {
					memoization[i][j] = NEG_ONE;
				}
			}
		}

		for (int i = 0; i < COLS; i ++) {
			calcSeam(energies.length - 1, i);
		}
		
	}
	
	private double calcSeam(int i, int j) {
		if (i < 0) i = 0;
		if (j < 0) j = 0;
		if (memoization[i][j] == -1.0) {
			if (j == 0) {
				memoization[i][j] = energies[i][j] + 
					getMin( calcSeam(i-1, j),
							calcSeam(i-1, j+1));
			} else if (j == COLS - 1) {
				memoization[i][j] = energies[i][j] + 
					getMin( calcSeam(i-1, j),
							calcSeam(i-1, j-1));
			} else {
				memoization[i][j] = energies[i][j] +
					getMin( calcSeam(i-1, j-1),
							calcSeam(i-1, j),
							calcSeam(i-1, j+1));
			}
		}
		return memoization[i][j];
	}
	
	public double[][] findPath() {
		double[][] ret = memoization;
		int endPos = getMinIndex(ret[ROWS - 1]);
		ret[ROWS - 1][endPos] = NEG_ONE;
		for (int i = ROWS - 2; i >= 0; i--) {
			double upLeft = Double.MAX_VALUE;
			double up = memoization[i][endPos];
			double upRight = Double.MAX_VALUE;
			
			if (endPos - 1 >= 0) {
				upLeft = memoization[i][endPos - 1];
			}
			if (endPos + 1 < COLS) {
				upRight = memoization[i][endPos + 1];
			}
			double min = getMin(upLeft, up, upRight);
			if (min == upLeft) {
				ret[i][endPos - 1] = NEG_ONE;
				endPos--;
			} else if (min == up) {
				ret[i][endPos] = NEG_ONE;
			} else if (min == upRight) {
				ret[i][endPos + 1] = NEG_ONE;
				endPos++;
			}
		}
		return ret;
	}
	
	public double[][][] removeSeam(double[][][] image, double[][] nrgs, boolean resize) {
		double[][][] retVal;
		if (resize) {
			retVal = new double[nrgs.length][nrgs[0].length - 1][3];
		} else {
			retVal = new double[nrgs.length][nrgs[0].length][3];
		}
		for (int i = 0; i < nrgs.length; i++) {
			int y = 0;
			for (int j = 0; j < nrgs[0].length; j++) {
				if (nrgs[i][j] != -1.0) {
					retVal[i][y][0] = image[i][j][0];
					retVal[i][y][1] = image[i][j][1];
					retVal[i][y][2] = image[i][j][2];
					y++;
				} else {
					if (!resize) {
						retVal[i][y][0] = 0;
						retVal[i][y][1] = 0;
						retVal[i][y][2] = 0;
						y++;
					}
				}
			}
		}
		return retVal;
	}
	
	private double getMin(double... nums) {
		double lowVal = nums[0];
		for (int i = 1; i < nums.length; i++) {
			if (nums[i] < lowVal) lowVal = nums[i];
		}
		return lowVal;
	}
	
	private int getMinIndex(double... nums) {
		double lowVal = nums[0];
		int lowIndex = 0;
		for (int i = 0; i < nums.length; i++) {
			if (nums[i] < lowVal) {
				lowIndex = i;
				lowVal = nums[i];
			}
		}
		return lowIndex;
	}
	
}
