/*
 * Decompiled with CFR 0.152.
 */
package stats.glm;

import cern.colt.list.DoubleArrayList;
import cern.colt.list.IntArrayList;
import cern.colt.matrix.DoubleFactory1D;
import cern.colt.matrix.DoubleFactory2D;
import cern.colt.matrix.DoubleMatrix1D;
import cern.colt.matrix.DoubleMatrix2D;
import cern.colt.matrix.doublealgo.Transform;
import cern.colt.matrix.impl.AbstractMatrix;
import cern.colt.matrix.impl.AbstractMatrix1D;
import cern.colt.matrix.linalg.Algebra;
import com.imsl.math.Sfun;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import org.biojava.stats.svm.ItemValue;
import org.biojava.stats.svm.SVMTarget;
import org.biojava.stats.svm.TrainingEvent;
import org.biojava.stats.svm.TrainingListener;
import stats.glm.BasisFunction;
import stats.glm.BasisSource;
import stats.glm.GLMClassificationModel;
import stats.glm.GLMRegressionModel;
import stats.glm.GLMTrainer;
import stats.glm.SLMTrainingContext;

public class VRVMTrainer
implements GLMTrainer {
    private static final DoubleFactory1D mf1d = DoubleFactory1D.dense;
    private static final DoubleFactory2D mf2d = DoubleFactory2D.dense;
    private static final Algebra alg = Algebra.DEFAULT;
    private int maxCycles = 2000;
    private double initialAlpha = 1.0;
    private int cleaningCycles = 100;
    private int initialBasisSet = 0;
    private int maxWorkingSet = Integer.MAX_VALUE;
    private int minWorkingSet = 0;
    private boolean unityHack = false;
    private double unityHackThreshold = 1000.0;
    private boolean resetAlphaHack = false;

    public void setMaxCycles(int n) {
        this.maxCycles = n;
    }

    public void setInitialAlpha(double d) {
        this.initialAlpha = d;
    }

    public void setCleaningCycles(int n) {
        this.cleaningCycles = n;
    }

    public void setMaxWorkingSet(int n) {
        this.maxWorkingSet = n;
    }

    public void setMinWorkingSet(int n) {
        this.minWorkingSet = n;
    }

    public void setInitialBasisSet(int n) {
        this.initialBasisSet = n;
    }

    public void setUnityHack(boolean bl) {
        this.unityHack = bl;
    }

    public void setUnityHackThreshold(double d) {
        this.unityHackThreshold = d;
    }

    public void setResetAlphaHack(boolean bl) {
        this.resetAlphaHack = bl;
    }

    public GLMRegressionModel trainRegression(SVMTarget sVMTarget, BasisSource basisSource, TrainingListener trainingListener) {
        VRVMTrainContext vRVMTrainContext = new VRVMTrainContext(sVMTarget, basisSource, trainingListener);
        return vRVMTrainContext.train();
    }

    public GLMClassificationModel trainClassification(SVMTarget sVMTarget, BasisSource basisSource, TrainingListener trainingListener) {
        VRVMClassContext vRVMClassContext = new VRVMClassContext(sVMTarget, basisSource, trainingListener);
        return vRVMClassContext.train();
    }

    private class VRVMClassContext
    implements SLMTrainingContext {
        private int cycle = 0;
        private SVMTarget starget;
        private BasisSource basisSource;
        private TrainingListener listener;
        private TrainingEvent tevent;
        private List workingSet;
        private DoubleMatrix1D weights;

        VRVMClassContext(SVMTarget sVMTarget, BasisSource basisSource, TrainingListener trainingListener) {
            this.starget = sVMTarget;
            this.basisSource = basisSource;
            this.listener = trainingListener;
            this.tevent = new TrainingEvent(this);
        }

        GLMClassificationModel train() {
            Object object;
            HashMap<BasisFunction, Integer> hashMap = new HashMap<BasisFunction, Integer>();
            ArrayList<Object> arrayList = new ArrayList<Object>();
            int n = this.starget.items().size();
            DoubleMatrix1D doubleMatrix1D = mf1d.make(this.starget.items().size());
            int n2 = 0;
            Iterator iterator = this.starget.itemTargets().iterator();
            while (iterator.hasNext()) {
                ItemValue itemValue = (ItemValue)iterator.next();
                arrayList.add(itemValue.getItem());
                doubleMatrix1D.setQuick(n2++, itemValue.getValue());
            }
            this.workingSet = new ArrayList();
            int n3 = VRVMTrainer.this.initialBasisSet == 0 ? VRVMTrainer.this.maxWorkingSet : VRVMTrainer.this.initialBasisSet;
            while (this.basisSource.hasNext(this) && this.workingSet.size() < n3) {
                BasisFunction basisFunction = this.basisSource.next(this);
                hashMap.put(basisFunction, new Integer(0));
                this.workingSet.add(basisFunction);
            }
            int n4 = this.workingSet.size();
            Cloneable cloneable = mf2d.make(n, n4);
            int n5 = 0;
            while (n5 < n) {
                int n6 = 0;
                while (n6 < n4) {
                    ((DoubleMatrix2D)cloneable).setQuick(n5, n6, ((BasisFunction)this.workingSet.get(n6)).evaluate(arrayList.get(n5)));
                    ++n6;
                }
                ++n5;
            }
            DoubleMatrix1D doubleMatrix1D2 = mf1d.make(n4, VRVMTrainer.this.initialAlpha);
            DoubleMatrix1D doubleMatrix1D3 = mf1d.make(n, 1.0);
            double d = 1.0E-6;
            double d2 = 1.0E-6;
            this.weights = mf1d.make(n4);
            DoubleMatrix2D doubleMatrix2D = null;
            double d3 = 0.0;
            int n7 = 0;
            while (this.cycle < VRVMTrainer.this.maxCycles) {
                Cloneable cloneable2;
                DoubleMatrix1D doubleMatrix1D4;
                AbstractMatrix abstractMatrix;
                DoubleMatrix1D doubleMatrix1D5;
                n4 = doubleMatrix1D2.size();
                doubleMatrix2D = mf2d.diagonal(doubleMatrix1D2);
                object = mf2d.make(n4, n4);
                int n8 = 0;
                while (n8 < n) {
                    doubleMatrix1D5 = ((DoubleMatrix2D)cloneable).viewRow(n8);
                    alg.multOuter(doubleMatrix1D5, doubleMatrix1D5, (DoubleMatrix2D)object);
                    Transform.plusMult(doubleMatrix2D, (DoubleMatrix2D)object, 2.0 * this.lambda(doubleMatrix1D3.get(n8)));
                    ++n8;
                }
                doubleMatrix2D = alg.inverse(doubleMatrix2D);
                doubleMatrix1D5 = mf1d.make(n4);
                int n9 = 0;
                while (n9 < n) {
                    abstractMatrix = ((DoubleMatrix2D)cloneable).viewRow(n9);
                    Transform.plusMult(doubleMatrix1D5, abstractMatrix, 2.0 * doubleMatrix1D.get(n9) - 1.0);
                    ++n9;
                }
                this.weights = alg.mult(doubleMatrix2D, doubleMatrix1D5);
                Transform.mult(this.weights, 0.5);
                abstractMatrix = Transform.plus(alg.multOuter(this.weights, this.weights, null), doubleMatrix2D);
                double d4 = d2 + 0.5;
                DoubleMatrix1D doubleMatrix1D6 = mf1d.make(n4);
                int n10 = 0;
                while (n10 < n4) {
                    doubleMatrix1D6.set(n10, d + Math.pow(this.weights.get(n10), 2.0) / 2.0);
                    ++n10;
                }
                double d5 = 0.0;
                IntArrayList intArrayList = new IntArrayList();
                int n11 = 0;
                while (n11 < n4) {
                    double d6 = d4 / doubleMatrix1D6.get(n11);
                    if (n11 == 0 && VRVMTrainer.this.unityHack) {
                        d6 = Math.min(VRVMTrainer.this.unityHackThreshold, d6);
                    }
                    d5 = Math.max(d5, d6);
                    doubleMatrix1D2.set(n11, d6);
                    if (d6 < 1000.0) {
                        intArrayList.add(n11);
                    } else {
                        hashMap.remove(this.workingSet.get(n11));
                    }
                    ++n11;
                }
                int n12 = 0;
                while (n12 < n) {
                    doubleMatrix1D4 = ((DoubleMatrix2D)cloneable).viewRow(n12);
                    double d7 = Math.sqrt(alg.mult(alg.mult((DoubleMatrix2D)abstractMatrix, doubleMatrix1D4), doubleMatrix1D4));
                    doubleMatrix1D3.set(n12, d7);
                    ++n12;
                }
                if (intArrayList.size() < n4) {
                    doubleMatrix1D4 = mf1d.make(intArrayList.size());
                    DoubleMatrix2D doubleMatrix2D2 = mf2d.make(n, intArrayList.size());
                    cloneable2 = new ArrayList();
                    int n13 = 0;
                    while (n13 < intArrayList.size()) {
                        int n14 = intArrayList.get(n13);
                        doubleMatrix1D4.set(n13, doubleMatrix1D2.get(n14));
                        int n15 = 0;
                        while (n15 < n) {
                            doubleMatrix2D2.set(n15, n13, ((DoubleMatrix2D)cloneable).get(n15, n14));
                            ++n15;
                        }
                        cloneable2.add(this.workingSet.get(n14));
                        ++n13;
                    }
                    doubleMatrix1D2 = doubleMatrix1D4;
                    cloneable = doubleMatrix2D2;
                    this.workingSet = cloneable2;
                }
                if (this.workingSet.size() < VRVMTrainer.this.minWorkingSet && this.basisSource.hasNext(this) && this.cycle < VRVMTrainer.this.maxCycles - VRVMTrainer.this.cleaningCycles) {
                    System.out.print('+');
                    int n16 = this.workingSet.size();
                    while (this.workingSet.size() < VRVMTrainer.this.maxWorkingSet && this.basisSource.hasNext(this)) {
                        BasisFunction basisFunction = this.basisSource.next(this);
                        hashMap.put(basisFunction, new Integer(this.cycle + 1));
                        this.workingSet.add(basisFunction);
                    }
                    n4 = this.workingSet.size();
                    DoubleMatrix1D doubleMatrix1D7 = mf1d.make(n4);
                    cloneable2 = mf2d.make(n, n4);
                    double d8 = 0.0;
                    DoubleArrayList doubleArrayList = new DoubleArrayList();
                    int n17 = 0;
                    while (n17 < n16) {
                        if (VRVMTrainer.this.resetAlphaHack) {
                            doubleMatrix1D7.set(n17, VRVMTrainer.this.initialAlpha);
                        } else {
                            doubleMatrix1D7.set(n17, doubleMatrix1D2.get(n17));
                        }
                        double d9 = doubleMatrix1D2.get(n17);
                        d8 += d9;
                        doubleArrayList.add(d9);
                        int n18 = 0;
                        while (n18 < n) {
                            ((DoubleMatrix2D)cloneable2).set(n18, n17, ((DoubleMatrix2D)cloneable).get(n18, n17));
                            ++n18;
                        }
                        ++n17;
                    }
                    int n19 = n16;
                    while (n19 < n4) {
                        doubleMatrix1D7.set(n19, VRVMTrainer.this.initialAlpha);
                        int n20 = 0;
                        while (n20 < n) {
                            ((DoubleMatrix2D)cloneable2).set(n20, n19, ((BasisFunction)this.workingSet.get(n19)).evaluate(arrayList.get(n20)));
                            ++n20;
                        }
                        ++n19;
                    }
                    doubleMatrix1D2 = doubleMatrix1D7;
                    cloneable = cloneable2;
                }
                System.out.println("Cycle: " + this.cycle);
                System.out.println("Working set size: " + this.workingSet.size());
                System.out.println("Max alpha: " + d5);
                n7 = Math.abs((d5 - d3) / d5) < 0.01 ? ++n7 : 0;
                if (n7 > 5) {
                    System.out.print("Stagnation: ");
                    if (VRVMTrainer.this.minWorkingSet > 0) {
                        System.out.println("adding some extra model elements next time round.");
                        VRVMTrainer.this.maxWorkingSet += 5;
                        VRVMTrainer.this.minWorkingSet += 5;
                    } else {
                        System.out.println("giving up.");
                        this.cycle = VRVMTrainer.this.maxCycles;
                    }
                    n7 = 0;
                }
                this.listener.trainingCycleComplete(this.tevent);
                ++this.cycle;
                d3 = d5;
            }
            System.out.println("\n\nBasis function ages:");
            object = hashMap.values().iterator();
            while (object.hasNext()) {
                System.out.println(object.next().toString());
            }
            this.listener.trainingComplete(this.tevent);
            return this.freezeModel();
        }

        public GLMClassificationModel freezeModel() {
            return new GLMClassificationModel(this.workingSet, this.weights);
        }

        DoubleMatrix2D copy2D(DoubleMatrix1D doubleMatrix1D) {
            int n = doubleMatrix1D.size();
            DoubleMatrix2D doubleMatrix2D = mf2d.make(n, 1);
            int n2 = 0;
            while (n2 < n) {
                doubleMatrix2D.setQuick(n2, 0, doubleMatrix1D.getQuick(n2));
                ++n2;
            }
            return doubleMatrix2D;
        }

        DoubleMatrix1D view1D(DoubleMatrix2D doubleMatrix2D) {
            if (doubleMatrix2D.rows() == 1) {
                return doubleMatrix2D.viewRow(0);
            }
            if (doubleMatrix2D.columns() == 1) {
                return doubleMatrix2D.viewColumn(0);
            }
            throw new RuntimeException("Matrix must be a vector.");
        }

        private double logit(double d) {
            return 1.0 / (1.0 + Math.exp(-d));
        }

        private double invlogit(double d) {
            return Math.log(d / (1.0 - d));
        }

        private double diflogit(double d) {
            double d2 = Math.exp(-d);
            return d2 / Math.pow(1.0 + d2, 2.0);
        }

        private double lambda(double d) {
            return 0.25 / d * Sfun.tanh(d / 2.0);
        }

        public double getDeviation() {
            throw new UnsupportedOperationException();
        }

        public int getCurrentCycle() {
            return this.cycle;
        }

        public List getBasisList() {
            return this.workingSet;
        }

        public double getWeightForBasis(BasisFunction basisFunction) {
            int n = this.workingSet.indexOf(basisFunction);
            if (n < 0) {
                return 0.0;
            }
            return this.weights.get(n);
        }

        public SVMTarget getTarget() {
            return this.starget;
        }
    }

    private class VRVMTrainContext
    implements SLMTrainingContext {
        private int cycle = 0;
        private SVMTarget starget;
        private BasisSource basisSource;
        private TrainingListener listener;
        private TrainingEvent tevent;
        private List workingSet;
        private DoubleMatrix1D weights;
        private DoubleMatrix2D Sigma;

        VRVMTrainContext(SVMTarget sVMTarget, BasisSource basisSource, TrainingListener trainingListener) {
            this.starget = sVMTarget;
            this.basisSource = basisSource;
            this.listener = trainingListener;
            this.tevent = new TrainingEvent(this);
        }

        GLMRegressionModel train() {
            Object object;
            Object object2;
            HashMap<Object, Integer> hashMap = new HashMap<Object, Integer>();
            ArrayList<Object> arrayList = new ArrayList<Object>();
            int n = this.starget.items().size();
            DoubleMatrix1D doubleMatrix1D = mf1d.make(this.starget.items().size());
            int n2 = 0;
            Iterator iterator = this.starget.itemTargets().iterator();
            while (iterator.hasNext()) {
                object2 = (ItemValue)iterator.next();
                arrayList.add(object2.getItem());
                doubleMatrix1D.set(n2++, object2.getValue());
            }
            this.workingSet = new ArrayList();
            while (this.basisSource.hasNext(this) && this.workingSet.size() < VRVMTrainer.this.maxWorkingSet) {
                object2 = this.basisSource.next(this);
                hashMap.put(object2, new Integer(0));
                this.workingSet.add(object2);
            }
            int n3 = this.workingSet.size();
            Cloneable cloneable = mf2d.make(n, n3);
            int n4 = 0;
            while (n4 < n) {
                int n5 = 0;
                while (n5 < n3) {
                    ((DoubleMatrix2D)cloneable).set(n4, n5, ((BasisFunction)this.workingSet.get(n5)).evaluate(arrayList.get(n4)));
                    ++n5;
                }
                ++n4;
            }
            Object object3 = mf1d.make(n3, VRVMTrainer.this.initialAlpha);
            double d = 0.1;
            double d2 = 1.0E-6;
            double d3 = 1.0E-6;
            double d4 = 1.0E-6;
            double d5 = 1.0E-6;
            while (this.cycle < VRVMTrainer.this.maxCycles) {
                int n6;
                int n7;
                int n8;
                Cloneable cloneable2;
                Object object4;
                AbstractMatrix abstractMatrix;
                DoubleMatrix1D doubleMatrix1D2;
                n3 = ((AbstractMatrix1D)object3).size();
                this.Sigma = mf2d.make(n3, n3);
                object = mf2d.make(n3, n3);
                int n9 = 0;
                while (n9 < n) {
                    doubleMatrix1D2 = ((DoubleMatrix2D)cloneable).viewRow(n9);
                    alg.multOuter(doubleMatrix1D2, doubleMatrix1D2, (DoubleMatrix2D)object);
                    Transform.plus(this.Sigma, (DoubleMatrix2D)object);
                    ++n9;
                }
                Transform.mult(this.Sigma, d);
                Transform.plus(this.Sigma, mf2d.diagonal((DoubleMatrix1D)object3));
                this.Sigma = alg.inverse(this.Sigma);
                doubleMatrix1D2 = mf1d.make(n3);
                int n10 = 0;
                while (n10 < n) {
                    abstractMatrix = ((DoubleMatrix2D)cloneable).viewRow(n10);
                    Transform.plusMult(doubleMatrix1D2, abstractMatrix, doubleMatrix1D.get(n10));
                    ++n10;
                }
                this.weights = alg.mult(this.Sigma, doubleMatrix1D2);
                Transform.mult(this.weights, d);
                abstractMatrix = Transform.plus(alg.multOuter(this.weights, this.weights, null), this.Sigma);
                double d6 = d5 + 0.5;
                DoubleMatrix1D doubleMatrix1D3 = mf1d.make(n3);
                int n11 = 0;
                while (n11 < n3) {
                    doubleMatrix1D3.set(n11, d4 + Math.pow(this.weights.get(n11), 2.0) / 2.0);
                    ++n11;
                }
                double d7 = d3 + (double)((n + 1) / 2);
                double d8 = d2;
                int n12 = 0;
                while (n12 < n) {
                    d8 += Math.pow(doubleMatrix1D.get(n12), 2.0) / 2.0;
                    ++n12;
                }
                DoubleMatrix1D doubleMatrix1D4 = mf1d.make(n3);
                int n13 = 0;
                while (n13 < n) {
                    Transform.plusMult(doubleMatrix1D4, ((DoubleMatrix2D)cloneable).viewRow(n13), doubleMatrix1D.get(n13));
                    ++n13;
                }
                d8 -= alg.mult(this.weights, doubleMatrix1D4);
                int n14 = 0;
                while (n14 < n) {
                    DoubleMatrix1D doubleMatrix1D5 = ((DoubleMatrix2D)cloneable).viewRow(n14);
                    d8 += alg.mult(alg.mult((DoubleMatrix2D)abstractMatrix, doubleMatrix1D5), doubleMatrix1D5) / 2.0;
                    ++n14;
                }
                double d9 = 0.0;
                IntArrayList intArrayList = new IntArrayList();
                int n15 = 0;
                while (n15 < n3) {
                    double d10 = d6 / doubleMatrix1D3.get(n15);
                    d9 = Math.max(d9, d10);
                    ((DoubleMatrix1D)object3).set(n15, d10);
                    if (d10 < 499000.0) {
                        intArrayList.add(n15);
                    } else {
                        hashMap.remove(this.workingSet.get(n15));
                    }
                    ++n15;
                }
                d = d7 / d8;
                System.out.println("maxAlpha = " + d9);
                if (intArrayList.size() < n3) {
                    DoubleMatrix1D doubleMatrix1D6 = mf1d.make(intArrayList.size());
                    object4 = mf2d.make(n, intArrayList.size());
                    cloneable2 = new ArrayList();
                    n8 = 0;
                    while (n8 < intArrayList.size()) {
                        n7 = intArrayList.get(n8);
                        doubleMatrix1D6.set(n8, ((DoubleMatrix1D)object3).get(n7));
                        n6 = 0;
                        while (n6 < n) {
                            ((DoubleMatrix2D)object4).set(n6, n8, ((DoubleMatrix2D)cloneable).get(n6, n7));
                            ++n6;
                        }
                        cloneable2.add(this.workingSet.get(n7));
                        ++n8;
                    }
                    object3 = doubleMatrix1D6;
                    cloneable = object4;
                    this.workingSet = cloneable2;
                }
                if (this.workingSet.size() < VRVMTrainer.this.minWorkingSet && this.basisSource.hasNext(this) && this.cycle < VRVMTrainer.this.maxCycles - VRVMTrainer.this.cleaningCycles) {
                    int n16 = this.workingSet.size();
                    while (this.workingSet.size() < VRVMTrainer.this.maxWorkingSet && this.basisSource.hasNext(this)) {
                        object4 = this.basisSource.next(this);
                        hashMap.put(object4, new Integer(this.cycle + 1));
                        this.workingSet.add(object4);
                    }
                    n3 = this.workingSet.size();
                    object4 = mf1d.make(n3);
                    cloneable2 = mf2d.make(n, n3);
                    n8 = 0;
                    while (n8 < n16) {
                        ((DoubleMatrix1D)object4).set(n8, ((DoubleMatrix1D)object3).get(n8));
                        n7 = 0;
                        while (n7 < n) {
                            ((DoubleMatrix2D)cloneable2).set(n7, n8, ((DoubleMatrix2D)cloneable).get(n7, n8));
                            ++n7;
                        }
                        ++n8;
                    }
                    n7 = n16;
                    while (n7 < n3) {
                        ((DoubleMatrix1D)object4).set(n7, VRVMTrainer.this.initialAlpha);
                        n6 = 0;
                        while (n6 < n) {
                            ((DoubleMatrix2D)cloneable2).set(n6, n7, ((BasisFunction)this.workingSet.get(n7)).evaluate(arrayList.get(n6)));
                            ++n6;
                        }
                        ++n7;
                    }
                    object3 = object4;
                    cloneable = cloneable2;
                }
                ++this.cycle;
                this.listener.trainingCycleComplete(this.tevent);
            }
            System.out.println("\n\nBasis function ages:");
            object = hashMap.values().iterator();
            while (object.hasNext()) {
                System.out.println(object.next().toString());
            }
            this.listener.trainingComplete(this.tevent);
            return new GLMRegressionModel(this.workingSet, this.weights, this.Sigma, d);
        }

        public int getCurrentCycle() {
            return this.cycle;
        }

        public List getBasisList() {
            return this.workingSet;
        }

        public double getWeightForBasis(BasisFunction basisFunction) {
            int n = this.workingSet.indexOf(basisFunction);
            if (n < 0) {
                return 0.0;
            }
            return this.weights.get(n);
        }

        public SVMTarget getTarget() {
            return this.starget;
        }

        public double getDeviation() {
            throw new UnsupportedOperationException();
        }

        public GLMClassificationModel freezeModel() {
            throw new UnsupportedOperationException();
        }
    }
}

