Line data Source code
1 : //# CASA - Common Astronomy Software Applications (http://casa.nrao.edu/)
2 : //# Copyright (C) Associated Universities, Inc. Washington DC, USA 2011, All
3 : //# rights reserved.
4 : //# Copyright (C) European Southern Observatory, 2011, All rights reserved.
5 : //#
6 : //# This library is free software; you can redistribute it and/or
7 : //# modify it under the terms of the GNU Lesser General Public
8 : //# License as published by the Free software Foundation; either
9 : //# version 2.1 of the License, or (at your option) any later version.
10 : //#
11 : //# This library is distributed in the hope that it will be useful,
12 : //# but WITHOUT ANY WARRANTY, without even the implied warranty of
13 : //# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
14 : //# Lesser General Public License for more details.
15 : //#
16 : //# You should have received a copy of the GNU Lesser General Public
17 : //# License along with this library; if not, write to the Free Software
18 : //# Foundation, Inc., 59 Temple Place, Suite 330, Boston,
19 : //# MA 02111-1307 USA
20 :
21 : #include <mstransform/TVI/StatWtDataAggregator.h>
22 :
23 : #include <casacore/casa/Arrays/ArrayMath.h>
24 : #include <casacore/casa/Arrays/Cube.h>
25 : // debug
26 : #include <casacore/casa/IO/ArrayIO.h>
27 :
28 : #ifdef _OPENMP
29 : #include <omp.h>
30 : #endif
31 :
32 : using namespace casacore;
33 : using namespace std;
34 :
35 : namespace casa {
36 :
37 : namespace vi {
38 :
39 0 : StatWtDataAggregator::StatWtDataAggregator(
40 : ViImplementation2 *const vii,
41 : const map<Int, vector<StatWtTypes::ChanBin>>& chanBins,
42 : std::shared_ptr<map<uInt, pair<uInt, uInt>>>& samples,
43 : StatWtTypes::Column column, Bool noModel,
44 : const map<uInt, Cube<Bool>>& chanSelFlags,
45 : std::shared_ptr<
46 : casacore::ClassicalStatistics<casacore::Double,
47 : casacore::Array<casacore::Float>::const_iterator,
48 : casacore::Array<casacore::Bool>::const_iterator>
49 : >& wtStats,
50 : shared_ptr<const pair<Double, Double>> wtrange,
51 : Bool combineCorr,
52 : shared_ptr<
53 : StatisticsAlgorithm<
54 : Double, Array<Float>::const_iterator,
55 : Array<Bool>::const_iterator, Array<Double>::const_iterator
56 : >
57 : >& statAlg, Int minSamp
58 0 : ) : _vii(vii), _chanBins(chanBins), _samples(samples),
59 : _varianceComputer(
60 0 : new StatWtVarianceAndWeightCalculator(statAlg, samples, minSamp)
61 : ),
62 : _column(column),_noModel(noModel), _chanSelFlags(chanSelFlags),
63 0 : _wtStats(wtStats), _wtrange(wtrange), _combineCorr(combineCorr) {}
64 :
65 :
66 0 : StatWtDataAggregator::~StatWtDataAggregator() {}
67 :
68 0 : Bool StatWtDataAggregator::mustComputeWtSp() const {
69 0 : return *_mustComputeWtSp;
70 : }
71 :
72 0 : void StatWtDataAggregator::setMustComputeWtSp(
73 : std::shared_ptr<casacore::Bool> mcwp
74 : ) {
75 0 : _mustComputeWtSp = mcwp;
76 0 : }
77 :
78 0 : StatWtTypes::Baseline StatWtDataAggregator::_baseline(
79 : uInt ant1, uInt ant2
80 : ) {
81 0 : return StatWtTypes::Baseline(min(ant1, ant2), max(ant1, ant2));
82 : }
83 :
84 0 : Bool StatWtDataAggregator::_checkFirstSubChunk(
85 : Int& spw, Bool& firstTime, const VisBuffer2 * const vb
86 : ) const {
87 0 : if (! firstTime) {
88 : // this chunk has already been checked, it has not
89 : // been processed previously
90 0 : return False;
91 : }
92 0 : const auto& rowIDs = vb->rowIds();
93 0 : if (_processedRowIDs.find(rowIDs[0]) == _processedRowIDs.end()) {
94 : // haven't processed this chunk
95 0 : _processedRowIDs.insert(rowIDs[0]);
96 : // the spw is the same for all subchunks, so it only needs to
97 : // be set once
98 0 : spw = *vb->spectralWindows().begin();
99 0 : if (_samples->find(spw) == _samples->end()) {
100 0 : (*_samples)[spw].first = 0;
101 0 : (*_samples)[spw].second = 0;
102 : }
103 0 : firstTime = False;
104 0 : return False;
105 : }
106 : else {
107 : // this chunk has been processed, this can happen at the end
108 : // when the last chunk is processed twice
109 0 : return True;
110 : }
111 : }
112 :
113 0 : const Cube<Complex> StatWtDataAggregator::_dataCube(
114 : const VisBuffer2 *const vb
115 : ) const {
116 0 : switch (_column) {
117 0 : case StatWtTypes::CORRECTED:
118 0 : return vb->visCubeCorrected();
119 0 : case StatWtTypes::DATA:
120 0 : return vb->visCube();
121 0 : case StatWtTypes::RESIDUAL:
122 0 : if (_noModel) {
123 0 : return vb->visCubeCorrected();
124 : }
125 : else {
126 0 : return vb->visCubeCorrected() - vb->visCubeModel();
127 : }
128 0 : case StatWtTypes::RESIDUAL_DATA:
129 0 : if(_noModel) {
130 0 : return vb->visCube();
131 : }
132 : else {
133 0 : return vb->visCube() - vb->visCubeModel();
134 : }
135 0 : default:
136 0 : ThrowCc("Logic error: column type not handled");
137 : }
138 : }
139 :
140 0 : Cube<Bool> StatWtDataAggregator::_getResultantFlags(
141 : Cube<Bool>& chanSelFlagTemplate, Cube<Bool>& chanSelFlags,
142 : Bool& initTemplate, Int spw, const Cube<Bool>& flagCube
143 : ) const {
144 0 : if (_chanSelFlags.find(spw) == _chanSelFlags.cend()) {
145 : // no selection of channels to ignore
146 0 : return flagCube;
147 : }
148 0 : if (initTemplate) {
149 : // this can be done just once per chunk because all the rows
150 : // in the chunk are guaranteed to have the same spw
151 : // because each subchunk is guaranteed to have a single
152 : // data description ID.
153 0 : chanSelFlagTemplate = _chanSelFlags.find(spw)->second;
154 0 : initTemplate = False;
155 : }
156 0 : auto dataShape = flagCube.shape();
157 0 : chanSelFlags.resize(dataShape, False);
158 0 : auto ncorr = dataShape[0];
159 0 : auto nrows = dataShape[2];
160 0 : IPosition start(3, 0);
161 0 : IPosition end = dataShape - 1;
162 0 : Slicer sl(start, end, Slicer::endIsLast);
163 0 : for (uInt corr=0; corr<ncorr; ++corr) {
164 0 : start[0] = corr;
165 0 : end[0] = corr;
166 0 : for (Int row=0; row<nrows; ++row) {
167 0 : start[2] = row;
168 0 : end[2] = row;
169 0 : sl.setStart(start);
170 0 : sl.setEnd(end);
171 0 : chanSelFlags(sl) = chanSelFlagTemplate;
172 : }
173 : }
174 0 : return flagCube || chanSelFlags;
175 : }
176 :
177 0 : void StatWtDataAggregator::_updateWtSpFlags(
178 : Cube<Float>& wtsp, Cube<Bool>& flags, Bool& checkFlags,
179 : const Slicer& slice, Float wt
180 : ) const {
181 : // writable array reference
182 0 : auto flagSlice = flags(slice);
183 0 : if (*_mustComputeWtSp) {
184 : // writable array reference
185 0 : auto wtSlice = wtsp(slice);
186 0 : wtSlice = wt;
187 : // update global stats before we potentially flag data
188 0 : auto mask = ! flagSlice;
189 0 : _wtStats->addData(wtSlice.begin(), mask.begin(), wtSlice.size());
190 : }
191 0 : else if (! allTrue(flagSlice)) {
192 : // we don't need to compute WEIGHT_SPECTRUM, and the slice isn't
193 : // entirely flagged, so we need to update the WEIGHT column stats
194 0 : _wtStats->addData(Array<Float>(IPosition(1, 1), wt).begin(), 1);
195 : }
196 0 : if (
197 0 : wt == 0
198 0 : || (_wtrange && (wt < _wtrange->first || wt > _wtrange->second))
199 : ) {
200 0 : if (*_mustComputeWtSp) {
201 0 : wtsp(slice) = 0;
202 : }
203 0 : checkFlags = True;
204 0 : flagSlice = True;
205 : }
206 :
207 0 : }
208 :
209 : }
210 :
211 : }
|