LCOV - code coverage report
Current view: top level - synthesis/MeasurementEquations - objfunc_alglib.h (source / functions) Hit Total Coverage
Test: casa_coverage.info Lines: 95 95 100.0 %
Date: 2023-10-25 08:47:59 Functions: 14 14 100.0 %

          Line data    Source code
       1             : #ifndef SYNTHESIS_OBJFUNCALGLIB_H
       2             : #define SYNTHESIS_OBJFUNCALGLIB_H
       3             : 
       4             : #include <casacore/ms/MeasurementSets/MeasurementSet.h>
       5             : #include <casacore/casa/Arrays/Matrix.h>
       6             : #include <casacore/casa/Arrays/IPosition.h>
       7             : #include <casacore/images/Images/ImageInterface.h>
       8             : #include <casacore/images/Images/PagedImage.h>
       9             : #include <casacore/images/Images/TempImage.h>
      10             : 
      11             : #include <casacore/scimath/Mathematics/FFTServer.h>
      12             : #include <casacore/scimath/Functionals/Gaussian2D.h>
      13             : 
      14             : #include "lbfgs/optimization.h"
      15             : 
      16             : #ifndef isnan
      17             : #define isnan(x) std::isnan(x)
      18             : #endif
      19             : 
      20             : namespace casa { //# NAMESPACE CASA - BEGIN
      21             : 
      22             : class ParamAlglibObj
      23             : {
      24             : private:
      25             :   int nX;
      26             :   int nY;
      27             :   unsigned int AspLen;
      28             :   casacore::Matrix<casacore::Float> itsMatDirty;
      29             :   casacore::Matrix<casacore::Complex> itsPsfFT;
      30             :   std::vector<casacore::IPosition> center;
      31             :   casacore::Matrix<casacore::Float> newResidual;
      32             :   casacore::Matrix<casacore::Float> AspConvPsf;
      33             :   casacore::Matrix<casacore::Float> dAspConvPsf;
      34             :   casacore::Matrix<casacore::Float> Asp;
      35             :   casacore::Matrix<casacore::Float> dAsp;
      36             : 
      37             : public:
      38             :   casacore::FFTServer<casacore::Float,casacore::Complex> fft;
      39             : 
      40         498 :   ParamAlglibObj(const casacore::Matrix<casacore::Float>& dirty,
      41             :     const casacore::Matrix<casacore::Complex>& psf,
      42             :     const std::vector<casacore::IPosition>& positionOptimum,
      43         498 :     const casacore::FFTServer<casacore::Float,casacore::Complex>& fftin) :
      44             :     itsMatDirty(dirty),
      45             :     itsPsfFT(psf),
      46             :     center(positionOptimum),
      47         498 :     fft(fftin)
      48             :   {
      49         498 :     nX = itsMatDirty.shape()(0);
      50         498 :     nY = itsMatDirty.shape()(1);
      51         498 :     AspLen = center.size();
      52         498 :     newResidual.resize(nX, nY);
      53         498 :     AspConvPsf.resize(nX, nY);
      54         498 :     dAspConvPsf.resize(nX, nY);
      55         498 :     Asp.resize(nX, nY);
      56         498 :     dAsp.resize(nX, nY);
      57         498 :   }
      58             : 
      59         498 :   ~ParamAlglibObj() = default;
      60             : 
      61        4889 :   casacore::Matrix<casacore::Float>  getterDirty() { return itsMatDirty; }
      62        4889 :   casacore::Matrix<casacore::Complex> getterPsfFT() { return itsPsfFT; }
      63        4889 :   std::vector<casacore::IPosition> getterCenter() {return center;}
      64        4889 :   unsigned int getterAspLen() { return AspLen; }
      65        4889 :   int getterNX() { return nX; }
      66        4889 :   int getterNY() { return nY; }
      67        4889 :   casacore::Matrix<casacore::Float>  getterRes() { return newResidual; }
      68             :   void setterRes(const casacore::Matrix<casacore::Float>& res) { newResidual = res; }
      69        4889 :   casacore::Matrix<casacore::Float>  getterAspConvPsf() { return AspConvPsf; }
      70             :   void setterAspConvPsf(const casacore::Matrix<casacore::Float>& m) { AspConvPsf = m; }
      71        4889 :   casacore::Matrix<casacore::Float>  getterDAspConvPsf() { return dAspConvPsf; }
      72        4889 :   casacore::Matrix<casacore::Float>  getterAsp() { return Asp; }
      73             :   void setterAsp(const casacore::Matrix<casacore::Float>& m) { Asp = m; }
      74        4889 :   casacore::Matrix<casacore::Float>  getterDAsp() { return dAsp; }
      75             : };
      76             : 
      77        4889 : void objfunc_alglib(const alglib::real_1d_array &x, double &func, alglib::real_1d_array &grad, void *ptr) 
      78             : {
      79             :     // retrieve params for GSL bfgs optimization
      80        4889 :     casa::ParamAlglibObj *MyP = (casa::ParamAlglibObj *) ptr; //re-cast back to ParamAlglibObj to retrieve images
      81             : 
      82        4889 :     casacore::Matrix<casacore::Float> itsMatDirty(MyP->getterDirty());
      83        4889 :     casacore::Matrix<casacore::Complex> itsPsfFT(MyP->getterPsfFT());
      84        4889 :     std::vector<casacore::IPosition> center = MyP->getterCenter();
      85        4889 :     const unsigned int AspLen = MyP->getterAspLen();
      86        4889 :     const int nX = MyP->getterNX();
      87        4889 :     const int nY = MyP->getterNY();
      88        4889 :     casacore::Matrix<casacore::Float> newResidual(MyP->getterRes());
      89        4889 :     casacore::Matrix<casacore::Float> AspConvPsf(MyP->getterAspConvPsf());
      90        4889 :     casacore::Matrix<casacore::Float> Asp(MyP->getterAsp());
      91        4889 :     casacore::Matrix<casacore::Float> dAspConvPsf(MyP->getterDAspConvPsf());
      92        4889 :     casacore::Matrix<casacore::Float> dAsp(MyP->getterDAsp());
      93             : 
      94        4889 :     func = 0;
      95        4889 :     double amp = 1;
      96             : 
      97        4889 :     const int refi = nX/2;
      98        4889 :     const int refj = nY/2;
      99             : 
     100        4889 :     int minX = nX - 1;
     101        4889 :     int maxX = 0;
     102        4889 :     int minY = nY - 1;
     103        4889 :     int maxY = 0;
     104             : 
     105             :     // First, get the amp * AspenConvPsf for each Aspen to update the residual
     106        9776 :     for (unsigned int k = 0; k < AspLen; k ++)
     107             :     {
     108        4889 :         amp = x[2*k];
     109        4889 :         double scale = x[2*k+1];
     110             :         //std::cout << "f: amp " << amp << " scale " << scale << std::endl;
     111             : 
     112        4889 :       if (isnan(amp) || scale < 0.4) // GSL scale < 0
     113             :       {
     114             :         //std::cout << "nan? " << amp << " neg scale? " << scale << std::endl;
     115             :         // If scale is small (<0.4), make it 0 scale to utilize Hogbom and save time
     116           6 :         scale = (scale = fabs(scale)) < 0.4 ? 0 : scale;
     117             :         //std::cout << "reset neg scale to " << scale << std::endl;
     118             : 
     119           6 :         if (scale <= 0)
     120           2 :           return;
     121             :       }
     122             : 
     123             :       // generate a gaussian for each Asp in the Aspen set
     124             :       // x[0]: Amplitude0,       x[1]: scale0
     125             :       // x[2]: Amplitude1,       x[3]: scale1
     126             :       // x[2k]: Amplitude(k), x[2k+1]: scale(k+1)
     127             :       //casacore::Matrix<casacore::Float> Asp(nX, nY);
     128        4887 :       Asp = 0.0;
     129        4887 :       dAsp = 0.0;
     130             : 
     131        4887 :       const double sigma5 = 5 * scale / 2;
     132        4887 :       const int minI = std::max(0, (int)(center[k][0] - sigma5));
     133        4887 :       const int maxI = std::min(nX-1, (int)(center[k][0] + sigma5));
     134        4887 :       const int minJ = std::max(0, (int)(center[k][1] - sigma5));
     135        4887 :       const int maxJ = std::min(nY-1, (int)(center[k][1] + sigma5));
     136             : 
     137        4887 :       if (minI < minX)
     138        4887 :         minX = minI;
     139        4887 :       if (maxI > maxX)
     140        4887 :         maxX = maxI;
     141        4887 :       if (minJ < minY)
     142        4887 :         minY = minJ;
     143        4887 :       if (maxJ > maxY)
     144        4887 :         maxY = maxJ;
     145             : 
     146      452143 :       for (int j = minJ; j <= maxJ; j++)
     147             :       {
     148    76945897 :         for (int i = minI; i <= maxI; i++)
     149             :         {
     150    76498641 :           const int px = i;
     151    76498641 :           const int py = j;
     152             : 
     153    76498641 :           Asp(i,j) = (1.0/(sqrt(2*M_PI)*fabs(scale)))*exp(-(pow(i-center[k][0],2) + pow(j-center[k][1],2))*0.5/pow(scale,2));
     154    76498641 :           dAsp(i,j)= Asp(i,j) * (((pow(i-center[k][0],2) + pow(j-center[k][1],2)) / pow(scale,2) - 1) / fabs(scale)); // verified by python
     155             :         }
     156             :       }
     157             : 
     158        9774 :       casacore::Matrix<casacore::Complex> AspFT;
     159        4887 :       MyP->fft.fft0(AspFT, Asp);
     160        9774 :       casacore::Matrix<casacore::Complex> cWork;
     161        4887 :       cWork = AspFT * itsPsfFT;
     162        4887 :       MyP->fft.fft0(AspConvPsf, cWork, false);
     163        4887 :       MyP->fft.flip(AspConvPsf, false, false); //need this
     164             : 
     165             :       // gradient. 0: amplitude; 1: scale
     166             :       // returns the gradient evaluated on x
     167        9774 :       casacore::Matrix<casacore::Complex> dAspFT;
     168             : 
     169             :       //auto start = std::chrono::high_resolution_clock::now();
     170        4887 :       MyP->fft.fft0(dAspFT, dAsp);
     171             :       //auto stop = std::chrono::high_resolution_clock::now();
     172             :       //auto duration = std::chrono::duration_cast<std::chrono::microseconds>(stop - start) ;
     173             :       //std::cout << "BFGS fft0 runtime " << duration.count() << " us" << std::endl;
     174             : 
     175        9774 :       casacore::Matrix<casacore::Complex> dcWork;
     176        4887 :       dcWork = dAspFT * itsPsfFT;
     177        4887 :       MyP->fft.fft0(dAspConvPsf, dcWork, false);
     178        4887 :       MyP->fft.flip(dAspConvPsf, false, false); //need this
     179             :     } // end get amp * AspenConvPsf
     180             : 
     181             :     // reset grad to 0. This is important to get the correct optimization.
     182        4887 :     double dA = 0.0;
     183        4887 :     double dS = 0.0;
     184             : 
     185             :     // Update the residual using the current residual image and the latest Aspen.
     186             :     // Sanjay used, Res = OrigDirty - active-set aspen * Psf, in 2004, instead.
     187             :     // Both works but the current approach is simpler and performs well too.
     188      447256 :     for (int j = minY; j < maxY; ++j)
     189             :     {
     190    76030223 :       for(int i = minX; i < maxX; ++i)
     191             :       {
     192    75587854 :         newResidual(i, j) = itsMatDirty(i, j) - amp * AspConvPsf(i, j);
     193    75587854 :         func = func + double(pow(newResidual(i, j), 2));
     194             : 
     195             :         // derivatives of amplitude
     196    75587854 :         dA += double((-2) * newResidual(i,j) * AspConvPsf(i,j));
     197             :         // derivative of scale
     198    75587854 :         dS += double((-2) * amp * newResidual(i,j) * dAspConvPsf(i,j));
     199             :       }
     200             :     }
     201             :     //std::cout << "after f " << func << std::endl;
     202             : 
     203        4887 :     grad[0] = dA;
     204        4887 :     grad[1] = dS; 
     205             : }
     206             : 
     207             : 
     208             : 
     209             : } // end namespace casa
     210             : 
     211             : #endif // SYNTHESIS_OBJFUNCALGLIB_H

Generated by: LCOV version 1.16