001/*-
002 *******************************************************************************
003 * Copyright (c) 2011, 2016 Diamond Light Source Ltd.
004 * All rights reserved. This program and the accompanying materials
005 * are made available under the terms of the Eclipse Public License v1.0
006 * which accompanies this distribution, and is available at
007 * http://www.eclipse.org/legal/epl-v10.html
008 *
009 * Contributors:
010 *    Peter Chang - initial API and implementation and/or initial documentation
011 *******************************************************************************/
012
013package org.eclipse.january.dataset;
014
015import java.util.Arrays;
016import java.util.List;
017
018/**
019 * Class to run over a single dataset with NumPy broadcasting to promote shapes
020 * which have lower rank and outputs to a second dataset
021 */
022public class SingleInputBroadcastIterator extends IndexIterator {
023        private int[] maxShape;
024        private int[] aShape;
025        private final Dataset aDataset;
026        private final Dataset oDataset;
027        private int[] aStride;
028        private int[] oStride;
029
030        final private int endrank;
031
032        /**
033         * position in dataset
034         */
035        private final int[] pos;
036        private final int[] aDelta;
037        private final int[] oDelta; // this being non-null means output is different from inputs
038        private final int aStep, oStep;
039        private int aMax;
040        private int aStart, oStart;
041        private final boolean outputA;
042
043        /**
044         * Index in array
045         */
046        public int aIndex, oIndex;
047
048        /**
049         * Current value in array
050         */
051        public double aDouble;
052
053        /**
054         * Current value in array
055         */
056        public long aLong;
057
058        private boolean asDouble = true;
059
060        /**
061         * @param a
062         * @param o (can be null for new dataset, or a)
063         */
064        public SingleInputBroadcastIterator(Dataset a, Dataset o) {
065                this(a, o, false);
066        }
067
068        /**
069         * @param a
070         * @param o (can be null for new dataset, or a)
071         * @param createIfNull (by default, can create float or complex datasets)
072         */
073        public SingleInputBroadcastIterator(Dataset a, Dataset o, boolean createIfNull) {
074                this(a, o, createIfNull, false, true);
075        }
076
077        /**
078         * @param a
079         * @param o (can be null for new dataset, or a)
080         * @param createIfNull
081         * @param allowInteger if true, can create integer datasets
082         * @param allowComplex if true, can create complex datasets
083         */
084        @SuppressWarnings("deprecation")
085        public SingleInputBroadcastIterator(Dataset a, Dataset o, boolean createIfNull, boolean allowInteger, boolean allowComplex) {
086                List<int[]> fullShapes = BroadcastUtils.broadcastShapes(a.getShapeRef(), o == null ? null : o.getShapeRef());
087
088                checkItemSize(a, o);
089
090                maxShape = fullShapes.remove(0);
091
092                oStride = null;
093                if (o != null && !Arrays.equals(maxShape, o.getShapeRef())) {
094                        throw new IllegalArgumentException("Output does not match broadcasted shape");
095                }
096                aShape = fullShapes.remove(0);
097
098                int rank = maxShape.length;
099                endrank = rank - 1;
100
101                aDataset = a.reshape(aShape);
102                aStride = BroadcastUtils.createBroadcastStrides(aDataset, maxShape);
103                outputA = o == a;
104                if (outputA) {
105                        oStride = aStride;
106                        oDelta = null;
107                        oStep = 0;
108                        oDataset = aDataset;
109                } else if (o != null) {
110                        oStride = BroadcastUtils.createBroadcastStrides(o, maxShape);
111                        oDelta = new int[rank];
112                        oStep = o.getElementsPerItem();
113                        oDataset = o;
114                } else if (createIfNull) {
115                        int is = aDataset.getElementsPerItem();
116                        int dt = aDataset.getDType();
117                        if (aDataset.isComplex() && !allowComplex) {
118                                is = 1;
119                                dt = DTypeUtils.getBestFloatDType(dt);
120                        } else if (!aDataset.hasFloatingPointElements() && !allowInteger) {
121                                dt = DTypeUtils.getBestFloatDType(dt);
122                        }
123                        oDataset = DatasetFactory.zeros(is, maxShape, dt);
124                        oStride = BroadcastUtils.createBroadcastStrides(oDataset, maxShape);
125                        oDelta = new int[rank];
126                        oStep = oDataset.getElementsPerItem();
127                } else {
128                        oDelta = null;
129                        oStep = 0;
130                        oDataset = o;
131                }
132
133                pos = new int[rank];
134                aDelta = new int[rank];
135                aStep = aDataset.getElementsPerItem();
136                for (int j = endrank; j >= 0; j--) {
137                        aDelta[j] = aStride[j] * aShape[j];
138                        if (oDelta != null) {
139                                oDelta[j] = oStride[j] * maxShape[j];
140                        }
141                }
142                if (endrank < 0) {
143                        aMax = aStep;
144                } else {
145                        aMax = Integer.MIN_VALUE; // use max delta
146                        for (int j = endrank; j >= 0; j--) {
147                                if (aDelta[j] > aMax) {
148                                        aMax = aDelta[j];
149                                }
150                        }
151                }
152                aStart = aDataset.getOffset();
153                aMax += aStart;
154                oStart = oDelta == null ? 0 : oDataset.getOffset();
155                asDouble = aDataset.hasFloatingPointElements();
156                reset();
157        }
158
159        /**
160         * @return true if output from iterator is double
161         */
162        public boolean isOutputDouble() {
163                return asDouble;
164        }
165
166        /**
167         * Set to output doubles
168         * @param asDouble
169         */
170        public void setOutputDouble(boolean asDouble) {
171                if (this.asDouble != asDouble) {
172                        this.asDouble = asDouble;
173                        storeCurrentValues();
174                }
175        }
176
177        private static void checkItemSize(Dataset a, Dataset o) {
178                final int isa = a.getElementsPerItem();
179                if (o != null) {
180                        final int iso = o.getElementsPerItem();
181                        if (isa != 1 && iso != isa) {
182                                throw new IllegalArgumentException("Can not output to dataset whose number of elements per item mismatch inputs'");
183                        }
184                }
185        }
186
187        @Override
188        public int[] getShape() {
189                return maxShape;
190        }
191
192        @Override
193        public boolean hasNext() {
194                int j = endrank;
195                int oldA = aIndex;
196                for (; j >= 0; j--) {
197                        pos[j]++;
198                        aIndex += aStride[j];
199                        if (oDelta != null)
200                                oIndex += oStride[j];
201                        if (pos[j] >= maxShape[j]) {
202                                pos[j] = 0;
203                                aIndex -= aDelta[j]; // reset these dimensions
204                                if (oDelta != null)
205                                        oIndex -= oDelta[j];
206                        } else {
207                                break;
208                        }
209                }
210                if (j == -1) {
211                        if (endrank >= 0) {
212                                aIndex = aMax;
213                                return false;
214                        }
215                        aIndex += aStep;
216                        if (oDelta != null)
217                                oIndex += oStep;
218                }
219                if (outputA) {
220                        oIndex = aIndex;
221                }
222
223                if (aIndex == aMax)
224                        return false;
225
226                if (oldA != aIndex) {
227                        if (asDouble) {
228                                aDouble = aDataset.getElementDoubleAbs(aIndex);
229                        } else {
230                                aLong = aDataset.getElementLongAbs(aIndex);
231                        }
232                }
233
234                return true;
235        }
236
237        /**
238         * @return output dataset (can be null)
239         */
240        public Dataset getOutput() {
241                return oDataset;
242        }
243
244        @Override
245        public int[] getPos() {
246                return pos;
247        }
248
249        @Override
250        public void reset() {
251                for (int i = 0; i <= endrank; i++)
252                        pos[i] = 0;
253
254                if (endrank >= 0) {
255                        pos[endrank] = -1;
256                        aIndex = aStart - aStride[endrank];
257                        oIndex = oStart - (oStride == null ? 0 : oStride[endrank]);
258                } else {
259                        aIndex = -aStep;
260                        oIndex = -oStep;
261                }
262
263                // for zero-ranked datasets
264                if (aIndex == 0) {
265                        storeCurrentValues();
266                        if (aMax == aIndex)
267                                aMax++;
268                }
269        }
270
271        private void storeCurrentValues() {
272                if (aIndex >= 0) {
273                        if (asDouble) {
274                                aDouble = aDataset.getElementDoubleAbs(aIndex);
275                        } else {
276                                aLong = aDataset.getElementLongAbs(aIndex);
277                        }
278                }
279        }
280}