LCOV - code coverage report
Current view: top level - mstransform/TVI - StatWtClassicalDataAggregator.cc (source / functions) Hit Total Coverage
Test: casa_coverage.info Lines: 134 135 99.3 %
Date: 2023-10-25 08:47:59 Functions: 7 7 100.0 %

          Line data    Source code
       1             : //#  CASA - Common Astronomy Software Applications (http://casa.nrao.edu/)
       2             : //#  Copyright (C) Associated Universities, Inc. Washington DC, USA 2011, All
       3             : //#  rights reserved.
       4             : //#  Copyright (C) European Southern Observatory, 2011, All rights reserved.
       5             : //#
       6             : //#  This library is free software; you can redistribute it and/or
       7             : //#  modify it under the terms of the GNU Lesser General Public
       8             : //#  License as published by the Free software Foundation; either
       9             : //#  version 2.1 of the License, or (at your option) any later version.
      10             : //#
      11             : //#  This library is distributed in the hope that it will be useful,
      12             : //#  but WITHOUT ANY WARRANTY, without even the implied warranty of
      13             : //#  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
      14             : //#  Lesser General Public License for more details.
      15             : //#
      16             : //#  You should have received a copy of the GNU Lesser General Public
      17             : //#  License along with this library; if not, write to the Free Software
      18             : //#  Foundation, Inc., 59 Temple Place, Suite 330, Boston,
      19             : //#  MA 02111-1307  USA
      20             : 
      21             : #include <mstransform/TVI/StatWtClassicalDataAggregator.h>
      22             : 
      23             : #include <casacore/casa/Arrays/Cube.h>
      24             : #include <casacore/scimath/StatsFramework/ClassicalStatistics.h>
      25             : 
      26             : #ifdef _OPENMP
      27             : #include <omp.h>
      28             : #endif
      29             : 
      30             : using namespace casacore;
      31             : using namespace std;
      32             : 
      33             : namespace casa {
      34             : 
      35             : namespace vi {
      36             : 
      37          12 : StatWtClassicalDataAggregator::StatWtClassicalDataAggregator(
      38             :     ViImplementation2 *const vii,
      39             :     // shared_ptr<Bool>& mustComputeWtSp,
      40             :     const map<Int, vector<StatWtTypes::ChanBin>>& chanBins,
      41             :     std::shared_ptr<map<uInt, pair<uInt, uInt>>>& samples,
      42             :     StatWtTypes::Column column, Bool noModel,
      43             :     const map<uInt, Cube<Bool>>& chanSelFlags,
      44             :     shared_ptr<
      45             :         ClassicalStatistics<
      46             :             Double, Array<Float>::const_iterator,
      47             :             Array<Bool>::const_iterator
      48             :         >
      49             :     >& wtStats,
      50             :     shared_ptr<const pair<Double, Double>> wtrange, Bool combineCorr,
      51             :     shared_ptr<
      52             :         StatisticsAlgorithm<
      53             :             Double, Array<Float>::const_iterator, Array<Bool>::const_iterator,
      54             :             Array<Double>::const_iterator
      55             :         >
      56             :     >& statAlg, Int minSamp
      57          12 : ) : StatWtDataAggregator(
      58             :        vii, chanBins, samples, column, noModel, chanSelFlags, /* mustComputeWtSp,*/
      59             :        wtStats, wtrange, combineCorr, statAlg, minSamp
      60          12 :     ) {}
      61             : 
      62          24 : StatWtClassicalDataAggregator::~StatWtClassicalDataAggregator() {}
      63             : 
      64         123 : void StatWtClassicalDataAggregator::aggregate() {
      65             :     // Drive NEXT LOWER layer's ViImpl to gather data into allvis:
      66             :     // Assumes all sub-chunks in the current chunk are to be used
      67             :     // for the variance calculation
      68             :     // Essentially, we are sorting the incoming data into
      69             :     // allvis, to enable a convenient variance calculation
      70         123 :     _variances.clear();
      71         123 :     auto* vb = _vii->getVisBuffer();
      72         123 :     std::map<StatWtTypes::BaselineChanBin, Cube<Complex>> data;
      73         123 :     std::map<StatWtTypes::BaselineChanBin, Cube<Bool>> flags;
      74         123 :     std::map<StatWtTypes::BaselineChanBin, Vector<Double>> exposures;
      75         123 :     IPosition blc(3, 0);
      76         123 :     auto trc = blc;
      77         123 :     auto initChanSelTemplate = True;
      78         123 :     Cube<Bool> chanSelFlagTemplate, chanSelFlags;
      79         123 :     auto firstTime = True;
      80             :     // we cannot know the spw until we are in the subchunks loop
      81         123 :     Int spw = -1;
      82        1483 :     for (_vii->origin(); _vii->more(); _vii->next()) {
      83        1372 :         if (_checkFirstSubChunk(spw, firstTime, vb)) {
      84          12 :             return;
      85             :         }
      86        1360 :         if (! _mustComputeWtSp) {
      87          20 :             _mustComputeWtSp.reset(
      88             :                 new Bool(
      89          10 :                     vb->existsColumn(VisBufferComponent2::WeightSpectrum)
      90          10 :                 )
      91             :             );
      92             :         }
      93        1360 :         const auto& ant1 = vb->antenna1();
      94        1360 :         const auto& ant2 = vb->antenna2();
      95             :         // [nCorr, nFreq, nRows)
      96        2720 :         const auto& dataCube = _dataCube(vb);
      97        1360 :         const auto& flagCube = vb->flagCube();
      98        2720 :         const auto dataShape = dataCube.shape();
      99        1360 :         const auto& exposureVector = vb->exposure();
     100        1360 :         const auto nrows = vb->nRows();
     101        1360 :         const auto npol = dataCube.nrow();
     102             :         const auto resultantFlags = _getResultantFlags(
     103             :             chanSelFlagTemplate, chanSelFlags, initChanSelTemplate,
     104             :             spw, flagCube
     105        2720 :         );
     106        2720 :         auto bins = _chanBins.find(spw)->second;
     107        1360 :         StatWtTypes::BaselineChanBin blcb;
     108        1360 :         blcb.spw = spw;
     109        2720 :         IPosition dataCubeBLC(3, 0);
     110        2720 :         auto dataCubeTRC = dataCube.shape() - 1;
     111       35956 :         for (rownr_t row=0; row<nrows; ++row) {
     112       34596 :             dataCubeBLC[2] = row;
     113       34596 :             dataCubeTRC[2] = row;
     114       34596 :             blcb.baseline = _baseline(ant1[row], ant2[row]);
     115       34596 :             auto citer = bins.cbegin();
     116       34596 :             auto cend = bins.cend();
     117      113292 :             for (; citer!=cend; ++citer) {
     118       78696 :                 dataCubeBLC[1] = citer->start;
     119       78696 :                 dataCubeTRC[1] = citer->end;
     120       78696 :                 blcb.chanBin.start = citer->start;
     121       78696 :                 blcb.chanBin.end = citer->end;
     122      157392 :                 auto dataSlice = dataCube(dataCubeBLC, dataCubeTRC);
     123      157392 :                 auto flagSlice = resultantFlags(dataCubeBLC, dataCubeTRC);
     124       78696 :                 if (data.find(blcb) == data.end()) {
     125        4758 :                     data[blcb] = dataSlice;
     126        4758 :                     flags[blcb] = flagSlice;
     127        4758 :                     exposures[blcb] = Vector<Double>(1, exposureVector[row]);
     128             :                 }
     129             :                 else {
     130       73938 :                     auto myshape = data[blcb].shape();
     131       73938 :                     auto nplane = myshape[2];
     132       73938 :                     auto nchan = myshape[1];
     133       73938 :                     data[blcb].resize(npol, nchan, nplane+1, True);
     134       73938 :                     flags[blcb].resize(npol, nchan, nplane+1, True);
     135       73938 :                     exposures[blcb].resize(nplane+1, True);
     136       73938 :                     trc = myshape - 1;
     137             :                     // because we've extended the cube by one plane since
     138             :                     // myshape was determined.
     139       73938 :                     ++trc[2];
     140       73938 :                     blc[2] = trc[2];
     141       73938 :                     data[blcb](blc, trc) = dataSlice;
     142       73938 :                     flags[blcb](blc, trc) = flagSlice;
     143       73938 :                     exposures[blcb][trc[2]] = exposureVector[row];
     144             :                 }
     145             :             }
     146             :         }
     147             :     }
     148         111 :     _computeVariances(data, flags, exposures);
     149             : }
     150             : 
     151         328 : void StatWtClassicalDataAggregator::weightSingleChanBin(
     152             :     Matrix<Float>& wtmat, Int nrows
     153             : ) const {
     154         656 :     Vector<Int> ant1, ant2, spws;
     155         656 :     Vector<Double> exposures;
     156         328 :     _vii->antenna1(ant1);
     157         328 :     _vii->antenna2(ant2);
     158         328 :     _vii->spectralWindows(spws);
     159         328 :     _vii->exposure(exposures);
     160             :     // There is only one spw in a chunk
     161         328 :     auto spw = *spws.begin();
     162         328 :     StatWtTypes::BaselineChanBin blcb;
     163         328 :     blcb.spw = spw;
     164        2284 :     for (Int i=0; i<nrows; ++i) {
     165        3912 :         auto bins = _chanBins.find(spw)->second;
     166        1956 :         blcb.baseline = _baseline(ant1[i], ant2[i]);
     167        1956 :         blcb.chanBin = bins[0];
     168        3912 :         auto variances = _variances.find(blcb)->second;
     169        1956 :         if (_combineCorr) {
     170           0 :             wtmat.column(i) = exposures[i]/variances[0];
     171             :         }
     172             :         else {
     173        1956 :             auto corr = 0;
     174        9780 :             for (const auto variance: variances) {
     175        7824 :                 wtmat(corr, i) = exposures[i]/variance;
     176        7824 :                 ++corr;
     177             :             }
     178             :         }
     179             :     }
     180         328 : }
     181             : 
     182         111 : void StatWtClassicalDataAggregator::_computeVariances(
     183             :     const map<StatWtTypes::BaselineChanBin, Cube<Complex>>& data,
     184             :     const map<StatWtTypes::BaselineChanBin, Cube<Bool>>& flags,
     185             :     const map<StatWtTypes::BaselineChanBin, Vector<Double>>& exposures
     186             : ) const {
     187         111 :     auto diter = data.cbegin();
     188         111 :     auto dend = data.cend();
     189         111 :     const auto nActCorr = diter->second.shape()[0];
     190         111 :     const auto ncorr = _combineCorr ? 1 : nActCorr;
     191             :     // spw will be the same for all members
     192         111 :     const auto& spw = data.begin()->first.spw;
     193         111 :     vector<StatWtTypes::BaselineChanBin> keys(data.size());
     194         111 :     auto idx = 0;
     195        4869 :     for (; diter!=dend; ++diter, ++idx) {
     196        4758 :         const auto& blcb = diter->first;
     197        4758 :         keys[idx] = blcb;
     198        4758 :         _variances[blcb].resize(ncorr);
     199             :     }
     200         111 :     auto n = keys.size();
     201             : #ifdef _OPENMP
     202         111 : #pragma omp parallel for
     203             :     // cout << "WARN OMP PARALLEL LOOPING IS OFF FOR DEBUGGING" << endl;
     204             : #endif
     205             :     for (size_t i=0; i<n; ++i) {
     206             :         auto blcb = keys[i];
     207             :         auto dataForBLCB = data.find(blcb)->second;
     208             :         auto flagsForBLCB = flags.find(blcb)->second;
     209             :         auto exposuresForBLCB = exposures.find(blcb)->second;
     210             :         for (ssize_t corr=0; corr<ncorr; ++corr) {
     211             :             IPosition start(3, 0);
     212             :             auto end = dataForBLCB.shape() - 1;
     213             :             if (! _combineCorr) {
     214             :                 start[0] = corr;
     215             :                 end[0] = corr;
     216             :             }
     217             :             Slicer slice(start, end, Slicer::endIsLast);
     218             :             _variances[blcb][corr]
     219             :                 = _varianceComputer->computeVariance(
     220             :                     dataForBLCB(slice), flagsForBLCB(slice),
     221             :                     exposuresForBLCB, spw
     222             :                 );
     223             :         }
     224             :     }
     225         111 : }
     226             : 
     227        1360 : void StatWtClassicalDataAggregator::weightSpectrumFlags(
     228             :     Cube<Float>& wtsp, Cube<Bool>& flagCube, Bool& checkFlags,
     229             :     const Vector<Int>& ant1, const Vector<Int>& ant2, const Vector<Int>& spws,
     230             :     const Vector<Double>& exposures, const Vector<rownr_t>&
     231             : ) const {
     232        2720 :     Slicer slice(IPosition(3, 0), flagCube.shape(), Slicer::endIsLength);
     233        2720 :     auto sliceStart = slice.start();
     234        2720 :     auto sliceEnd = slice.end();
     235        1360 :     auto nrows = ant1.size();
     236       35956 :     for (size_t i=0; i<nrows; ++i) {
     237       34596 :         sliceStart[2] = i;
     238       34596 :         sliceEnd[2] = i;
     239       34596 :         StatWtTypes::BaselineChanBin blcb;
     240       34596 :         blcb.baseline = _baseline(ant1[i], ant2[i]);
     241       34596 :         auto spw = spws[i];
     242       34596 :         blcb.spw = spw;
     243       69192 :         auto bins = _chanBins.find(spw)->second;
     244      113292 :         for (const auto& bin: bins) {
     245       78696 :             sliceStart[1] = bin.start;
     246       78696 :             sliceEnd[1] = bin.end;
     247       78696 :             blcb.chanBin = bin;
     248      157392 :             auto variances = _variances.find(blcb)->second;
     249       78696 :             auto ncorr = variances.size();
     250      157392 :             Vector<Double> weights(ncorr);
     251      304380 :             for (size_t corr=0; corr<ncorr; ++corr) {
     252      225684 :                 if (! _combineCorr) {
     253      195984 :                     sliceStart[0] = corr;
     254      195984 :                     sliceEnd[0] = corr;
     255             :                 }
     256      451368 :                 weights[corr] = variances[corr] == 0
     257      225684 :                     ? 0 : exposures[i]/variances[corr];
     258      225684 :                 slice.setStart(sliceStart);
     259      225684 :                 slice.setEnd(sliceEnd);
     260      451368 :                 _updateWtSpFlags(
     261      225684 :                     wtsp, flagCube, checkFlags, slice, weights[corr]
     262             :                 );
     263             :             }
     264             :         }
     265             :     }
     266        1360 : }
     267             : 
     268             : }
     269             : 
     270             : }

Generated by: LCOV version 1.16