001package org.opengion.penguin.math.statistics;
002
003import java.util.Arrays;
004
005import org.apache.commons.math3.stat.regression.OLSMultipleLinearRegression;
006
007/**
008 * apache.commons.mathを利用したOLS重回帰計算のクラスです。
009 * y = c0 + x1c1 + x2c2 + x3c3 ...の係数を求めます。
010 * c0の切片を考慮するかどうかはnoInterceptで決めます。
011 * 
012 */
013public class HybsMultiRegression implements HybsRegression {
014        private double cnst[];                  // 各係数(xの種類+1になる?)
015        private double rsquare;                 // 決定係数
016        private boolean noIntercept;    //切片を利用するかどうか
017
018        /**
019         * コンストラクタ。
020         * 与えた二次元データを元に重回帰を計算します。
021         * xデータとして二次元配列を与えます。
022         * noInterceptで切片有り無しを選択します。
023         * 
024         * @param in_x 説明変数
025         * @param in_y 目的変数
026         * @param noIntercept 切片利用有無(trueで利用しない)
027         */
028        public HybsMultiRegression( final double[][] in_x, final double[] in_y, final boolean noIntercept ) {
029                train( in_x, in_y, noIntercept );
030        }
031
032        /**
033         * 与えた二次元データを元に重回帰を計算します。
034         * xデータとして二次元配列を与えます。
035         * noInterceptで切片有り無しを選択します。
036         * 
037         * @param in_x 説明変数
038         * @param in_y 目的変数
039         * @param noIntercept 切片利用有無(trueで利用しない)
040         */
041        private void train( final double[][] in_x, final double[] in_y, final boolean noIntercept ) {
042                this.noIntercept = noIntercept;
043
044                // ここで重回帰計算
045                final OLSMultipleLinearRegression regression = new OLSMultipleLinearRegression();
046                regression.setNoIntercept(noIntercept);
047        regression.newSampleData(in_y, in_x);
048
049                cnst    = regression.estimateRegressionParameters();
050                rsquare = regression.calculateRSquared();
051        }
052
053        /**
054         * 係数をセットした配列を返します。
055         *
056         * @return 係数の配列
057         */
058        @Override
059        public double[] getCoefficient() {
060                return Arrays.copyOf( cnst,cnst.length );
061        }
062
063        /**
064         * 決定係数の取得。
065         * @return 決定係数
066         */
067        @Override
068        public double getRSquare() {
069                return rsquare;
070        }
071
072        /**
073         * 計算( c0 + c1x1...)を行う。
074         * noInterceptによってc0の利用を決める。
075         * xの大きさが足りない場合は0を返す。
076         * 
077         * @param in_x 必要な大きさの変数配列
078         * @return 計算結果
079         */
080        @Override
081        public double predict( final double... in_x ) {
082                double rtn = 0;
083                final int itr = noIntercept ? 0 : 1;
084                if( in_x.length < cnst.length-itr ) {
085                        return rtn;
086                }
087
088                for( int i=0; i < in_x.length; i++ ) {
089                        rtn = rtn + in_x[i] * cnst[i+itr];
090                }
091                if( !noIntercept ) { rtn = rtn + cnst[0]; }
092
093                return rtn;
094        }
095
096        //************** ここまでが本体 **************
097        /**
098         * ここからテスト用mainメソッド 。
099         *
100         * @param args 引数
101         */
102        public static void main( final String[] args ) {
103                // データはhttp://mjin.doshisha.ac.jp/R/14.htmlより
104                final double[] y = new double[] { 50, 60, 65, 65, 70, 75, 80, 85, 90, 95 };
105                double[][] x = new double[10][];
106                x[0] = new double[] { 165, 65 };
107                x[1] = new double[] { 170, 68 };
108                x[2] = new double[] { 172, 70 };
109                x[3] = new double[] { 175, 65 };
110                x[4] = new double[] { 170, 80 };
111                x[5] = new double[] { 172, 85 };
112                x[6] = new double[] { 183, 78 };
113                x[7] = new double[] { 187, 79 };
114                x[8] = new double[] { 180, 95 };
115                x[9] = new double[] { 185, 97 };
116
117                final HybsMultiRegression mr = new HybsMultiRegression(x,y,true);
118
119                System.out.println( mr.getRSquare() );
120                System.out.println( Arrays.toString( mr.getCoefficient()) );
121
122                System.out.println( mr.predict( new double[] { 169,85 } ));
123        }
124}
125