/*
 * Decompiled with CFR 0.152.
 */
package jitk.spline;

import jitk.spline.ThinPlateR2LogRSplineKernelTransform;
import org.ejml.data.DMatrix1Row;
import org.ejml.data.DMatrixD1;
import org.ejml.data.DMatrixRMaj;
import org.ejml.dense.row.CommonOps_DDRM;
import org.ejml.dense.row.NormOps_DDRM;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class TransformInverseGradientDescent {
    int ndims;
    ThinPlateR2LogRSplineKernelTransform xfm;
    DMatrixRMaj jacobian;
    DMatrixRMaj directionalDeriv;
    DMatrixRMaj descentDirectionMag;
    DMatrixRMaj dir;
    DMatrixRMaj errorV;
    DMatrixRMaj estimate;
    DMatrixRMaj estimateXfm;
    DMatrixRMaj target;
    double error = 9999.0;
    double stepSz = 1.0;
    int maxIters = 20;
    double eps = 1.0E-6;
    double beta = 0.7;
    protected static Logger logger = LoggerFactory.getLogger(TransformInverseGradientDescent.class);

    public TransformInverseGradientDescent(int ndims, ThinPlateR2LogRSplineKernelTransform xfm) {
        this.ndims = ndims;
        this.xfm = xfm;
        this.dir = new DMatrixRMaj(ndims, 1);
        this.errorV = new DMatrixRMaj(ndims, 1);
        this.directionalDeriv = new DMatrixRMaj(ndims, 1);
        this.descentDirectionMag = new DMatrixRMaj(1, 1);
    }

    public void setEps(double eps) {
        this.eps = eps;
    }

    public void setStepSize(double stepSize) {
        this.stepSz = stepSize;
    }

    public void setJacobian(double[][] mtx) {
        this.jacobian = new DMatrixRMaj(mtx);
        logger.trace("setJacobian:\n" + this.jacobian);
    }

    public void setTarget(double[] tgt) {
        this.target = new DMatrixRMaj(this.ndims, 1);
        this.target.setData(tgt);
    }

    public DMatrixRMaj getErrorVector() {
        return this.errorV;
    }

    public DMatrixRMaj getDirection() {
        return this.dir;
    }

    public DMatrixRMaj getJacobian() {
        return this.jacobian;
    }

    public void setEstimate(double[] est) {
        this.estimate = new DMatrixRMaj(this.ndims, 1);
        this.estimate.setData(est);
    }

    public void setEstimateXfm(double[] est) {
        this.estimateXfm = new DMatrixRMaj(this.ndims, 1);
        this.estimateXfm.setData(est);
        this.updateError();
    }

    public DMatrixRMaj getEstimate() {
        return this.estimate;
    }

    public double getError() {
        return this.error;
    }

    public void oneIteration() {
        this.oneIteration(true);
    }

    public void oneIteration(boolean updateError) {
        this.computeDirection();
        this.updateEstimate(this.stepSz);
        if (updateError) {
            this.updateError();
        }
    }

    public void computeDirectionSteepest() {
        DMatrixRMaj tmp = new DMatrixRMaj(this.ndims, 1);
        logger.trace("\nerrorV:\n" + this.errorV);
        CommonOps_DDRM.mult((DMatrix1Row)this.jacobian, (DMatrix1Row)this.estimate, (DMatrix1Row)tmp);
        CommonOps_DDRM.subtractEquals((DMatrixD1)tmp, (DMatrixD1)this.errorV);
        CommonOps_DDRM.multTransA((double)2.0, (DMatrix1Row)this.jacobian, (DMatrix1Row)tmp, (DMatrix1Row)this.dir);
        double norm = NormOps_DDRM.normP2((DMatrixRMaj)this.dir);
        CommonOps_DDRM.divide((double)norm, (DMatrixD1)this.dir);
        CommonOps_DDRM.mult((DMatrix1Row)this.jacobian, (DMatrix1Row)this.dir, (DMatrix1Row)this.directionalDeriv);
        CommonOps_DDRM.scale((double)-1.0, (DMatrixD1)this.dir);
    }

    public void computeDirection() {
        CommonOps_DDRM.solve((DMatrixRMaj)this.jacobian, (DMatrixRMaj)this.errorV, (DMatrixRMaj)this.dir);
        double norm = NormOps_DDRM.normP2((DMatrixRMaj)this.dir);
        CommonOps_DDRM.divide((double)norm, (DMatrixD1)this.dir);
        CommonOps_DDRM.mult((DMatrix1Row)this.jacobian, (DMatrix1Row)this.dir, (DMatrix1Row)this.directionalDeriv);
        CommonOps_DDRM.multTransA((DMatrix1Row)this.dir, (DMatrix1Row)this.directionalDeriv, (DMatrix1Row)this.descentDirectionMag);
        logger.debug("descentDirectionMag: " + this.descentDirectionMag.get(0));
    }

    public double backtrackingLineSearch(double c, double beta, int maxtries, double t0) {
        int k;
        double t = t0;
        for (k = 0; k < maxtries && !this.armijoCondition(c, t); ++k) {
            t *= beta;
        }
        logger.trace("selected step size after " + k + " tries");
        return t;
    }

    public boolean armijoCondition(double c, double t) {
        double[] d = this.dir.data;
        double[] x = this.estimate.data;
        double[] x_ap = new double[this.ndims];
        for (int i = 0; i < this.ndims; ++i) {
            x_ap[i] = x[i] + t * d[i];
        }
        double[] phix = this.estimateXfm.data;
        double[] phix_ap = this.xfm.apply(x_ap);
        double fx = this.squaredError(phix);
        double fx_ap = this.squaredError(phix_ap);
        double m = this.sumSquaredErrorsDeriv(this.target.data, phix) * this.descentDirectionMag.get(0);
        logger.trace("   f( x )     : " + fx);
        logger.trace("   f( x + ap ): " + fx_ap);
        logger.trace("   f( x ) + c * m * t: " + (fx + c * t * m));
        return fx_ap < fx + c * t * m;
    }

    public double squaredError(double[] x) {
        double error = 0.0;
        for (int i = 0; i < this.ndims; ++i) {
            error += (x[i] - this.target.get(i)) * (x[i] - this.target.get(i));
        }
        return error;
    }

    public void updateEstimate(double stepSize) {
        logger.trace("step size: " + stepSize);
        logger.trace("estimate:\n" + this.estimate);
        CommonOps_DDRM.addEquals((DMatrixD1)this.estimate, (double)stepSize, (DMatrixD1)this.dir);
        logger.trace("new estimate:\n" + this.estimate);
    }

    public void updateEstimateNormBased(double stepSize) {
        logger.debug("step size: " + stepSize);
        logger.trace("estimate:\n" + this.estimate);
        double norm = NormOps_DDRM.normP2((DMatrixRMaj)this.dir);
        logger.debug("norm: " + norm);
        if (norm > stepSize) {
            CommonOps_DDRM.scale((double)(-stepSize / norm), (DMatrixD1)this.dir);
        }
        CommonOps_DDRM.addEquals((DMatrixD1)this.estimate, (DMatrixD1)this.dir);
        logger.trace("new estimate:\n" + this.estimate);
    }

    public void updateError() {
        if (this.estimate == null || this.target == null) {
            System.err.println("WARNING: Call to updateError with null target or estimate");
            return;
        }
        CommonOps_DDRM.subtract((DMatrixD1)this.target, (DMatrixD1)this.estimateXfm, (DMatrixD1)this.errorV);
        logger.trace("#########################");
        logger.trace("updateError, estimate   :\n" + this.estimate);
        logger.trace("updateError, estimateXfm:\n" + this.estimateXfm);
        logger.trace("updateError, target     :\n" + this.target);
        logger.trace("updateError, error      :\n" + this.errorV);
        logger.trace("#########################");
        this.error = Math.abs(this.errorV.get(0));
        for (int i = 1; i < this.ndims; ++i) {
            if (!(Math.abs(this.errorV.get(i)) > this.error)) continue;
            this.error = Math.abs(this.errorV.get(i));
        }
    }

    private double sumSquaredErrorsDeriv(double[] y, double[] x) {
        double errDeriv = 0.0;
        for (int i = 0; i < this.ndims; ++i) {
            errDeriv += (y[i] - x[i]) * (y[i] - x[i]);
        }
        return 2.0 * errDeriv;
    }

    public static double sumSquaredErrors(double[] y, double[] x) {
        int ndims = y.length;
        double err = 0.0;
        for (int i = 0; i < ndims; ++i) {
            err += (y[i] - x[i]) * (y[i] - x[i]);
        }
        return err;
    }

    public static void copyVectorIntoArray(DMatrixRMaj vec, double[] array) {
        System.arraycopy(vec.data, 0, array, 0, vec.getNumElements());
    }
}

