/*
 * Copyright (C) 2003, 2004 Bjrn-Ove Heimsund
 * 
 * This file is part of SMT.
 * 
 * This library is free software; you can redistribute it and/or modify it
 * under the terms of the GNU Lesser General Public License as published by the
 * Free Software Foundation; either version 2.1 of the License, or (at your
 * option) any later version.
 * 
 * This library is distributed in the hope that it will be useful, but WITHOUT
 * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
 * FITNESS FOR A PARTICULAR PURPOSE. See the GNU Lesser General Public License
 * for more details.
 * 
 * You should have received a copy of the GNU Lesser General Public License
 * along with this library; if not, write to the Free Software Foundation,
 * Inc., 59 Temple Place, Suite 330, Boston, MA 02111-1307 USA
 */

package smt.iter.prec;

import java.util.Arrays;

import mt.DenseVector;
import mt.Matrix;
import mt.VectorEntry;
import smt.FlexCompRowMatrix;
import smt.SparseVector;

import java.util.Iterator;

/**
 * Partial preconditioner for decomposition based methods. Contains some
 * methods to make incomplete decomposition based preconditioners easier to
 * implement, such as ILU and ICC.
 */
abstract class DecompositionPreconditioner implements Preconditioner {

	/**
	 * The decomposition matrix
	 */
	FlexCompRowMatrix F;

	/**
	 * Diagonal indices to the decomposition matrix F.
	 */
	int[] diagInd;

	/**
	 * Diagonal shift
	 */
	double shift;

	/**
	 * Initial diagonal shift
	 */
	double initShift;

	/**
	 * Constructor for DecompositionPreconditioner
	 * 
	 * @param A
	 *            Matrix to decompose. Not changed
	 * @param initShift
	 *            Initial diagonal shift
	 * @param shift
	 *            Successive diagonal shift
	 */
	public DecompositionPreconditioner(
		Matrix A,
		double initShift,
		double shift) {
		this.initShift = initShift;
		this.shift = shift;

		F = new FlexCompRowMatrix(A.numRows(), A.numColumns());

		setMatrix(A);
	}

	/**
	 * Constructor for DecompositionPreconditioner. Uses a default value of 0.1
	 * as successive shift, and no initial shift
	 * 
	 * @param A
	 *            Matrix to decompose. Not changed
	 */
	public DecompositionPreconditioner(Matrix A) {
		this(A, 0, 0.1);
	}

	public void setMatrix(Matrix A) {
		if (F.numRows() != A.numRows() || F.numColumns() != A.numColumns())
			throw new IllegalArgumentException("Decomposition size differs from matrix size");

		// Subclass performs the factorization
		factor(A);

		// Compress the matrix, and find the new diagonal indices
		F.compact();
		Arrays.fill(diagInd, -1);
		for (int i = 0; i < F.numRows(); ++i)
			getDiagInd(i);
	}

	/**
	 * This method should create the factor F of A. After the call, diagInd
	 * will be set by the caller
	 */
	abstract void factor(Matrix A);

	/**
	 * F is lower-triangular, and it is solved for.
	 * 
	 * @param data
	 *            Initially the right hand side. Overwritten with solution
	 * @param diagDiv
	 *            True if the diagonal will be divided with
	 */
	void solveL(double[] data, boolean diagDiv) {
		for (int i = 0; i < F.numRows(); ++i) {

			SparseVector cur = (SparseVector) F.getRow(i);
			int[] curInd = cur.getIndex();
			double[] curDat = cur.getData();

			double sum = 0;
			int j = 0;
			for (; curInd[j] < i; ++j)
				sum += curDat[j] * data[curInd[j]];

			data[i] -= sum;

			// Divide by diagonal. The factorization guarantees its existence
			if (diagDiv)
				data[i] /= curDat[j];
		}
	}

	/**
	 * F is upper-triangular, and it is solved for.
	 * 
	 * @param data
	 *            Initially the right hand side. Overwritten with solution
	 * @param diagDiv
	 *            True if the diagonal will be divided with
	 */
	void solveU(double[] data, boolean diagDiv) {
		for (int i = F.numRows() - 1; i >= 0; --i) {

			SparseVector cur = (SparseVector) F.getRow(i);
			int[] curInd = cur.getIndex();
			double[] curDat = cur.getData();

			double sum = 0;
			for (int j = diagInd[i] + 1; j < curInd.length; ++j)
				sum += curDat[j] * data[curInd[j]];

			data[i] -= sum;

			// Divide by diagonal. The factorization guarantees its existence
			if (diagDiv)
				data[i] /= curDat[diagInd[i]];
		}
	}

	/**
	 * F is lower-triangular, and F <sup>T</sup> is solved for.
	 * 
	 * @param data
	 *            Initially the right hand side. Overwritten with solution
	 * @param diagDiv
	 *            True if the diagonal will be divided with
	 */
	void solveLT(double[] data, boolean diagDiv) {
		for (int i = F.numRows() - 1; i >= 0; --i) {

			SparseVector cur = (SparseVector) F.getRow(i);
			int[] curInd = cur.getIndex();
			double[] curDat = cur.getData();

			// Solve at current position
			if (diagDiv)
				data[i] /= curDat[diagInd[i]];
			double val = data[i];

			// Move over to right hand side
			for (int j = 0; curInd[j] < i; ++j)
				data[curInd[j]] -= curDat[j] * val;
		}
	}

	/**
	 * F is upper-triangular, and F <sup>T</sup> is solved for.
	 * 
	 * @param data
	 *            Initially the right hand side. Overwritten with solution
	 * @param diagDiv
	 *            True if the diagonal will be divided with
	 */
	void solveUT(double[] data, boolean diagDiv) {
		for (int i = 0; i < F.numRows(); ++i) {

			SparseVector cur = (SparseVector) F.getRow(i);
			int[] curInd = cur.getIndex();
			double[] curDat = cur.getData();

			// Solve at current position
			if (diagDiv)
				data[i] /= curDat[diagInd[i]];
			double val = data[i];

			// Move over to right hand side
			for (int j = diagInd[i] + 1; j < curInd.length; ++j)
				data[curInd[j]] -= curDat[j] * val;
		}
	}

	/**
	 * Gets diagonal index from cache (if it exists there), or updates the
	 * cache with diagonal index.
	 * 
	 * @param row
	 *            Row to get diagonal index for
	 * @return diagInd[row]
	 */
	int getDiagInd(int row) {
		return getDiagInd(row,  ((SparseVector)F.getRow(row)).getIndex(), ((SparseVector)F.getRow(row)).used());
	}

	/**
	 * Gets diagonal index from cache (if it exists there), or updates the
	 * cache with diagonal index.
	 * 
	 * @param row
	 *            Row to get diagonal index for
	 * @param rowInd
	 *            The row indices to search
	 * @param length
	 *            Length of the given row
	 * @return diagInd[row]
	 */
	int getDiagInd(int row, int[] rowInd, int length) {
		if (diagInd[row] < 0)
			diagInd[row] = smt.util.Arrays.binarySearch(rowInd, row, 0, length);
		return diagInd[row];
	}

	/**
	 * Populates x with the non-zero entries from y. x is not zerod
	 */
	void scatter(DenseVector x, SparseVector y) {
		/*for (VectorEntry e : y)
			x.set(e.index(), e.get());*/
		VectorEntry e;
		Iterator iter = y.iterator();
		while(iter.hasNext()) {
			e = (VectorEntry) iter.next();
			x.set(e.index(), e.get());
		}
	}

	/**
	 * Sets the non-zero entries of x equal those of y. The rest of the entries
	 * in y are ignored
	 */
	void gather(SparseVector x, DenseVector y) {
		/*for (VectorEntry e : x)
			e.set(y.get(e.index()));*/
		VectorEntry e;
		Iterator iter = x.iterator();
		while(iter.hasNext()) {
			e = (VectorEntry) iter.next();
			e.set(y.get(e.index()));
		}
	}

}
