/*
 * Decompiled with CFR 0.152.
 */
package dr.inference.hmc;

import dr.inference.hmc.GradientWrtParameterProvider;
import dr.inference.loggers.LogColumn;
import dr.inference.loggers.Loggable;
import dr.inference.loggers.NumberColumn;
import dr.xml.Reportable;
import java.util.List;

public class GradientErrorLogger
implements Loggable,
Reportable {
    private final GradientWrtParameterProvider source;
    private final List<Statistic> statistics;
    private boolean gradientKnown = false;
    private double[] gradient;
    private double[] reference;

    public GradientErrorLogger(GradientWrtParameterProvider gradientWrtParameterProvider, List<Statistic> list) {
        this.source = gradientWrtParameterProvider;
        this.statistics = list;
    }

    private double getStatisticValue(Statistic statistic) {
        if (!this.gradientKnown) {
            this.gradient = this.source.getGradientLogDensity();
            this.reference = new GradientWrtParameterProvider.CheckGradientNumerically(this.source, Double.NEGATIVE_INFINITY, Double.POSITIVE_INFINITY, null, null).getNumericalGradient();
            this.gradientKnown = true;
        }
        return statistic.getStatistic(this.gradient, this.reference);
    }

    @Override
    public LogColumn[] getColumns() {
        LogColumn[] logColumnArray = new LogColumn[this.statistics.size()];
        for (int i = 0; i < logColumnArray.length; ++i) {
            final Statistic statistic = this.statistics.get(i);
            final int n = i;
            String string = this.source.getParameter().getId() + "." + statistic.getName();
            logColumnArray[i] = new NumberColumn(string){

                @Override
                public double getDoubleValue() {
                    if (n == 0) {
                        GradientErrorLogger.this.gradientKnown = false;
                    }
                    return GradientErrorLogger.this.getStatisticValue(statistic);
                }
            };
        }
        return logColumnArray;
    }

    @Override
    public String getReport() {
        throw new RuntimeException("Not yet implemented");
    }

    private static double maxAbsDifference(double[] dArray, double[] dArray2, double d) {
        assert (dArray.length == dArray2.length);
        double d2 = Math.abs(dArray[0] - dArray2[0]) / d;
        for (int i = 1; i < dArray.length; ++i) {
            double d3 = Math.abs(dArray[i] - dArray2[i]) / d;
            if (!(d3 > d2)) continue;
            d2 = d3;
        }
        return d2;
    }

    private static double innerProduct(double[] dArray, double[] dArray2) {
        assert (dArray.length == dArray2.length);
        double d = 0.0;
        for (int i = 0; i < dArray.length; ++i) {
            d += dArray[i] * dArray2[i];
        }
        return d;
    }

    public static enum Statistic {
        MAX_ERROR_ABSOLUTE("maxErrorAbsolute"){

            @Override
            double getStatistic(double[] dArray, double[] dArray2) {
                return GradientErrorLogger.maxAbsDifference(dArray, dArray2, 1.0);
            }
        }
        ,
        MAX_ERROR_RELATIVE("maxErrorRelative"){

            @Override
            double getStatistic(double[] dArray, double[] dArray2) {
                double d = Math.sqrt(GradientErrorLogger.innerProduct(dArray, dArray) * GradientErrorLogger.innerProduct(dArray2, dArray2));
                return GradientErrorLogger.maxAbsDifference(dArray, dArray2, d);
            }
        }
        ,
        ANGLE("angle"){

            @Override
            double getStatistic(double[] dArray, double[] dArray2) {
                double d = GradientErrorLogger.innerProduct(dArray, dArray2);
                double d2 = Math.sqrt(GradientErrorLogger.innerProduct(dArray, dArray));
                double d3 = Math.sqrt(GradientErrorLogger.innerProduct(dArray2, dArray2));
                return Math.acos(d / (d2 * d3));
            }
        };

        private final String name;

        private Statistic(String string2) {
            this.name = string2;
        }

        abstract double getStatistic(double[] var1, double[] var2);

        public final String getName() {
            return this.name;
        }

        public static Statistic parse(String string) {
            for (Statistic statistic : Statistic.values()) {
                if (!statistic.name.equalsIgnoreCase(string)) continue;
                return statistic;
            }
            return null;
        }
    }
}

