/*
 * Decompiled with CFR 0.152.
 */
package net.imglib2.img;

import ij.ImagePlus;
import java.util.AbstractList;
import java.util.concurrent.ExecutionException;
import java.util.function.Function;
import java.util.stream.LongStream;
import net.imagej.ImgPlus;
import net.imagej.axis.CalibratedAxis;
import net.imglib2.cache.Cache;
import net.imglib2.cache.ref.SoftRefLoaderCache;
import net.imglib2.img.Img;
import net.imglib2.img.NativeImg;
import net.imglib2.img.basictypeaccess.array.ArrayDataAccess;
import net.imglib2.img.basictypeaccess.array.ByteArray;
import net.imglib2.img.basictypeaccess.array.FloatArray;
import net.imglib2.img.basictypeaccess.array.IntArray;
import net.imglib2.img.basictypeaccess.array.ShortArray;
import net.imglib2.img.display.imagej.CalibrationUtils;
import net.imglib2.img.planar.PlanarImg;
import net.imglib2.type.NativeType;
import net.imglib2.type.numeric.ARGBType;
import net.imglib2.type.numeric.integer.UnsignedByteType;
import net.imglib2.type.numeric.integer.UnsignedIntType;
import net.imglib2.type.numeric.integer.UnsignedShortType;
import net.imglib2.type.numeric.real.FloatType;
import net.imglib2.util.Fraction;

public class VirtualStackAdapter {
    public static ImgPlus<UnsignedByteType> wrapByte(ImagePlus image) {
        return VirtualStackAdapter.internWrap(image, 0, new UnsignedByteType(), array -> new ByteArray((byte[])array));
    }

    public static ImgPlus<UnsignedShortType> wrapShort(ImagePlus image) {
        return VirtualStackAdapter.internWrap(image, 1, new UnsignedShortType(), array -> new ShortArray((short[])array));
    }

    public static ImgPlus<FloatType> wrapFloat(ImagePlus image) {
        return VirtualStackAdapter.internWrap(image, 2, new FloatType(), array -> new FloatArray((float[])array));
    }

    public static ImgPlus<UnsignedIntType> wrapInt(ImagePlus image) {
        return VirtualStackAdapter.internWrap(image, 4, new UnsignedIntType(), array -> new IntArray((int[])array));
    }

    public static ImgPlus<ARGBType> wrapRGBA(ImagePlus image) {
        return VirtualStackAdapter.internWrap(image, 4, new ARGBType(), array -> new IntArray((int[])array));
    }

    public static ImgPlus<?> wrap(ImagePlus image) {
        switch (image.getType()) {
            case 0: {
                return VirtualStackAdapter.wrapByte(image);
            }
            case 1: {
                return VirtualStackAdapter.wrapShort(image);
            }
            case 2: {
                return VirtualStackAdapter.wrapFloat(image);
            }
            case 4: {
                return VirtualStackAdapter.wrapRGBA(image);
            }
        }
        throw new RuntimeException("Only 8, 16, 32-bit and RGB supported!");
    }

    private static <T extends NativeType<T>, A extends ArrayDataAccess<A>> ImgPlus<T> internWrap(ImagePlus image, int expectedType, T type, Function<Object, A> createArrayAccess) {
        if (image.getType() != expectedType) {
            throw new IllegalArgumentException();
        }
        ImagePlusLoader<A> loader = new ImagePlusLoader<A>(image, createArrayAccess);
        long[] dimensions = VirtualStackAdapter.getNonTrivialDimensions(image);
        PlanarImg cached = new PlanarImg(loader, dimensions, new Fraction());
        cached.setLinkedType(type.getNativeTypeFactory().createLinkedType((NativeImg)cached));
        CalibratedAxis[] axes = CalibrationUtils.getNonTrivialAxes(image);
        ImgPlus wrap = new ImgPlus((Img)cached, image.getTitle(), axes);
        return wrap;
    }

    private static long[] getNonTrivialDimensions(ImagePlus image) {
        LongStream xy = LongStream.of(image.getWidth(), image.getHeight());
        LongStream czt = LongStream.of(image.getNChannels(), image.getNSlices(), image.getNFrames());
        return LongStream.concat(xy, czt.filter(x -> x > 1L)).toArray();
    }

    private static class ImagePlusLoader<A extends ArrayDataAccess<A>>
    extends AbstractList<A> {
        private final ImagePlus image;
        private final Cache<Integer, A> cache;
        private final Function<Object, A> arrayFactory;

        public ImagePlusLoader(ImagePlus image, Function<Object, A> arrayFactory) {
            this.arrayFactory = arrayFactory;
            this.image = image;
            this.cache = new SoftRefLoaderCache<Integer, ArrayDataAccess>().withLoader(this::load);
        }

        @Override
        public A get(int key) {
            try {
                return (A)((ArrayDataAccess)this.cache.get(key));
            }
            catch (ExecutionException e) {
                throw new RuntimeException(e);
            }
        }

        private A load(Integer key) {
            return (A)((ArrayDataAccess)this.arrayFactory.apply(this.image.getStack().getPixels(key + 1)));
        }

        @Override
        public int size() {
            return this.image.getStackSize();
        }
    }
}

