/*
 * Copyright (C) 2003, 2004 Bjrn-Ove Heimsund
 * 
 * This file is part of SMT.
 * 
 * This library is free software; you can redistribute it and/or modify it
 * under the terms of the GNU Lesser General Public License as published by the
 * Free Software Foundation; either version 2.1 of the License, or (at your
 * option) any later version.
 * 
 * This library is distributed in the hope that it will be useful, but WITHOUT
 * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
 * FITNESS FOR A PARTICULAR PURPOSE. See the GNU Lesser General Public License
 * for more details.
 * 
 * You should have received a copy of the GNU Lesser General Public License
 * along with this library; if not, write to the Free Software Foundation,
 * Inc., 59 Temple Place, Suite 330, Boston, MA 02111-1307 USA
 */

package smt.iter.mixed;

import mt.Matrix;
import mt.Vector;
import smt.iter.IterativeSolverNotConvergedException;

/**
 * Uzawa algorithm using the CG method
 */
public class CGUzawa extends AbstractMixedSolver {

	/**
	 * Temporary work vectors
	 */
	private Vector ur, p, h, r, rhat, d, t;

	/**
	 * Constructor for CGUzawa
	 * 
	 * @param qTemplate
	 *            Vector of size <i>m</i>, used as template in preallocating
	 *            work vectors
	 * @param uTemplate
	 *            Vector of size <i>n</i>, used as template in preallocating
	 *            work vectors
	 */
	public CGUzawa(Vector qTemplate, Vector uTemplate) {
		super(qTemplate);

		ur = qTemplate.copy();
		p = qTemplate.copy();
		h = qTemplate.copy();
		r = uTemplate.copy();

		rhat = uTemplate.copy();
		d = uTemplate.copy();
		t = uTemplate.copy();
	}

	protected void solveI(
		Matrix A,
		Matrix B,
		Matrix Bt,
		Matrix C,
		Vector q,
		Vector u,
		Vector f,
		Vector g)
		throws IterativeSolverNotConvergedException {

		Bt.multAdd(-1., u, f, ur); // ur = f - Bt*u
		solver.solve(A, ur, q); // q = A \ ur = A \ (f - Bt*u)

		C.multAdd(-1., u, g, t); // t = g - C*u
		B.multAdd(-1., q, t, r); // r = t - B*q = g - B*q - C*u
		d.set(-1., r); // d = -r

		for (iter.setFirst(); !iter.converged(r); iter.next()) {
			Bt.mult(d, p); // p = Bt*d
			solver.solve(A, p, h); // h = A\p

			double alpha = r.dot(r) / p.dot(h);
			u.add(alpha, d); // u = u + alpha*d
			q.add(-alpha, h); // q = q - alpha*h

			rhat.set(r); // rhat = r
			C.multAdd(-1., u, g, t); // t = g - C*u
			B.multAdd(-1., q, t, r); // r = g - B*q - C*u
			double beta = r.dot(r) / rhat.dot(rhat);
			d.add(-1., r, beta); // d = -r + beta*d
		}
	}

}
