/*
 * Decompiled with CFR 0.152.
 */
package dr.math.distributions;

import dr.math.GammaFunction;
import dr.math.MathUtils;
import dr.math.distributions.GammaDistribution;
import dr.math.distributions.MultivariateDistribution;
import dr.math.distributions.WishartStatistics;
import dr.math.matrixAlgebra.CholeskyDecomposition;
import dr.math.matrixAlgebra.IllegalDimension;
import dr.math.matrixAlgebra.Matrix;

public class WishartDistribution
implements MultivariateDistribution,
WishartStatistics {
    public static final String TYPE = "Wishart";
    private double df;
    private int dim;
    private double[][] scaleMatrix;
    private double[] Sinv;
    private Matrix SinvMat;
    private double logNormalizationConstant;

    public WishartDistribution(double d, double[][] dArray) {
        this.df = d;
        this.scaleMatrix = dArray;
        this.dim = dArray.length;
        this.SinvMat = new Matrix(dArray).inverse();
        double[][] dArray2 = this.SinvMat.toComponents();
        this.Sinv = new double[this.dim * this.dim];
        for (int i = 0; i < this.dim; ++i) {
            System.arraycopy(dArray2[i], 0, this.Sinv, i * this.dim, this.dim);
        }
        this.computeNormalizationConstant();
    }

    public WishartDistribution(int n) {
        this.df = 0.0;
        this.scaleMatrix = null;
        this.dim = n;
        this.logNormalizationConstant = 0.0;
    }

    private void computeNormalizationConstant() {
        this.logNormalizationConstant = WishartDistribution.computeNormalizationConstant(new Matrix(this.scaleMatrix), this.df, this.dim);
    }

    public static double computeNormalizationConstant(Matrix matrix, double d, int n) {
        if (d == 0.0) {
            return 0.0;
        }
        double d2 = 0.0;
        try {
            d2 = -d / 2.0 * Math.log(matrix.determinant());
        }
        catch (IllegalDimension illegalDimension) {
            illegalDimension.printStackTrace();
        }
        d2 -= d * (double)n / 2.0 * Math.log(2.0);
        d2 -= (double)(n * (n - 1)) / 4.0 * Math.log(Math.PI);
        for (int i = 1; i <= n; ++i) {
            d2 -= GammaFunction.lnGamma((d + 1.0 - (double)i) / 2.0);
        }
        return d2;
    }

    @Override
    public String getType() {
        return TYPE;
    }

    @Override
    public double[][] getScaleMatrix() {
        return this.scaleMatrix;
    }

    @Override
    public double[] getMean() {
        return null;
    }

    public void testMe() {
        int n = 100000;
        double d = 0.0;
        double d2 = 0.0;
        double d3 = 0.0;
        double d4 = 0.0;
        for (int i = 0; i < n; ++i) {
            double[][] dArray = this.nextWishart();
            d += dArray[0][0];
            d2 += dArray[0][1];
            d3 += dArray[1][0];
            d4 += dArray[1][1];
        }
        System.err.println("S1: " + (d /= (double)n));
        System.err.println("S2: " + (d2 /= (double)n));
        System.err.println("S3: " + (d3 /= (double)n));
        System.err.println("S4: " + (d4 /= (double)n));
    }

    @Override
    public double getDF() {
        return this.df;
    }

    public double[][] nextWishart() {
        return WishartDistribution.nextWishart(this.df, this.scaleMatrix);
    }

    public static double[][] nextWishart(double d, double[][] dArray) {
        int n;
        int n2;
        int n3;
        int n4;
        int n5;
        int n6 = dArray.length;
        double[][] dArray2 = new double[n6][n6];
        double[][] dArray3 = new double[n6][n6];
        for (n5 = 0; n5 < n6; ++n5) {
            for (n4 = 0; n4 < n5; ++n4) {
                dArray3[n5][n4] = MathUtils.nextGaussian();
            }
        }
        for (n5 = 0; n5 < n6; ++n5) {
            dArray3[n5][n5] = Math.sqrt(MathUtils.nextGamma((d - (double)n5) * 0.5, 0.5));
        }
        double[][] dArray4 = new double[n6][n6];
        for (n4 = 0; n4 < n6; ++n4) {
            for (n3 = n4; n3 < n6; ++n3) {
                double d2 = dArray[n4][n3];
                dArray4[n3][n4] = d2;
                dArray4[n4][n3] = d2;
            }
        }
        try {
            dArray4 = new CholeskyDecomposition(dArray4).getL();
        }
        catch (IllegalDimension illegalDimension) {
            throw new RuntimeException("Numerical exception in WishartDistribution");
        }
        double[][] dArray5 = new double[n6][n6];
        for (n3 = 0; n3 < n6; ++n3) {
            for (n2 = 0; n2 < n6; ++n2) {
                for (n = 0; n < n6; ++n) {
                    double[] dArray6 = dArray5[n3];
                    int n7 = n2;
                    dArray6[n7] = dArray6[n7] + dArray4[n3][n] * dArray3[n][n2];
                }
            }
        }
        for (n3 = 0; n3 < n6; ++n3) {
            for (n2 = 0; n2 < n6; ++n2) {
                for (n = 0; n < n6; ++n) {
                    double[] dArray7 = dArray2[n3];
                    int n8 = n2;
                    dArray7[n8] = dArray7[n8] + dArray5[n3][n] * dArray5[n2][n];
                }
            }
        }
        return dArray2;
    }

    @Override
    public double logPdf(double[] dArray) {
        if (dArray.length == 4) {
            return WishartDistribution.logPdf2D(dArray, this.Sinv, this.df, this.dim, this.logNormalizationConstant);
        }
        return this.logPdfSlow(dArray);
    }

    public double logPdfSlow(double[] dArray) {
        Matrix matrix = new Matrix(dArray, this.dim, this.dim);
        return WishartDistribution.logPdf(matrix, this.SinvMat, this.df, this.dim, this.logNormalizationConstant);
    }

    public static double logPdf2D(double[] dArray, double[] dArray2, double d, int n, double d2) {
        double d3 = dArray[0] * dArray[3] - dArray[1] * dArray[2];
        if (d3 <= 0.0) {
            return Double.NEGATIVE_INFINITY;
        }
        double d4 = Math.log(d3);
        d4 *= 0.5 * (d - (double)n - 1.0);
        double d5 = dArray2[0] * dArray[0] + dArray2[1] * dArray[2] + dArray2[2] * dArray[1] + dArray2[3] * dArray[3];
        d4 -= 0.5 * d5;
        return d4 += d2;
    }

    public static double logPdf(Matrix matrix, Matrix matrix2, double d, int n, double d2) {
        double d3 = 0.0;
        try {
            d3 = matrix.logDeterminant();
            if (Double.isInfinite(d3) || Double.isNaN(d3)) {
                return Double.NEGATIVE_INFINITY;
            }
            d3 *= 0.5;
            d3 *= d - (double)n - 1.0;
            if (matrix2 != null) {
                Matrix matrix3 = matrix2.product(matrix);
                for (int i = 0; i < n; ++i) {
                    d3 -= 0.5 * matrix3.component(i, i);
                }
            }
        }
        catch (IllegalDimension illegalDimension) {
            illegalDimension.printStackTrace();
        }
        return d3 += d2;
    }

    public static void testBivariateMethod() {
        System.out.println("Testing new computations ...");
        WishartDistribution wishartDistribution = new WishartDistribution(5.0, new double[][]{{2.0, -0.5}, {-0.5, 2.0}});
        double[] dArray = new double[]{4.0, 1.0, 1.0, 3.0};
        System.out.println("Fast logPdf = " + wishartDistribution.logPdf(dArray));
        System.out.println("Slow logPdf = " + wishartDistribution.logPdfSlow(dArray));
    }

    public static void main(String[] stringArray) {
        WishartDistribution wishartDistribution = new WishartDistribution(2.0, new double[][]{{500.0}});
        GammaDistribution gammaDistribution = new GammaDistribution(0.001, 1000.0);
        double[] dArray = new double[]{1.0};
        System.out.println("Wishart, df=2, scale = 500, PDF(1.0): " + wishartDistribution.logPdf(dArray));
        System.out.println("Gamma, shape = 1/1000, scale = 1000, PDF(1.0): " + gammaDistribution.logPdf(dArray[0]));
        wishartDistribution = new WishartDistribution(4.0, new double[][]{{5.0}});
        gammaDistribution = new GammaDistribution(2.0, 10.0);
        dArray = new double[]{1.0};
        System.out.println("Wishart, df=4, scale = 5, PDF(1.0): " + wishartDistribution.logPdf(dArray));
        System.out.println("Gamma, shape = 1/1000, scale = 10, PDF(1.0): " + gammaDistribution.logPdf(dArray[0]));
        wishartDistribution = new WishartDistribution(1);
        dArray = new double[]{0.1};
        System.out.println("Wishart, uninformative, PDF(0.1): " + wishartDistribution.logPdf(dArray));
        dArray = new double[]{1.0};
        System.out.println("Wishart, uninformative, PDF(1.0): " + wishartDistribution.logPdf(dArray));
        dArray = new double[]{10.0};
        System.out.println("Wishart, uninformative, PDF(10.0): " + wishartDistribution.logPdf(dArray));
        WishartDistribution.testBivariateMethod();
    }
}

