/**
 *
 * @file c_refine_grad.c
 *
 * PaStiX refinement functions implementations.
 *
 * @copyright 2015-2024 Bordeaux INP, CNRS (LaBRI UMR 5800), Inria,
 *                      Univ. Bordeaux. All rights reserved.
 *
 * @version 6.4.0
 * @author Mathieu Faverge
 * @author Pierre Ramet
 * @author Xavier Lacoste
 * @author Theophile Terraz
 * @author Gregoire Pichon
 * @author Vincent Bridonneau
 * @date 2024-07-05
 * @generated from /build/pastix/src/pastix-6.4.0/refinement/z_refine_grad.c, normal z -> c, Thu Oct 23 06:51:46 2025
 *
 **/
#include "common.h"
#include "bcsc/bcsc.h"
#include "c_refine_functions.h"

/**
 *******************************************************************************
 *
 * @ingroup pastix_refine
 *
 * c_grad_smp - Refine the solution using conjugate gradian method.
 *
 *******************************************************************************
 *
 * @param[in] pastix_data
 *          The PaStiX data structure that describes the solver instance.
 *
 * @param[out] xp
 *          The solution vector.
 *
 * @param[in] bp
 *          The right hand side member (only one).
 *
 *******************************************************************************
 *
 * @return Number of iterations
 *
 *******************************************************************************/
pastix_int_t
c_grad_smp( pastix_data_t *pastix_data,
            pastix_rhs_t   xp,
            pastix_rhs_t   bp )
{
    struct c_solver     solver;
    pastix_int_t        n;
    Clock               refine_clk;
    pastix_fixdbl_t     t0 = 0;
    pastix_fixdbl_t     t3 = 0;
    int                 itermax;
    int                 nb_iter = 0;
    int                 precond = 1;
    pastix_complex32_t *x = (pastix_complex32_t*)(xp->b);
    pastix_complex32_t *b = (pastix_complex32_t*)(bp->b);
    pastix_complex32_t *gradr;
    pastix_complex32_t *gradp;
    pastix_complex32_t *gradz;
    pastix_complex32_t *grad2;
    pastix_complex32_t *sgrad = NULL;
    float normb, normx, normr, alpha, beta;
    float resid_b, eps;

    memset( &solver, 0, sizeof(struct c_solver) );
    c_refine_init( &solver, pastix_data );

    if ( !(pastix_data->steps & STEP_NUMFACT) ) {
        precond = 0;
    }

    n       = pastix_data->bcsc->n;
    itermax = pastix_data->iparm[IPARM_ITERMAX];
    eps     = pastix_data->dparm[DPARM_EPSILON_REFINEMENT];

    /* Initialize vectors */
    gradr = (pastix_complex32_t *)solver.malloc(n * sizeof(pastix_complex32_t));
    gradp = (pastix_complex32_t *)solver.malloc(n * sizeof(pastix_complex32_t));
    gradz = (pastix_complex32_t *)solver.malloc(n * sizeof(pastix_complex32_t));
    grad2 = (pastix_complex32_t *)solver.malloc(n * sizeof(pastix_complex32_t));

    /* Allocating a vector at half-precision, NULL pointer otherwise */
    if ( pastix_data->iparm[IPARM_MIXED] )
    {
        sgrad = solver.malloc( n * sizeof(pastix_complex32_t) );
    }

    clockInit(refine_clk);
    clockStart(refine_clk);

    normb = solver.norm( pastix_data, n, b );
    if ( normb == 0. ) {
        normb = 1;
    }
    normx = solver.norm( pastix_data, n, x );

    /* Compute r0 = b - A * x */
    solver.copy( pastix_data, n, b, gradr );
    if ( normx > 0. ) {
        solver.spmv( pastix_data, PastixNoTrans, -1., x, 1., gradr );
    }
    normr = solver.norm( pastix_data, n, gradr );
    resid_b = normr / normb;

    /* z = M^{-1} r */
    solver.copy( pastix_data, n, gradr, gradz );
    if ( precond ) {
        solver.spsv( pastix_data, gradz, sgrad );
    }

    /* p = z */
    solver.copy( pastix_data, n, gradz, gradp );

    while ((resid_b > eps) && (nb_iter < itermax))
    {
        clockStop((refine_clk));
        t0 = clockGet();
        nb_iter++;

        /* grad2 = A * p */
        solver.spmv( pastix_data, PastixNoTrans, 1.0, gradp, 0., grad2 );

        /* alpha = <r, z> / <Ap, p> */
        beta  = solver.dot( pastix_data, n, gradr, gradz );
        alpha = solver.dot( pastix_data, n, grad2, gradp );
        alpha = beta / alpha;

        /* x = x + alpha * p */
        solver.axpy( pastix_data, n, alpha, gradp, x );

        /* r = r - alpha * A * p */
        solver.axpy( pastix_data, n, -alpha, grad2, gradr );

        /* z = M-1 * r */
        solver.copy( pastix_data, n, gradr, gradz );
        if ( precond ) {
            solver.spsv( pastix_data, gradz, sgrad );
        }

        /* beta = <r', z> / <r, z> */
        alpha = solver.dot( pastix_data, n, gradr, gradz );
        beta  = alpha / beta;

        /* p = z + beta * p */
        solver.scal( pastix_data, n, beta, gradp );
        solver.axpy( pastix_data, n, 1., gradz, gradp );

        normr = solver.norm( pastix_data, n, gradr );
        resid_b = normr / normb;

        clockStop((refine_clk));
        t3 = clockGet();
        if ( ( pastix_data->iparm[IPARM_VERBOSE] > PastixVerboseNot ) &&
             ( pastix_data->procnum == 0 ) ) {
            solver.output_oneiter( t0, t3, resid_b, nb_iter );
        }
        t0 = t3;
    }

    solver.output_final(pastix_data, resid_b, nb_iter, t3, x, x);

    solver.free((void*) gradr);
    solver.free((void*) gradp);
    solver.free((void*) gradz);
    solver.free((void*) grad2);
    solver.free((void*) sgrad);

    return nb_iter;
}
