/*
 * Copyright (C) 2003, 2004 Bjrn-Ove Heimsund
 * 
 * This file is part of SMT.
 * 
 * This library is free software; you can redistribute it and/or modify it
 * under the terms of the GNU Lesser General Public License as published by the
 * Free Software Foundation; either version 2.1 of the License, or (at your
 * option) any later version.
 * 
 * This library is distributed in the hope that it will be useful, but WITHOUT
 * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
 * FITNESS FOR A PARTICULAR PURPOSE. See the GNU Lesser General Public License
 * for more details.
 * 
 * You should have received a copy of the GNU Lesser General Public License
 * along with this library; if not, write to the Free Software Foundation,
 * Inc., 59 Temple Place, Suite 330, Boston, MA 02111-1307 USA
 */

package smt.iter.prec;

import mt.DenseVector;
import mt.Matrix;
import mt.Vector;
import smt.FlexCompRowMatrix;
import smt.SparseVector;

/**
 * SSOR preconditioner. Uses symmetrical sucessive overrelaxation as a
 * preconditioner. Meant for symmetrical matrices. For best performance, omega
 * must be carefully chosen (between 0 and 2).
 */
public class SSOR implements Preconditioner {

	/**
	 * Overrelaxation parameter
	 */
	private double omega;

	/**
	 * All diagonal entries of the matrix
	 */
	private double[] diag;

	/**
	 * Holds a copy of the matrix A
	 */
	private FlexCompRowMatrix F;

	/**
	 * Indices to the diagonal entries of the matrix
	 */
	private int[] diagInd;

	/**
	 * Constructor for SSOR
	 * 
	 * @param A
	 *            Matrix to create preconditioner for. Not modified
	 * @param omega
	 *            Overrelaxation parameter. Between 0 and 2.
	 */
	public SSOR(Matrix A, double omega) {
		F = new FlexCompRowMatrix(A.numRows(), A.numColumns());
		setOmega(omega);

		int n = F.numRows();
		diag = new double[n];
		diagInd = new int[n];

		setMatrix(A);
	}

	/**
	 * Constructor for SSOR. Uses <code>omega=1</code>
	 * 
	 * @param A
	 *            Matrix to create preconditioner for. Not modified
	 */
	public SSOR(Matrix A) {
		this(A, 1);
	}

	/**
	 * Sets the overrelaxation parameter
	 * 
	 * @param omega
	 *            Overrelaxation parameter. Between 0 and 2.
	 */
	public void setOmega(double omega) {
		if (omega <= 0 || omega >= 2)
			throw new IllegalArgumentException("omega must be between 0 and 2");

		this.omega = omega;
	}

	public void setMatrix(Matrix A) {
		if (F.numRows() != A.numRows() || F.numColumns() != A.numColumns())
			throw new IllegalArgumentException("Cached size differs from matrix size");

		F.set(A);
		int n = F.numRows();

		// Store the diagonals and their indices
		for (int i = 0; i < n; ++i) {
			// Ensure existant diagonal (even if it's zero)
			if (F.get(i, i) == 0)
				F.set(i, i, 0.);

			SparseVector row = (SparseVector) F.getRow(i);
			int ind =
				smt.util.Arrays.binarySearch(row.getIndex(), i, 0, row.used());
			diag[i] = row.getData()[ind];
			diagInd[i] = ind;
		}
	}

	public Vector apply(Vector b, Vector x) {
		if (!(x instanceof DenseVector))
			throw new IllegalArgumentException("x must be a DenseVector");

		// Copy b over to x, and get a dense array representation
		x.set(b);
		double[] xd = ((DenseVector) x).getData();
		int n = F.numRows();

		// M = (D+L) D^{-1} (D+L)^T

		// Solves (1/omega)*(D+L) y = b
		for (int i = 0; i < n; ++i) {

			// Do nothing if we get a divide by zero
			if (diag[i] == 0)
				continue;

			SparseVector row = (SparseVector) F.getRow(i);
			double[] data = row.getData();
			int[] index = row.getIndex();
			int length = row.used();

			double sum = 0;
			for (int j = 0; j < length && index[j] < i; ++j)
				sum += data[j] * xd[index[j]];
			xd[i] = (omega / diag[i]) * (xd[i] - sum);
		}

		// Solves (omega/(2-omega))*D^{-1} z = y
		for (int i = 0; i < n; ++i)
			// Avoid zero-division
			if (diag[i] != 0.)
				xd[i] = (2 - omega) / omega * diag[i] * xd[i];

		// Solves (1/omega)*(D+L)^T x = z
		for (int i = n - 1; i >= 0; --i) {

			// Do nothing if we get a divide by zero
			if (diag[i] == 0)
				continue;

			SparseVector row = (SparseVector) F.getRow(i);
			double[] data = row.getData();
			int[] index = row.getIndex();
			int length = row.used();

			double sum = 0;
			for (int j = diagInd[i] + 1; j < length; ++j)
				sum += data[j] * xd[index[j]];
			xd[i] = (omega / diag[i]) * (xd[i] - sum);
		}

		return x;
	}

	public Vector transApply(Vector b, Vector x) {
		return apply(b, x);
	}

}
