LCOV - code coverage report
Current view: top level - mstransform/TVI - PolAverageTVI.cc (source / functions) Hit Total Coverage
Test: ctest_coverage.info Lines: 305 437 69.8 %
Date: 2023-11-06 10:06:49 Functions: 51 69 73.9 %

          Line data    Source code
       1             : //# PolAverageTVI.h: This file contains the implementation of the PolAverageTVI class.
       2             : //#
       3             : //#  CASA - Common Astronomy Software Applications (http://casa.nrao.edu/)
       4             : //#  Copyright (C) Associated Universities, Inc. Washington DC, USA 2011, All rights reserved.
       5             : //#  Copyright (C) European Southern Observatory, 2011, All rights reserved.
       6             : //#
       7             : //#  This library is free software; you can redistribute it and/or
       8             : //#  modify it under the terms of the GNU Lesser General Public
       9             : //#  License as published by the Free software Foundation; either
      10             : //#  version 2.1 of the License, or (at your option) any later version.
      11             : //#
      12             : //#  This library is distributed in the hope that it will be useful,
      13             : //#  but WITHOUT ANY WARRANTY, without even the implied warranty of
      14             : //#  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
      15             : //#  Lesser General Public License for more details.
      16             : //#
      17             : //#  You should have received a copy of the GNU Lesser General Public
      18             : //#  License along with this library; if not, write to the Free Software
      19             : //#  Foundation, Inc., 59 Temple Place, Suite 330, Boston,
      20             : //#  MA 02111-1307  USA
      21             : //# $Id: $
      22             : #include <mstransform/TVI/PolAverageTVI.h>
      23             : 
      24             : #include <casacore/casa/Arrays/Cube.h>
      25             : #include <casacore/casa/Arrays/Matrix.h>
      26             : #include <casacore/casa/Arrays/Vector.h>
      27             : #include <casacore/casa/BasicSL/String.h>
      28             : #include <casacore/casa/Logging/LogIO.h>
      29             : #include <casacore/casa/Containers/Record.h>
      30             : #include <casacore/casa/Exceptions/Error.h>
      31             : #include <casacore/casa/Arrays/ArrayIter.h>
      32             : #include <casacore/measures/Measures/Stokes.h>
      33             : #include <casacore/ms/MeasurementSets/MSDataDescColumns.h>
      34             : #include <casacore/ms/MeasurementSets/MSPolColumns.h>
      35             : 
      36             : #include <msvis/MSVis/VisBufferComponents2.h>
      37             : #include <msvis/MSVis/VisibilityIteratorImpl2.h>
      38             : 
      39             : #include <mstransform/TVI/UtilsTVI.h>
      40             : 
      41             : using namespace casacore;
      42             : 
      43             : namespace {
      44             : template<class T>
      45     5361088 : inline T replaceFlaggedDataWithZero(T v, Bool b) {
      46     5361088 :   return ((b) ? T(0) : v);
      47             : }
      48             : 
      49             : struct StokesTransformation {
      50             : public:
      51             :   template<class T>
      52        7314 :   inline static void transformData(Cube<T> const &dataIn,
      53             :       Cube<Bool> const &flagIn, Int pid0, Int pid1, Cube<T> &dataOut) {
      54        7314 :     if (dataIn.empty()) {
      55           0 :       dataOut.resize();
      56           0 :       return;
      57             :     }
      58       14628 :     auto const cubeShape = dataIn.shape();
      59       14628 :     IPosition const newShape(3, 1, cubeShape[1], cubeShape[2]);
      60             : //    size_t const npol = cubeShape[0];
      61        7314 :     size_t const nchan = cubeShape[1];
      62        7314 :     size_t const nrow = cubeShape[2];
      63        7314 :     size_t const nelem = cubeShape[1] * cubeShape[2];
      64       14628 :     Cube<T> transformedData(newShape, T(0));
      65       14628 :     Cube<Float> weightSum(newShape, 0.0f);
      66             : 
      67        7314 :     Int pols[] = { pid0, pid1 };
      68       21942 :     for (size_t i = 0; i < 2; ++i) {
      69       14628 :       Int ipol = pols[i];
      70       29256 :       IPosition start(3, ipol, 0, 0);
      71       29256 :       IPosition end(3, ipol, nchan - 1, nrow - 1);
      72       29256 :       auto dslice = dataIn(start, end);
      73       29256 :       auto fslice = flagIn(start, end);
      74       29256 :       Array<Float> weight(dslice.shape());
      75       29256 :       Array<T> weightedData(dslice.shape());
      76       14628 :       arrayContTransform(dslice, fslice, weightedData,
      77             :           replaceFlaggedDataWithZero<T>);
      78     5326564 :       arrayContTransform(fslice, weight, [](Bool b) {
      79     5311936 :         return ((b) ? 0.0f : 1.0f);
      80             :       });
      81       14628 :       transformedData += weightedData;
      82       14628 :       weightSum += weight;
      83             :     }
      84             : 
      85             :     // transformedData, transformedFlag and nAccumulated should be contiguous array
      86        7314 :     auto p_tdata = transformedData.data();
      87        7314 :     auto p_wsum = weightSum.data();
      88             : 
      89     2663282 :     for (size_t i = 0; i < nelem; ++i) {
      90     2655968 :       if (p_wsum[i] > 0.0) {
      91     2655968 :         p_tdata[i] /= T(p_wsum[i]);
      92             :       }
      93             :     }
      94             : 
      95        7314 :     dataOut.reference(transformedData);
      96             :   }
      97             : 
      98     2630888 :   static inline void AccumulateWeight(Float const wt, Double &wtsum) {
      99     2630888 :     wtsum += 1.0 / wt;
     100     2630888 :   }
     101             : 
     102     1315444 :   static inline void NormalizeWeight(Double const wtsum, Float &wt) {
     103     1315444 :     wt = 4.0 / wtsum;
     104     1315444 :   }
     105             : };
     106             : 
     107             : struct GeometricTransformation {
     108             :   template<class T>
     109           0 :   inline static void transformData(Cube<T> const &dataIn,
     110             :       Cube<Bool> const &flagIn, Cube<Float> const &weightIn, Int pid0, Int pid1,
     111             :       Cube<T> &dataOut) {
     112           0 :     if (dataIn.empty()) {
     113           0 :       dataOut.resize();
     114           0 :       return;
     115             :     }
     116           0 :     auto const cubeShape = dataIn.shape();
     117           0 :     IPosition const newShape(3, 1, cubeShape[1], cubeShape[2]);
     118           0 :     Cube<T> transformedData(newShape, T(0));
     119           0 :     Cube<Bool> transformedFlag(newShape, True);
     120           0 :     Cube<Float> weightSum(newShape, 0.0f);
     121             : //    size_t const npol = cubeShape[0];
     122           0 :     size_t const nchan = cubeShape[1];
     123           0 :     size_t const nrow = cubeShape[2];
     124           0 :     size_t const nelem = cubeShape[1] * cubeShape[2];
     125             : 
     126           0 :     Int pols[] = { pid0, pid1 };
     127           0 :     for (size_t i = 0; i < 2; ++i) {
     128           0 :       Int ipol = pols[i];
     129           0 :       IPosition start(3, ipol, 0, 0);
     130           0 :       IPosition end(3, ipol, nchan - 1, nrow - 1);
     131           0 :       auto const dslice = dataIn(start, end);
     132           0 :       auto const wslice = weightIn(start, end);
     133           0 :       auto const fslice = flagIn(start, end);
     134           0 :       Array<Float> weight(dslice.shape());
     135           0 :       Array<T> weightedData(dslice.shape());
     136           0 :       arrayContTransform(dslice * wslice, fslice, weightedData,
     137             :           replaceFlaggedDataWithZero<T>);
     138           0 :       arrayContTransform(wslice, fslice, weight,
     139             :           replaceFlaggedDataWithZero<Float>);
     140           0 :       transformedData += weightedData;
     141           0 :       weightSum += weight;
     142             :     }
     143             : 
     144             :     // transformedData, transformedFlag and nAccumulated should be contiguous array
     145           0 :     T *p_tdata = transformedData.data();
     146           0 :     Float *p_wsum = weightSum.data();
     147           0 :     for (size_t i = 0; i < nelem; ++i) {
     148           0 :       if (p_wsum[i] > 0.0) {
     149           0 :         p_tdata[i] /= T(p_wsum[i]);
     150             :       }
     151             :     }
     152             : 
     153           0 :     dataOut.reference(transformedData);
     154             :   }
     155             : 
     156             :   template<class T>
     157           2 :   inline static void transformData(Cube<T> const &dataIn,
     158             :       Cube<Bool> const &flagIn, Matrix<Float> const &weightIn, Int pid0,
     159             :       Int pid1, Cube<T> &dataOut) {
     160             : //    cout << "start " << __func__ << endl;
     161           2 :     if (dataIn.empty()) {
     162           0 :       dataOut.resize();
     163           0 :       return;
     164             :     }
     165           4 :     auto const cubeShape = dataIn.shape();
     166           4 :     IPosition const newShape(3, 1, cubeShape[1], cubeShape[2]);
     167           4 :     Cube<T> transformedData(newShape, T(0));
     168           4 :     Cube<Float> weightSum(newShape, 0.0f);
     169             : //    auto const npol = dataIn.shape()[0];
     170           2 :     auto const nchan = dataIn.shape()[1];
     171           2 :     auto const nrow = dataIn.shape()[2];
     172           2 :     auto const nelem = nchan * nrow;
     173             : 
     174           2 :     Int pols[] = { pid0, pid1 };
     175           6 :     for (ssize_t i = 0; i < 2; ++i) {
     176           4 :       Int ipol = pols[i];
     177          10 :       for (ssize_t j = 0; j < nrow; ++j) {
     178          12 :         IPosition start(3, ipol, 0, j);
     179          12 :         IPosition end(3, ipol, nchan - 1, j);
     180          12 :         auto dslice = dataIn(start, end);
     181          12 :         auto fslice = flagIn(start, end);
     182           6 :         auto w = weightIn(ipol, j);
     183          12 :         Array<Float> weight(dslice.shape());
     184          12 :         Array<T> weightedData(dslice.shape());
     185           6 :         arrayContTransform(dslice * w, fslice, weightedData,
     186             :             replaceFlaggedDataWithZero<T>);
     187       98310 :         arrayContTransform(fslice, weight, [&w](Bool b) {
     188       49152 :           return ((b) ? 0.0f: w);
     189             :         });
     190          12 :         IPosition tstart(3, 0, 0, j);
     191          12 :         IPosition tend(3, 0, nchan - 1, j);
     192          12 :         Array<T> tdSlice = transformedData(tstart, tend);
     193          12 :         Array<Float> twSlice = weightSum(tstart, tend);
     194           6 :         AlwaysAssert(tdSlice.conform(weightedData), AipsError);
     195           6 :         AlwaysAssert(twSlice.conform(weight), AipsError);
     196           6 :         tdSlice += weightedData;
     197           6 :         twSlice += weight;
     198             :       }
     199             :     }
     200             : 
     201             :     // transformedData, transformedFlag and nAccumulated should be contiguous array
     202           2 :     T *p_tdata = transformedData.data();
     203           2 :     Float *p_wsum = weightSum.data();
     204             : 
     205       24578 :     for (ssize_t i = 0; i < nelem; ++i) {
     206       24576 :       if (p_wsum[i] > 0.0) {
     207       24576 :         p_tdata[i] /= T(p_wsum[i]);
     208             :       }
     209             :     }
     210             : 
     211           2 :     dataOut.reference(transformedData);
     212             : //    cout << "end " << __func__ << endl;
     213             :   }
     214             : 
     215           0 :   static inline void AccumulateWeight(Float const wt, Double &wtsum) {
     216           0 :     wtsum += wt;
     217           0 :   }
     218             : 
     219           0 :   static inline void NormalizeWeight(Double const wtsum, Float &wt) {
     220           0 :     wt = wtsum;
     221           0 :   }
     222             : };
     223             : 
     224           2 : inline Float weight2Sigma(Float x) {
     225           2 :   return 1.0 / sqrt(x);
     226             : }
     227             : 
     228             : template<class WeightHandler>
     229        3658 : inline void transformWeight(Array<Float> const &weightIn, Int pid0, Int pid1,
     230             :     Array<Float> &weightOut) {
     231             : //  cout << "start " << __func__ << endl;
     232        3658 :   if (weightIn.empty()) {
     233             : //    cout << "input weight is empty" << endl;
     234           0 :     weightOut.resize();
     235           0 :     return;
     236             :   }
     237        7316 :   IPosition const shapeIn = weightIn.shape();
     238        7316 :   IPosition shapeOut(shapeIn);
     239             :   // set length of polarization axis to 1
     240        3658 :   shapeOut[0] = 1;
     241             : //  cout << "shapeIn = " << shapeIn << " shapeOut = " << shapeOut << endl;
     242             : 
     243             :   // initialization
     244        3658 :   weightOut.resize(shapeOut);
     245        3658 :   weightOut = 0.0f;
     246             : 
     247        3658 :   ssize_t numPol = shapeIn[0];
     248        3658 :   Int64 numElemPerPol = shapeOut.product();
     249             : //  cout << "numElemPerPol = " << numElemPerPol << endl;
     250             : //  cout << "numPol = " << numPol << endl;
     251             : 
     252             :   Bool b;
     253        3658 :   Float const *p_wIn = weightIn.getStorage(b);
     254        3658 :   Float *p_wOut = weightOut.data();
     255             : 
     256        3658 :   Int pols[] = { pid0, pid1 };
     257     1319102 :   for (Int64 i = 0; i < numElemPerPol; ++i) {
     258     1315444 :     ssize_t offsetIndex = i * numPol;
     259     1315444 :     Double sum = 0.0;
     260     3946332 :     for (ssize_t j = 0; j < 2; ++j) {
     261     2630888 :       Int ipol = pols[j];
     262     2630888 :       WeightHandler::AccumulateWeight(p_wIn[offsetIndex + ipol], sum);
     263             :     }
     264     1315444 :     WeightHandler::NormalizeWeight(sum, p_wOut[i]);
     265             :   }
     266             : 
     267        3658 :   weightIn.freeStorage(p_wIn, b);
     268             : }
     269             : } // anonymous namespace
     270             : 
     271             : namespace casa { //# NAMESPACE CASA - BEGIN
     272             : 
     273             : namespace vi { //# NAMESPACE VI - BEGIN
     274             : //////////
     275             : // Base Class
     276             : // PolAverageTVI
     277             : /////////
     278           7 : PolAverageTVI::PolAverageTVI(ViImplementation2 *inputVII) :
     279           7 :     TransformingVi2(inputVII) {
     280           7 :   configurePolAverage();
     281             : 
     282             :   // Initialize attached VisBuffer
     283           7 :   setVisBuffer(createAttachedVisBuffer(VbRekeyable));
     284           7 : }
     285             : 
     286           7 : PolAverageTVI::~PolAverageTVI() {
     287           7 : }
     288             : 
     289          50 : void PolAverageTVI::origin() {
     290          50 :   TransformingVi2::origin();
     291             : 
     292             :   // Configure the correlations per shape
     293          50 :   configureShapes();
     294             : 
     295             :   // Synchronize own VisBuffer
     296          50 :   configureNewSubchunk();
     297             : 
     298             :   // reconfigure if necessary
     299          50 :   reconfigurePolAverageIfNecessary();
     300             : 
     301             :   // warn if current dd is inappropriate for polarization averaging
     302          50 :   warnIfNoTransform();
     303          50 : }
     304             : 
     305        3662 : void PolAverageTVI::next() {
     306        3662 :   TransformingVi2::next();
     307             : 
     308             :   // Configure the correlations per shape
     309        3662 :   configureShapes();
     310             : 
     311             :   // Synchronize own VisBuffer
     312        3662 :   configureNewSubchunk();
     313             : 
     314             :   // reconfigure if necessary
     315        3662 :   reconfigurePolAverageIfNecessary();
     316             : 
     317             :   // warn if current dd is inappropriate for polarization averaging
     318        3662 :   warnIfNoTransform();
     319        3662 : }
     320             : 
     321        3712 : void PolAverageTVI::configureShapes() {
     322        3712 :     Vector <Int> corrs = getCorrelations ();
     323        3712 :     Int nCorrs = corrs.nelements();
     324             : 
     325        3712 :     nCorrelationsPerShape_ = casacore::Vector<casacore::Int> (1, nCorrs);
     326        3712 : }
     327             : 
     328        3712 : void PolAverageTVI::warnIfNoTransform() {
     329        3712 :   if (!doTransform_[dataDescriptionId()]) {
     330           0 :     auto const vb = getVii()->getVisBuffer();
     331           0 :     LogIO os(LogOrigin("PolAverageTVI", __func__, WHERE));
     332           0 :     String msg("Skip polarization average because");
     333           0 :     if (vb->nCorrelations() == 1) {
     334           0 :       msg += " number of polarizations is 1.";
     335           0 :     } else if (anyEQ(vb->correlationTypes(), (Int) Stokes::I)) {
     336           0 :       msg += " polarization type is Stokes.";
     337             :     } else {
     338           0 :       msg += " no valid polarization components are found.";
     339             :     }
     340           0 :     os << LogIO::WARN << msg << LogIO::POST;
     341             :   }
     342        3712 : }
     343             : 
     344        3654 : void PolAverageTVI::corrType(Vector<Int> & corrTypes) const {
     345        3654 :   if (doTransform_[dataDescriptionId()]) {
     346             :     // Always return (Stokes::I)
     347        7308 :     Vector<Int> myCorrTypes(1, (Int) Stokes::I);
     348        3654 :     corrTypes.reference(myCorrTypes);
     349             :   } else {
     350           0 :     getVii()->corrType(corrTypes);
     351             :   }
     352        3654 : }
     353             : 
     354        3658 : void PolAverageTVI::flagRow(Vector<Bool> & rowflags) const {
     355        3658 :   Cube<Bool> const &flags = getVisBuffer()->flagCube();
     356        3658 :   accumulateFlagCube(flags, rowflags);
     357        3658 : }
     358             : 
     359        3662 : void PolAverageTVI::flag(Cube<Bool> & flags) const {
     360        3662 :   auto const vb = getVii()->getVisBuffer();
     361        7324 :   Cube<Bool> originalFlags = vb->flagCube();
     362        3662 :   Int ddid = dataDescriptionId();
     363             : 
     364        3662 :   if (doTransform_[ddid]) {
     365        7324 :     auto const cubeShape = originalFlags.shape();
     366        7324 :     IPosition const newShape(3, 1, cubeShape[1], cubeShape[2]);
     367        7324 :     Cube<Bool> transformedFlags(newShape, True);
     368             :     // accumulate first polarization component
     369        7324 :     IPosition start(3, polId0_[ddid], 0, 0);
     370        7324 :     IPosition end(3, polId0_[ddid], cubeShape[1] - 1, cubeShape[2] - 1);
     371        3662 :     transformedFlags = originalFlags(start, end);
     372             : 
     373             :     // accumulate second polarization component
     374        3662 :     start[0] = polId1_[ddid];
     375        3662 :     end[0] = polId1_[ddid];
     376        3662 :     transformedFlags &= originalFlags(start, end);
     377        3662 :     flags.reference(transformedFlags);
     378             :   } else {
     379           0 :     flags.reference(originalFlags);
     380             :   }
     381        3662 : }
     382             : 
     383           0 : void PolAverageTVI::flag(Matrix<Bool> & flags) const {
     384           0 :   Cube<Bool> transformedFlags;
     385           0 :   flag(transformedFlags);
     386             : 
     387           0 :   flags.reference(transformedFlags.yzPlane(0));
     388           0 : }
     389             : 
     390           0 : void PolAverageTVI::jonesC(Vector<SquareMatrix<Complex, 2> > &cjones) const {
     391           0 :   if (doTransform_[dataDescriptionId()]) {
     392           0 :     throw AipsError("PolAverageTVI::jonesC should not be called.");
     393             :   } else {
     394           0 :     getVii()->jonesC(cjones);
     395             :   }
     396           0 : }
     397             : 
     398           2 : void PolAverageTVI::sigma(Matrix<Float> & sigmat) const {
     399           2 :   if (weightSpectrumExists()) {
     400           0 :     Cube<Float> const &sigmaSp = getVisBuffer()->sigmaSpectrum();
     401           0 :     Cube<Bool> const &flag = getVisBuffer()->flagCube();
     402           0 :     accumulateWeightCube(sigmaSp, flag, sigmat);
     403             :   } else {
     404           2 :     if (doTransform_[dataDescriptionId()]) {
     405           2 :       weight(sigmat);
     406           2 :       arrayTransformInPlace(sigmat, ::weight2Sigma);
     407             :     } else {
     408           0 :       getVii()->sigma(sigmat);
     409             :     }
     410             :   }
     411           2 : }
     412             : 
     413        3656 : void PolAverageTVI::visibilityCorrected(Cube<Complex> & vis) const {
     414        3656 :   if (getVii()->existsColumn(VisBufferComponent2::VisibilityCorrected)) {
     415        7312 :     Cube<Complex> dataCube;
     416        3656 :     getVii()->visibilityCorrected(dataCube);
     417        3656 :     if (doTransform_[dataDescriptionId()]) {
     418        7312 :       Cube<Bool> flagCube;
     419        3656 :       getVii()->flag(flagCube);
     420        3656 :       transformComplexData(dataCube, flagCube, vis);
     421             :     } else {
     422           0 :       vis.reference(dataCube);
     423             :     }
     424             :   } else {
     425           0 :     vis.resize();
     426             :   }
     427        3656 : }
     428             : 
     429        3654 : void PolAverageTVI::visibilityModel(Cube<Complex> & vis) const {
     430        3654 :   if (getVii()->existsColumn(VisBufferComponent2::VisibilityModel)) {
     431        7308 :     Cube<Complex> dataCube;
     432        3654 :     getVii()->visibilityModel(dataCube);
     433        3654 :     if (doTransform_[dataDescriptionId()]) {
     434        7308 :       Cube<Bool> flagCube;
     435        3654 :       getVii()->flag(flagCube);
     436        3654 :       transformComplexData(dataCube, flagCube, vis);
     437             :     } else {
     438           0 :       vis.reference(dataCube);
     439             :     }
     440             :   } else {
     441           0 :     vis.resize();
     442             :   }
     443        3654 : }
     444             : 
     445           0 : void PolAverageTVI::visibilityObserved(Cube<Complex> & vis) const {
     446           0 :   if (getVii()->existsColumn(VisBufferComponent2::VisibilityObserved)) {
     447           0 :     Cube<Complex> dataCube;
     448           0 :     getVii()->visibilityObserved(dataCube);
     449           0 :     if (doTransform_[dataDescriptionId()]) {
     450           0 :       Cube<Bool> flagCube;
     451           0 :       getVii()->flag(flagCube);
     452           0 :       transformComplexData(dataCube, flagCube, vis);
     453             :     } else {
     454           0 :       vis.reference(dataCube);
     455             :     }
     456             :   } else {
     457           0 :     vis.resize();
     458             :   }
     459           0 : }
     460             : 
     461           6 : void PolAverageTVI::floatData(casacore::Cube<casacore::Float> & fcube) const {
     462           6 :   if (getVii()->existsColumn(VisBufferComponent2::FloatData)) {
     463          12 :     Cube<Float> dataCube;
     464           6 :     getVii()->floatData(dataCube);
     465           6 :     if (doTransform_[dataDescriptionId()]) {
     466          12 :       Cube<Bool> flagCube;
     467           6 :       getVii()->flag(flagCube);
     468           6 :       transformFloatData(dataCube, flagCube, fcube);
     469             :     } else {
     470           0 :       fcube.reference(dataCube);
     471             :     }
     472             :   } else {
     473           0 :     fcube.resize();
     474             :   }
     475           6 : }
     476             : 
     477           0 : IPosition PolAverageTVI::visibilityShape() const {
     478           0 :   IPosition cubeShape = getVii()->visibilityShape();
     479           0 :   if (doTransform_[dataDescriptionId()]) {
     480             :     // Length of polarization (Stokes) axis is always 1 after polarizaton averaging
     481           0 :     cubeShape[0] = 1;
     482             :   }
     483           0 :   return cubeShape;
     484             : }
     485             : 
     486           4 : void PolAverageTVI::weight(Matrix<Float> & wtmat) const {
     487           4 :   if (weightSpectrumExists()) {
     488           0 :     Cube<Float> const &weightSp = getVisBuffer()->weightSpectrum();
     489           0 :     Cube<Bool> const &flag = getVisBuffer()->flagCube();
     490           0 :     accumulateWeightCube(weightSp, flag, wtmat);
     491             :   } else {
     492           8 :     Matrix<Float> wtmatOrg;
     493           4 :     getVii()->weight(wtmatOrg);
     494           4 :     if (doTransform_[dataDescriptionId()]) {
     495           4 :       transformWeight(wtmatOrg, wtmat);
     496             :     } else {
     497           0 :       wtmat.reference(wtmatOrg);
     498             :     }
     499             :   }
     500           4 : }
     501             : 
     502        3654 : void PolAverageTVI::weightSpectrum(Cube<Float> & wtsp) const {
     503        3654 :   if (weightSpectrumExists()) {
     504        7308 :     Cube<Float> wtspOrg;
     505        3654 :     getVii()->weightSpectrum(wtspOrg);
     506        3654 :     if (doTransform_[dataDescriptionId()]) {
     507        3654 :       transformWeight(wtspOrg, wtsp);
     508             :     } else {
     509           0 :       wtsp.reference(wtspOrg);
     510             :     }
     511             :   } else {
     512           0 :     wtsp.resize();
     513             :   }
     514        3654 : }
     515             : 
     516           0 : void PolAverageTVI::sigmaSpectrum(Cube<Float> & wtsp) const {
     517           0 :   if (sigmaSpectrumExists()) {
     518           0 :     if (doTransform_[dataDescriptionId()]) {
     519             :       // sigma = (weight)^-1/2
     520           0 :       weightSpectrum(wtsp);
     521           0 :       arrayTransformInPlace(wtsp, ::weight2Sigma);
     522             :     } else {
     523           0 :       getVii()->sigmaSpectrum(wtsp);
     524             :     }
     525             :   } else {
     526           0 :     wtsp.resize();
     527             :   }
     528           0 : }
     529             : 
     530           0 : const VisImagingWeight & PolAverageTVI::getImagingWeightGenerator() const {
     531           0 :   if (doTransform_[dataDescriptionId()]) {
     532             :     throw AipsError(
     533           0 :         "PolAverageTVI::getImagingWeightGenerator should not be called.");
     534             :   }
     535             : 
     536           0 :   return getVii()->getImagingWeightGenerator();
     537             : }
     538             : 
     539        7424 : Vector<Int> PolAverageTVI::getCorrelations() const {
     540        7424 :   if (doTransform_[dataDescriptionId()]) {
     541             :     // Always return (Stokes::I)
     542        7424 :     return Vector<Int>(1, Stokes::I);
     543             :   } else {
     544           0 :     return getVii()->getCorrelations();
     545             :   }
     546             : }
     547             : 
     548        3712 : Vector<Stokes::StokesTypes> PolAverageTVI::getCorrelationTypesDefined() const {
     549        3712 :   if (doTransform_[dataDescriptionId()]) {
     550             :     // Always return (Stokes::I)
     551        3712 :     return Vector<Stokes::StokesTypes>(1, Stokes::I);
     552             :   } else {
     553           0 :     return getVii()->getCorrelationTypesDefined();
     554             :   }
     555             : }
     556             : 
     557        3712 : Vector<Stokes::StokesTypes> PolAverageTVI::getCorrelationTypesSelected() const {
     558        3712 :   if (doTransform_[dataDescriptionId()]) {
     559             :     // Always return (Stokes::I)
     560        3712 :     return Vector<Stokes::StokesTypes>(1, Stokes::I);
     561             :   } else {
     562           0 :     return getVii()->getCorrelationTypesSelected();
     563             :   }
     564             : }
     565             : 
     566             : const casacore::Vector<casacore::Int>&
     567        3712 : PolAverageTVI::nCorrelationsPerShape() const {
     568        3712 :   return nCorrelationsPerShape_;
     569             : }
     570             : 
     571         107 : void PolAverageTVI::configurePolAverage() {
     572         107 :   MeasurementSet const &ms = getVii()->ms();
     573         107 :   auto const &msdd = ms.dataDescription();
     574         214 :   MSDataDescColumns msddcols(msdd);
     575         107 :   uInt ndd = msddcols.nrow();
     576         214 :   Vector<Int> polIds = msddcols.polarizationId().getColumn();
     577         107 :   doTransform_.resize(ndd);
     578         107 :   polId0_.resize(ndd);
     579         107 :   polId1_.resize(ndd);
     580         107 :   auto const &mspol = ms.polarization();
     581         214 :   MSPolarizationColumns mspolcols(mspol);
     582         107 :   doTransform_ = False;
     583         613 :   for (uInt idd = 0; idd < ndd; ++idd) {
     584         506 :     Vector<Int> corrType = mspolcols.corrType()(polIds[idd]);
     585         506 :     polId0_[idd] = -1;
     586         506 :     polId1_[idd] = -1;
     587        2224 :     for (size_t i = 0; i < corrType.size(); ++i) {
     588        1718 :       auto stokesType = Stokes::type(corrType[i]);
     589        1718 :       if (stokesType == Stokes::XX || stokesType == Stokes::RR) {
     590         506 :         polId0_[idd] = i;
     591        1212 :       } else if (stokesType == Stokes::YY || stokesType == Stokes::LL) {
     592         500 :         polId1_[idd] = i;
     593             :       }
     594             :     }
     595         506 :     doTransform_[idd] = (polId0_[idd] >= 0 && polId1_[idd] >= 0);
     596             :   }
     597         107 : }
     598             : 
     599             : //////////
     600             : // GeometricPolAverageTVI
     601             : /////////
     602           2 : GeometricPolAverageTVI::GeometricPolAverageTVI(ViImplementation2 *inputVII) :
     603           2 :     PolAverageTVI(inputVII) {
     604           2 : }
     605             : 
     606           4 : GeometricPolAverageTVI::~GeometricPolAverageTVI() {
     607           4 : }
     608             : 
     609           0 : void GeometricPolAverageTVI::transformComplexData(Cube<Complex> const &dataIn,
     610             :     Cube<Bool> const &flagIn, Cube<Complex> &dataOut) const {
     611           0 :   Int ddid = dataDescriptionId();
     612           0 :   Int pid0 = polId0_[ddid];
     613           0 :   Int pid1 = polId1_[ddid];
     614           0 :   transformData(dataIn, flagIn, pid0, pid1, dataOut);
     615           0 : }
     616             : 
     617           2 : void GeometricPolAverageTVI::transformFloatData(Cube<Float> const &dataIn,
     618             :     Cube<Bool> const &flagIn, Cube<Float> &dataOut) const {
     619           2 :   Int ddid = dataDescriptionId();
     620           2 :   Int pid0 = polId0_[ddid];
     621           2 :   Int pid1 = polId1_[ddid];
     622           2 :   transformData(dataIn, flagIn, pid0, pid1, dataOut);
     623           2 : }
     624             : 
     625           0 : void GeometricPolAverageTVI::transformWeight(Array<Float> const &weightIn,
     626             :     Array<Float> &weightOut) const {
     627           0 :   Int ddid = dataDescriptionId();
     628           0 :   Int pid0 = polId0_[ddid];
     629           0 :   Int pid1 = polId1_[ddid];
     630           0 :   ::transformWeight<GeometricTransformation>(weightIn, pid0, pid1, weightOut);
     631           0 : }
     632             : 
     633             : template<class T>
     634           2 : void GeometricPolAverageTVI::transformData(Cube<T> const &dataIn,
     635             :     Cube<Bool> const &flagIn, Int pid0, Int pid1, Cube<T> &dataOut) const {
     636           2 :   if (weightSpectrumExists()) {
     637           0 :     Cube<Float> weightSp;
     638           0 :     getVii()->weightSpectrum(weightSp);
     639           0 :     ::GeometricTransformation::transformData<T>(dataIn, flagIn, weightSp, pid0,
     640             :         pid1, dataOut);
     641             :   } else {
     642           4 :     Matrix<Float> weightMat;
     643           2 :     getVii()->weight(weightMat);
     644           2 :     ::GeometricTransformation::transformData<T>(dataIn, flagIn, weightMat, pid0,
     645             :         pid1, dataOut);
     646             :   }
     647           2 : }
     648             : 
     649             : //////////
     650             : // StokesPolAverageTVI
     651             : /////////
     652           5 : StokesPolAverageTVI::StokesPolAverageTVI(ViImplementation2 *inputVII) :
     653           5 :     PolAverageTVI(inputVII) {
     654           5 : }
     655             : 
     656          10 : StokesPolAverageTVI::~StokesPolAverageTVI() {
     657          10 : }
     658             : 
     659        7310 : void StokesPolAverageTVI::transformComplexData(Cube<Complex> const &dataIn,
     660             :     Cube<Bool> const &flagIn, Cube<Complex> &dataOut) const {
     661        7310 :   Int ddid = dataDescriptionId();
     662        7310 :   Int pid0 = polId0_[ddid];
     663        7310 :   Int pid1 = polId1_[ddid];
     664        7310 :   transformData(dataIn, flagIn, pid0, pid1, dataOut);
     665        7310 : }
     666             : 
     667           4 : void StokesPolAverageTVI::transformFloatData(Cube<Float> const &dataIn,
     668             :     Cube<Bool> const &flagIn, Cube<Float> &dataOut) const {
     669           4 :   Int ddid = dataDescriptionId();
     670           4 :   Int pid0 = polId0_[ddid];
     671           4 :   Int pid1 = polId1_[ddid];
     672           4 :   transformData(dataIn, flagIn, pid0, pid1, dataOut);
     673           4 : }
     674             : 
     675        3658 : void StokesPolAverageTVI::transformWeight(Array<Float> const &weightIn,
     676             :     Array<Float> &weightOut) const {
     677        3658 :   Int ddid = dataDescriptionId();
     678        3658 :   Int pid0 = polId0_[ddid];
     679        3658 :   Int pid1 = polId1_[ddid];
     680        3658 :   ::transformWeight<StokesTransformation>(weightIn, pid0, pid1, weightOut);
     681        3658 : }
     682             : 
     683             : template<class T>
     684        7314 : void StokesPolAverageTVI::transformData(Cube<T> const &dataIn,
     685             :     Cube<Bool> const &flagIn, Int pid0, Int pid1, Cube<T> &dataOut) const {
     686        7314 :   ::StokesTransformation::transformData<T>(dataIn, flagIn, pid0, pid1, dataOut);
     687        7314 : }
     688             : 
     689             : //////////
     690             : // PolAverageTVIFactory
     691             : /////////
     692           1 : PolAverageVi2Factory::PolAverageVi2Factory(Record const &configuration,
     693           1 :     ViImplementation2 *inputVII) :
     694           1 :     inputVII_p(inputVII), mode_(AveragingMode::DEFAULT) {
     695           1 :   inputVII_p = inputVII;
     696             : 
     697           1 :   mode_ = PolAverageVi2Factory::GetAverageModeFromConfig(configuration);
     698           1 : }
     699             : 
     700           6 : PolAverageVi2Factory::PolAverageVi2Factory(Record const &configuration,
     701             :     MeasurementSet const *ms, SortColumns const sortColumns,
     702           6 :     Double timeInterval, Bool isWritable) :
     703           6 :     inputVII_p(nullptr), mode_(AveragingMode::DEFAULT) {
     704          12 :   inputVII_p = new VisibilityIteratorImpl2(Block<MeasurementSet const *>(1, ms),
     705           6 :       sortColumns, timeInterval, isWritable);
     706             : 
     707           6 :   mode_ = PolAverageVi2Factory::GetAverageModeFromConfig(configuration);
     708           6 : }
     709             : 
     710           7 : PolAverageVi2Factory::~PolAverageVi2Factory() {
     711           7 : }
     712             : 
     713           7 : ViImplementation2 * PolAverageVi2Factory::createVi() const {
     714           7 :   if (mode_ == AveragingMode::GEOMETRIC) {
     715           2 :     return new GeometricPolAverageTVI(inputVII_p);
     716           5 :   } else if (mode_ == AveragingMode::STOKES) {
     717           5 :     return new StokesPolAverageTVI(inputVII_p);
     718             :   }
     719             : 
     720           0 :   throw AipsError("Invalid Averaging Mode for PolAverageTVI.");
     721             : 
     722             :   return nullptr;
     723             : }
     724             : 
     725           1 : PolAverageTVILayerFactory::PolAverageTVILayerFactory(
     726           1 :     Record const &configuration) :
     727           1 :     ViiLayerFactory() {
     728           1 :   configuration_p = configuration;
     729           1 : }
     730             : 
     731             : ViImplementation2*
     732           1 : PolAverageTVILayerFactory::createInstance(ViImplementation2* vii0) const {
     733             :   // Make the PolAverageTVI, using supplied ViImplementation2, and return it
     734           1 :   PolAverageVi2Factory factory(configuration_p, vii0);
     735           1 :   ViImplementation2 *vii = nullptr;
     736             :   try {
     737           1 :     vii = factory.createVi();
     738           0 :   } catch (...) {
     739           0 :     if (vii0) {
     740           0 :       delete vii0;
     741             :     }
     742           0 :     throw;
     743             :   }
     744           2 :   return vii;
     745             : }
     746             : } // # NAMESPACE VI - END
     747             : } // #NAMESPACE CASA - END

Generated by: LCOV version 1.16