Line data Source code
1 : //# CASA - Common Astronomy Software Applications (http://casa.nrao.edu/)
2 : //# Copyright (C) Associated Universities, Inc. Washington DC, USA 2011, All
3 : //# rights reserved.
4 : //# Copyright (C) European Southern Observatory, 2011, All rights reserved.
5 : //#
6 : //# This library is free software; you can redistribute it and/or
7 : //# modify it under the terms of the GNU Lesser General Public
8 : //# License as published by the Free software Foundation; either
9 : //# version 2.1 of the License, or (at your option) any later version.
10 : //#
11 : //# This library is distributed in the hope that it will be useful,
12 : //# but WITHOUT ANY WARRANTY, without even the implied warranty of
13 : //# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
14 : //# Lesser General Public License for more details.
15 : //#
16 : //# You should have received a copy of the GNU Lesser General Public
17 : //# License along with this library; if not, write to the Free Software
18 : //# Foundation, Inc., 59 Temple Place, Suite 330, Boston,
19 : //# MA 02111-1307 USA
20 :
21 : #include <mstransform/TVI/StatWtClassicalDataAggregator.h>
22 :
23 : #include <casacore/casa/Arrays/Cube.h>
24 : #include <casacore/scimath/StatsFramework/ClassicalStatistics.h>
25 :
26 : #ifdef _OPENMP
27 : #include <omp.h>
28 : #endif
29 :
30 : using namespace casacore;
31 : using namespace std;
32 :
33 : namespace casa {
34 :
35 : namespace vi {
36 :
37 0 : StatWtClassicalDataAggregator::StatWtClassicalDataAggregator(
38 : ViImplementation2 *const vii,
39 : // shared_ptr<Bool>& mustComputeWtSp,
40 : const map<Int, vector<StatWtTypes::ChanBin>>& chanBins,
41 : std::shared_ptr<map<uInt, pair<uInt, uInt>>>& samples,
42 : StatWtTypes::Column column, Bool noModel,
43 : const map<uInt, Cube<Bool>>& chanSelFlags,
44 : shared_ptr<
45 : ClassicalStatistics<
46 : Double, Array<Float>::const_iterator,
47 : Array<Bool>::const_iterator
48 : >
49 : >& wtStats,
50 : shared_ptr<const pair<Double, Double>> wtrange, Bool combineCorr,
51 : shared_ptr<
52 : StatisticsAlgorithm<
53 : Double, Array<Float>::const_iterator, Array<Bool>::const_iterator,
54 : Array<Double>::const_iterator
55 : >
56 : >& statAlg, Int minSamp
57 0 : ) : StatWtDataAggregator(
58 : vii, chanBins, samples, column, noModel, chanSelFlags, /* mustComputeWtSp,*/
59 : wtStats, wtrange, combineCorr, statAlg, minSamp
60 0 : ) {}
61 :
62 0 : StatWtClassicalDataAggregator::~StatWtClassicalDataAggregator() {}
63 :
64 0 : void StatWtClassicalDataAggregator::aggregate() {
65 : // Drive NEXT LOWER layer's ViImpl to gather data into allvis:
66 : // Assumes all sub-chunks in the current chunk are to be used
67 : // for the variance calculation
68 : // Essentially, we are sorting the incoming data into
69 : // allvis, to enable a convenient variance calculation
70 0 : _variances.clear();
71 0 : auto* vb = _vii->getVisBuffer();
72 0 : std::map<StatWtTypes::BaselineChanBin, Cube<Complex>> data;
73 0 : std::map<StatWtTypes::BaselineChanBin, Cube<Bool>> flags;
74 0 : std::map<StatWtTypes::BaselineChanBin, Vector<Double>> exposures;
75 0 : IPosition blc(3, 0);
76 0 : auto trc = blc;
77 0 : auto initChanSelTemplate = True;
78 0 : Cube<Bool> chanSelFlagTemplate, chanSelFlags;
79 0 : auto firstTime = True;
80 : // we cannot know the spw until we are in the subchunks loop
81 0 : Int spw = -1;
82 0 : for (_vii->origin(); _vii->more(); _vii->next()) {
83 0 : if (_checkFirstSubChunk(spw, firstTime, vb)) {
84 0 : return;
85 : }
86 0 : if (! _mustComputeWtSp) {
87 0 : _mustComputeWtSp.reset(
88 : new Bool(
89 0 : vb->existsColumn(VisBufferComponent2::WeightSpectrum)
90 0 : )
91 : );
92 : }
93 0 : const auto& ant1 = vb->antenna1();
94 0 : const auto& ant2 = vb->antenna2();
95 : // [nCorr, nFreq, nRows)
96 0 : const auto& dataCube = _dataCube(vb);
97 0 : const auto& flagCube = vb->flagCube();
98 0 : const auto dataShape = dataCube.shape();
99 0 : const auto& exposureVector = vb->exposure();
100 0 : const auto nrows = vb->nRows();
101 0 : const auto npol = dataCube.nrow();
102 : const auto resultantFlags = _getResultantFlags(
103 : chanSelFlagTemplate, chanSelFlags, initChanSelTemplate,
104 : spw, flagCube
105 0 : );
106 0 : auto bins = _chanBins.find(spw)->second;
107 0 : StatWtTypes::BaselineChanBin blcb;
108 0 : blcb.spw = spw;
109 0 : IPosition dataCubeBLC(3, 0);
110 0 : auto dataCubeTRC = dataCube.shape() - 1;
111 0 : for (rownr_t row=0; row<nrows; ++row) {
112 0 : dataCubeBLC[2] = row;
113 0 : dataCubeTRC[2] = row;
114 0 : blcb.baseline = _baseline(ant1[row], ant2[row]);
115 0 : auto citer = bins.cbegin();
116 0 : auto cend = bins.cend();
117 0 : for (; citer!=cend; ++citer) {
118 0 : dataCubeBLC[1] = citer->start;
119 0 : dataCubeTRC[1] = citer->end;
120 0 : blcb.chanBin.start = citer->start;
121 0 : blcb.chanBin.end = citer->end;
122 0 : auto dataSlice = dataCube(dataCubeBLC, dataCubeTRC);
123 0 : auto flagSlice = resultantFlags(dataCubeBLC, dataCubeTRC);
124 0 : if (data.find(blcb) == data.end()) {
125 0 : data[blcb] = dataSlice;
126 0 : flags[blcb] = flagSlice;
127 0 : exposures[blcb] = Vector<Double>(1, exposureVector[row]);
128 : }
129 : else {
130 0 : auto myshape = data[blcb].shape();
131 0 : auto nplane = myshape[2];
132 0 : auto nchan = myshape[1];
133 0 : data[blcb].resize(npol, nchan, nplane+1, True);
134 0 : flags[blcb].resize(npol, nchan, nplane+1, True);
135 0 : exposures[blcb].resize(nplane+1, True);
136 0 : trc = myshape - 1;
137 : // because we've extended the cube by one plane since
138 : // myshape was determined.
139 0 : ++trc[2];
140 0 : blc[2] = trc[2];
141 0 : data[blcb](blc, trc) = dataSlice;
142 0 : flags[blcb](blc, trc) = flagSlice;
143 0 : exposures[blcb][trc[2]] = exposureVector[row];
144 : }
145 : }
146 : }
147 : }
148 0 : _computeVariances(data, flags, exposures);
149 : }
150 :
151 0 : void StatWtClassicalDataAggregator::weightSingleChanBin(
152 : Matrix<Float>& wtmat, Int nrows
153 : ) const {
154 0 : Vector<Int> ant1, ant2, spws;
155 0 : Vector<Double> exposures;
156 0 : _vii->antenna1(ant1);
157 0 : _vii->antenna2(ant2);
158 0 : _vii->spectralWindows(spws);
159 0 : _vii->exposure(exposures);
160 : // There is only one spw in a chunk
161 0 : auto spw = *spws.begin();
162 0 : StatWtTypes::BaselineChanBin blcb;
163 0 : blcb.spw = spw;
164 0 : for (Int i=0; i<nrows; ++i) {
165 0 : auto bins = _chanBins.find(spw)->second;
166 0 : blcb.baseline = _baseline(ant1[i], ant2[i]);
167 0 : blcb.chanBin = bins[0];
168 0 : auto variances = _variances.find(blcb)->second;
169 0 : if (_combineCorr) {
170 0 : wtmat.column(i) = exposures[i]/variances[0];
171 : }
172 : else {
173 0 : auto corr = 0;
174 0 : for (const auto variance: variances) {
175 0 : wtmat(corr, i) = exposures[i]/variance;
176 0 : ++corr;
177 : }
178 : }
179 : }
180 0 : }
181 :
182 0 : void StatWtClassicalDataAggregator::_computeVariances(
183 : const map<StatWtTypes::BaselineChanBin, Cube<Complex>>& data,
184 : const map<StatWtTypes::BaselineChanBin, Cube<Bool>>& flags,
185 : const map<StatWtTypes::BaselineChanBin, Vector<Double>>& exposures
186 : ) const {
187 0 : auto diter = data.cbegin();
188 0 : auto dend = data.cend();
189 0 : const auto nActCorr = diter->second.shape()[0];
190 0 : const auto ncorr = _combineCorr ? 1 : nActCorr;
191 : // spw will be the same for all members
192 0 : const auto& spw = data.begin()->first.spw;
193 0 : vector<StatWtTypes::BaselineChanBin> keys(data.size());
194 0 : auto idx = 0;
195 0 : for (; diter!=dend; ++diter, ++idx) {
196 0 : const auto& blcb = diter->first;
197 0 : keys[idx] = blcb;
198 0 : _variances[blcb].resize(ncorr);
199 : }
200 0 : auto n = keys.size();
201 : #ifdef _OPENMP
202 0 : #pragma omp parallel for
203 : // cout << "WARN OMP PARALLEL LOOPING IS OFF FOR DEBUGGING" << endl;
204 : #endif
205 : for (size_t i=0; i<n; ++i) {
206 : auto blcb = keys[i];
207 : auto dataForBLCB = data.find(blcb)->second;
208 : auto flagsForBLCB = flags.find(blcb)->second;
209 : auto exposuresForBLCB = exposures.find(blcb)->second;
210 : for (ssize_t corr=0; corr<ncorr; ++corr) {
211 : IPosition start(3, 0);
212 : auto end = dataForBLCB.shape() - 1;
213 : if (! _combineCorr) {
214 : start[0] = corr;
215 : end[0] = corr;
216 : }
217 : Slicer slice(start, end, Slicer::endIsLast);
218 : _variances[blcb][corr]
219 : = _varianceComputer->computeVariance(
220 : dataForBLCB(slice), flagsForBLCB(slice),
221 : exposuresForBLCB, spw
222 : );
223 : }
224 : }
225 0 : }
226 :
227 0 : void StatWtClassicalDataAggregator::weightSpectrumFlags(
228 : Cube<Float>& wtsp, Cube<Bool>& flagCube, Bool& checkFlags,
229 : const Vector<Int>& ant1, const Vector<Int>& ant2, const Vector<Int>& spws,
230 : const Vector<Double>& exposures, const Vector<rownr_t>&
231 : ) const {
232 0 : Slicer slice(IPosition(3, 0), flagCube.shape(), Slicer::endIsLength);
233 0 : auto sliceStart = slice.start();
234 0 : auto sliceEnd = slice.end();
235 0 : auto nrows = ant1.size();
236 0 : for (size_t i=0; i<nrows; ++i) {
237 0 : sliceStart[2] = i;
238 0 : sliceEnd[2] = i;
239 0 : StatWtTypes::BaselineChanBin blcb;
240 0 : blcb.baseline = _baseline(ant1[i], ant2[i]);
241 0 : auto spw = spws[i];
242 0 : blcb.spw = spw;
243 0 : auto bins = _chanBins.find(spw)->second;
244 0 : for (const auto& bin: bins) {
245 0 : sliceStart[1] = bin.start;
246 0 : sliceEnd[1] = bin.end;
247 0 : blcb.chanBin = bin;
248 0 : auto variances = _variances.find(blcb)->second;
249 0 : auto ncorr = variances.size();
250 0 : Vector<Double> weights(ncorr);
251 0 : for (size_t corr=0; corr<ncorr; ++corr) {
252 0 : if (! _combineCorr) {
253 0 : sliceStart[0] = corr;
254 0 : sliceEnd[0] = corr;
255 : }
256 0 : weights[corr] = variances[corr] == 0
257 0 : ? 0 : exposures[i]/variances[corr];
258 0 : slice.setStart(sliceStart);
259 0 : slice.setEnd(sliceEnd);
260 0 : _updateWtSpFlags(
261 0 : wtsp, flagCube, checkFlags, slice, weights[corr]
262 : );
263 : }
264 : }
265 : }
266 0 : }
267 :
268 : }
269 :
270 : }
|