/*
 * Decompiled with CFR 0.152.
 */
package ch.epfl.biop.sourceandconverter.register;

import bdv.viewer.Source;
import bdv.viewer.SourceAndConverter;
import ch.epfl.biop.sourceandconverter.exporter.CZTRange;
import ch.epfl.biop.sourceandconverter.exporter.ImagePlusGetter;
import ij.IJ;
import ij.ImagePlus;
import java.awt.geom.AffineTransform;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.stream.Collectors;
import mpicbg.ij.FeatureTransform;
import mpicbg.ij.SIFT;
import mpicbg.imagefeatures.FloatArray2DSIFT;
import mpicbg.models.AffineModel2D;
import mpicbg.models.CoordinateTransform;
import mpicbg.models.NotEnoughDataPointsException;
import mpicbg.models.PointMatch;
import net.imglib2.FinalRealInterval;
import net.imglib2.RealInterval;
import net.imglib2.RealLocalizable;
import net.imglib2.RealPoint;
import net.imglib2.RealPositionable;
import net.imglib2.realtransform.AffineTransform3D;
import net.imglib2.type.NativeType;
import net.imglib2.type.numeric.NumericType;
import sc.fiji.bdvpg.sourceandconverter.importer.EmptySourceAndConverterCreator;
import sc.fiji.bdvpg.sourceandconverter.transform.SourceAffineTransformer;
import sc.fiji.bdvpg.sourceandconverter.transform.SourceResampler;

public class SIFTRegister<FT extends NativeType<FT> & NumericType<FT>, MT extends NativeType<MT> & NumericType<MT>> {
    SourceAndConverter<FT>[] sacs_fixed;
    SourceAndConverter<MT>[] sacs_moving;
    int levelMipmapFixed;
    int levelMipmapMoving;
    final boolean invertFixed;
    final boolean invertMoving;
    int tpMoving;
    int tpFixed;
    AffineTransform3D affineTransformOut;
    double px;
    double py;
    double pz;
    double sx;
    double sy;
    double pxSizeInCurrentUnit;
    boolean interpolate = false;
    String errorMessage = "";
    final FloatArray2DSIFT.Param paramSift;
    final float rod;
    final double maxEpsilon;
    final double minInlierRatio;
    final int minNumInliers;

    public SIFTRegister(SourceAndConverter<FT>[] sacs_fixed, int levelMipmapFixed, int tpFixed, boolean invertFixed, SourceAndConverter<MT>[] sacs_moving, int levelMipmapMoving, int tpMoving, boolean invertMoving, double pxSizeInCurrentUnit, double px, double py, double pz, double sx, double sy, FloatArray2DSIFT.Param paramSift, float rod, double maxEpsilon, double minInlierRatio, int minNumInliers) {
        this.sacs_fixed = sacs_fixed;
        this.sacs_moving = sacs_moving;
        this.pxSizeInCurrentUnit = pxSizeInCurrentUnit;
        this.px = px;
        this.py = py;
        this.pz = pz;
        this.sx = sx;
        this.sy = sy;
        this.levelMipmapFixed = levelMipmapFixed;
        this.levelMipmapMoving = levelMipmapMoving;
        this.tpFixed = tpFixed;
        this.tpMoving = tpMoving;
        this.paramSift = paramSift;
        this.rod = rod;
        this.maxEpsilon = maxEpsilon;
        this.minNumInliers = minNumInliers;
        this.minInlierRatio = minInlierRatio;
        this.invertFixed = invertFixed;
        this.invertMoving = invertMoving;
    }

    public void setInterpolate(boolean interpolate) {
        this.interpolate = interpolate;
    }

    public boolean run() {
        this.levelMipmapFixed = Math.min(this.levelMipmapFixed, this.sacs_fixed[0].getSpimSource().getNumMipmapLevels() - 1);
        this.levelMipmapMoving = Math.min(this.levelMipmapMoving, this.sacs_moving[0].getSpimSource().getNumMipmapLevels() - 1);
        ImagePlus croppedMoving = this.getCroppedImage("Moving", this.sacs_moving, this.tpMoving, this.levelMipmapMoving);
        ImagePlus croppedFixed = this.getCroppedImage("Fixed", this.sacs_fixed, this.tpFixed, this.levelMipmapFixed);
        Source sMoving = this.sacs_moving[0].getSpimSource();
        Source sFixed = this.sacs_fixed[0].getSpimSource();
        AffineTransform3D at3D = new AffineTransform3D();
        at3D.identity();
        at3D.translate(new double[]{-this.px, -this.py, -this.pz});
        AffineTransform3D atMoving = new AffineTransform3D();
        sMoving.getSourceTransform(this.tpMoving, this.levelMipmapMoving, atMoving);
        AffineTransform3D atFixed = new AffineTransform3D();
        sFixed.getSourceTransform(this.tpMoving, this.levelMipmapFixed, atFixed);
        at3D.identity();
        at3D.translate(new double[]{-this.px, -this.py, -this.pz});
        FloatArray2DSIFT sift = new FloatArray2DSIFT(this.paramSift);
        SIFT ijSIFT = new SIFT(sift);
        ArrayList fs1 = new ArrayList();
        ArrayList fs2 = new ArrayList();
        if (this.invertFixed) {
            croppedFixed.getProcessor().invert();
        }
        if (this.invertMoving) {
            croppedMoving.getProcessor().invert();
        }
        ijSIFT.extractFeatures(croppedFixed.getProcessor(), fs1);
        IJ.log((String)(fs1.size() + " features extracted for fixed image"));
        ijSIFT.extractFeatures(croppedMoving.getProcessor(), fs2);
        IJ.log((String)(fs2.size() + " features extracted for moving image"));
        IJ.log((String)"Identifying correspondence candidates using brute force ...");
        ArrayList candidates = new ArrayList();
        FeatureTransform.matchFeatures(fs1, fs2, candidates, (float)this.rod);
        IJ.log((String)(candidates.size() + " potentially corresponding features identified."));
        IJ.log((String)"Filtering correspondence candidates by geometric consensus ...");
        ArrayList inliers = new ArrayList();
        AffineModel2D model = new AffineModel2D();
        try {
            model.filterRansac(candidates, inliers, 1000, this.maxEpsilon, this.minInlierRatio, this.minNumInliers);
        }
        catch (NotEnoughDataPointsException e) {
            IJ.log((String)"No correspondences found.");
            return false;
        }
        PointMatch.apply(inliers, (CoordinateTransform)model);
        if (inliers.size() < this.minNumInliers) {
            IJ.log((String)"Not enough points found.");
            return false;
        }
        IJ.log((String)(inliers.size() + " corresponding features with an average displacement of " + PointMatch.meanDistance(inliers) + "px identified."));
        IJ.log((String)("Estimated transformation model: " + model));
        ArrayList points1 = new ArrayList();
        ArrayList points2 = new ArrayList();
        PointMatch.sourcePoints(inliers, points1);
        PointMatch.targetPoints(inliers, points2);
        AffineTransform3D affine3D = SIFTRegister.convertToAffineTransform3D(model.createAffine());
        AffineTransform3D mPatchPixToRegPatchPix = new AffineTransform3D();
        mPatchPixToRegPatchPix.set(affine3D);
        AffineTransform3D nonRegisteredPatchTransformPixToGlobal = new AffineTransform3D();
        nonRegisteredPatchTransformPixToGlobal.identity();
        nonRegisteredPatchTransformPixToGlobal.scale(this.pxSizeInCurrentUnit);
        double cx = this.px;
        double cy = this.py;
        double cz = this.pz;
        nonRegisteredPatchTransformPixToGlobal.translate(new double[]{cx, cy, cz});
        AffineTransform3D nonRegPatchGlobalToPix = nonRegisteredPatchTransformPixToGlobal.inverse();
        RealPoint nonRegUNorm = this.getMatrixAxis(nonRegPatchGlobalToPix, 0);
        RealPoint nonRegVNorm = this.getMatrixAxis(nonRegPatchGlobalToPix, 1);
        RealPoint nonRegWNorm = this.getMatrixAxis(nonRegPatchGlobalToPix, 2);
        double u0PatchCoord = SIFTRegister.prodScal(this.getMatrixAxis(atMoving, 0), nonRegUNorm);
        double v0PatchCoord = SIFTRegister.prodScal(this.getMatrixAxis(atMoving, 0), nonRegVNorm);
        double w0PatchCoord = SIFTRegister.prodScal(this.getMatrixAxis(atMoving, 0), nonRegWNorm);
        double u1PatchCoord = SIFTRegister.prodScal(this.getMatrixAxis(atMoving, 1), nonRegUNorm);
        double v1PatchCoord = SIFTRegister.prodScal(this.getMatrixAxis(atMoving, 1), nonRegVNorm);
        double w1PatchCoord = SIFTRegister.prodScal(this.getMatrixAxis(atMoving, 1), nonRegWNorm);
        double u2PatchCoord = SIFTRegister.prodScal(this.getMatrixAxis(atMoving, 2), nonRegUNorm);
        double v2PatchCoord = SIFTRegister.prodScal(this.getMatrixAxis(atMoving, 2), nonRegVNorm);
        double w2PatchCoord = SIFTRegister.prodScal(this.getMatrixAxis(atMoving, 2), nonRegWNorm);
        RealPoint newOrigin = new RealPoint(3);
        newOrigin.setPosition(atMoving.get(0, 3), 0);
        newOrigin.setPosition(atMoving.get(1, 3), 1);
        newOrigin.setPosition(atMoving.get(2, 3), 2);
        nonRegPatchGlobalToPix.apply((RealLocalizable)newOrigin, (RealPositionable)newOrigin);
        double u3PatchCoord = newOrigin.getDoublePosition(0);
        double v3PatchCoord = newOrigin.getDoublePosition(1);
        double w3PatchCoord = newOrigin.getDoublePosition(2);
        RealPoint p0 = new RealPoint(new double[]{u0PatchCoord, v0PatchCoord, w0PatchCoord});
        RealPoint p1 = new RealPoint(new double[]{u1PatchCoord, v1PatchCoord, w1PatchCoord});
        RealPoint p2 = new RealPoint(new double[]{u2PatchCoord, v2PatchCoord, w2PatchCoord});
        RealPoint p3 = new RealPoint(new double[]{u3PatchCoord, v3PatchCoord, w3PatchCoord});
        AffineTransform3D mPatchPixToGlobal = new AffineTransform3D();
        mPatchPixToGlobal.set(nonRegisteredPatchTransformPixToGlobal);
        mPatchPixToGlobal = nonRegisteredPatchTransformPixToGlobal.concatenate(mPatchPixToRegPatchPix);
        double shiftX = mPatchPixToGlobal.get(0, 3);
        double shiftY = mPatchPixToGlobal.get(1, 3);
        double shiftZ = mPatchPixToGlobal.get(2, 3);
        mPatchPixToGlobal.set(0.0, 0, 3);
        mPatchPixToGlobal.set(0.0, 1, 3);
        mPatchPixToGlobal.set(0.0, 2, 3);
        mPatchPixToGlobal.apply((RealLocalizable)p0, (RealPositionable)p0);
        mPatchPixToGlobal.apply((RealLocalizable)p1, (RealPositionable)p1);
        mPatchPixToGlobal.apply((RealLocalizable)p2, (RealPositionable)p2);
        mPatchPixToGlobal.set(shiftX, 0, 3);
        mPatchPixToGlobal.set(shiftY, 1, 3);
        mPatchPixToGlobal.set(shiftZ, 2, 3);
        mPatchPixToGlobal.apply((RealLocalizable)p3, (RealPositionable)p3);
        double[] newMatrix = new double[12];
        newMatrix[0] = p0.getDoublePosition(0);
        newMatrix[4] = p0.getDoublePosition(1);
        newMatrix[8] = p0.getDoublePosition(2);
        newMatrix[1] = p1.getDoublePosition(0);
        newMatrix[5] = p1.getDoublePosition(1);
        newMatrix[9] = p1.getDoublePosition(2);
        newMatrix[2] = p2.getDoublePosition(0);
        newMatrix[6] = p2.getDoublePosition(1);
        newMatrix[10] = p2.getDoublePosition(2);
        newMatrix[3] = p3.getDoublePosition(0);
        newMatrix[7] = p3.getDoublePosition(1);
        newMatrix[11] = p3.getDoublePosition(2);
        this.affineTransformOut = new AffineTransform3D();
        this.affineTransformOut.set(newMatrix);
        this.affineTransformOut = atMoving.concatenate(this.affineTransformOut.inverse());
        return true;
    }

    private <T extends NativeType<T> & NumericType<T>> ImagePlus getCroppedImage(String name, SourceAndConverter<T>[] sacs, int tp, int level) {
        FinalRealInterval window = new FinalRealInterval(new double[]{this.px, this.py, this.pz}, new double[]{this.px + this.sx, this.py + this.sy, this.pz + this.pxSizeInCurrentUnit});
        SourceAndConverter model = new EmptySourceAndConverterCreator("model", (RealInterval)window, this.pxSizeInCurrentUnit, this.pxSizeInCurrentUnit, this.pxSizeInCurrentUnit).get();
        SourceResampler resampler = new SourceResampler(null, model, model.getSpimSource().getName(), false, false, this.interpolate, level);
        List resampled = Arrays.stream(sacs).map(resampler).collect(Collectors.toList());
        ArrayList<Integer> channels = new ArrayList<Integer>(sacs.length);
        for (int i = 0; i < sacs.length; ++i) {
            channels.add(i);
        }
        ArrayList<Integer> slices = new ArrayList<Integer>();
        slices.add(0);
        ArrayList<Integer> timepoints = new ArrayList<Integer>();
        timepoints.add(tp);
        CZTRange range = new CZTRange(channels, slices, timepoints);
        return ImagePlusGetter.getImagePlus(name, resampled, 0, range, false, false, false, null);
    }

    public static double prodScal(RealPoint pt1, RealPoint pt2) {
        return pt1.getDoublePosition(0) * pt2.getDoublePosition(0) + pt1.getDoublePosition(1) * pt2.getDoublePosition(1) + pt1.getDoublePosition(2) * pt2.getDoublePosition(2);
    }

    public RealPoint getMatrixAxis(AffineTransform3D at3D, int axis) {
        RealPoint pt = new RealPoint(3);
        double[] m = at3D.getRowPackedCopy();
        pt.setPosition(m[0 + axis], 0);
        pt.setPosition(m[4 + axis], 1);
        pt.setPosition(m[8 + axis], 2);
        return pt;
    }

    public SourceAndConverter[] getRegisteredSacs() {
        SourceAndConverter[] out = new SourceAndConverter[this.sacs_moving.length];
        SourceAffineTransformer sat = new SourceAffineTransformer(null, this.affineTransformOut);
        for (int iCh = 0; iCh < this.sacs_moving.length; ++iCh) {
            out[iCh] = sat.apply(this.sacs_moving[iCh]);
        }
        return out;
    }

    public AffineTransform3D getAffineTransform() {
        return this.affineTransformOut;
    }

    public String getErrorMessage() {
        return this.errorMessage;
    }

    private static AffineTransform3D convertToAffineTransform3D(AffineTransform at) {
        AffineTransform3D at3d = new AffineTransform3D();
        double[] matrix2D = new double[6];
        at.getMatrix(matrix2D);
        double[] matrix3D = new double[]{matrix2D[0], matrix2D[2], 0.0, matrix2D[4], matrix2D[1], matrix2D[3], 0.0, matrix2D[5], 0.0, 0.0, 1.0, 0.0};
        at3d.set(matrix3D);
        return at3d;
    }
}

