Line data Source code
1 : #ifndef SYNTHESIS_OBJFUNCALGLIB_H
2 : #define SYNTHESIS_OBJFUNCALGLIB_H
3 :
4 : #include <casacore/ms/MeasurementSets/MeasurementSet.h>
5 : #include <casacore/casa/Arrays/Matrix.h>
6 : #include <casacore/casa/Arrays/IPosition.h>
7 : #include <casacore/images/Images/ImageInterface.h>
8 : #include <casacore/images/Images/PagedImage.h>
9 : #include <casacore/images/Images/TempImage.h>
10 :
11 : #include <casacore/scimath/Mathematics/FFTServer.h>
12 : #include <casacore/scimath/Functionals/Gaussian2D.h>
13 :
14 : #include "lbfgs/optimization.h"
15 :
16 : #ifndef isnan
17 : #define isnan(x) std::isnan(x)
18 : #endif
19 :
20 : namespace casa { //# NAMESPACE CASA - BEGIN
21 :
22 : class ParamAlglibObj
23 : {
24 : private:
25 : int nX;
26 : int nY;
27 : unsigned int AspLen;
28 : casacore::Matrix<casacore::Float> itsMatDirty;
29 : casacore::Matrix<casacore::Complex> itsPsfFT;
30 : std::vector<casacore::IPosition> center;
31 : casacore::Matrix<casacore::Float> newResidual;
32 : casacore::Matrix<casacore::Float> AspConvPsf;
33 : casacore::Matrix<casacore::Float> dAspConvPsf;
34 : casacore::Matrix<casacore::Float> Asp;
35 : casacore::Matrix<casacore::Float> dAsp;
36 :
37 : public:
38 : casacore::FFTServer<casacore::Float,casacore::Complex> fft;
39 :
40 0 : ParamAlglibObj(const casacore::Matrix<casacore::Float>& dirty,
41 : const casacore::Matrix<casacore::Complex>& psf,
42 : const std::vector<casacore::IPosition>& positionOptimum,
43 0 : const casacore::FFTServer<casacore::Float,casacore::Complex>& fftin) :
44 : itsMatDirty(dirty),
45 : itsPsfFT(psf),
46 : center(positionOptimum),
47 0 : fft(fftin)
48 : {
49 0 : nX = itsMatDirty.shape()(0);
50 0 : nY = itsMatDirty.shape()(1);
51 0 : AspLen = center.size();
52 0 : newResidual.resize(nX, nY);
53 0 : AspConvPsf.resize(nX, nY);
54 0 : dAspConvPsf.resize(nX, nY);
55 0 : Asp.resize(nX, nY);
56 0 : dAsp.resize(nX, nY);
57 0 : }
58 :
59 0 : ~ParamAlglibObj() = default;
60 :
61 0 : casacore::Matrix<casacore::Float> getterDirty() { return itsMatDirty; }
62 0 : casacore::Matrix<casacore::Complex> getterPsfFT() { return itsPsfFT; }
63 0 : std::vector<casacore::IPosition> getterCenter() {return center;}
64 0 : unsigned int getterAspLen() { return AspLen; }
65 0 : int getterNX() { return nX; }
66 0 : int getterNY() { return nY; }
67 0 : casacore::Matrix<casacore::Float> getterRes() { return newResidual; }
68 : void setterRes(const casacore::Matrix<casacore::Float>& res) { newResidual = res; }
69 0 : casacore::Matrix<casacore::Float> getterAspConvPsf() { return AspConvPsf; }
70 : void setterAspConvPsf(const casacore::Matrix<casacore::Float>& m) { AspConvPsf = m; }
71 0 : casacore::Matrix<casacore::Float> getterDAspConvPsf() { return dAspConvPsf; }
72 0 : casacore::Matrix<casacore::Float> getterAsp() { return Asp; }
73 : void setterAsp(const casacore::Matrix<casacore::Float>& m) { Asp = m; }
74 0 : casacore::Matrix<casacore::Float> getterDAsp() { return dAsp; }
75 : };
76 :
77 0 : void objfunc_alglib(const alglib::real_1d_array &x, double &func, alglib::real_1d_array &grad, void *ptr)
78 : {
79 : // retrieve params for GSL bfgs optimization
80 0 : casa::ParamAlglibObj *MyP = (casa::ParamAlglibObj *) ptr; //re-cast back to ParamAlglibObj to retrieve images
81 :
82 0 : casacore::Matrix<casacore::Float> itsMatDirty(MyP->getterDirty());
83 0 : casacore::Matrix<casacore::Complex> itsPsfFT(MyP->getterPsfFT());
84 0 : std::vector<casacore::IPosition> center = MyP->getterCenter();
85 0 : const unsigned int AspLen = MyP->getterAspLen();
86 0 : const int nX = MyP->getterNX();
87 0 : const int nY = MyP->getterNY();
88 0 : casacore::Matrix<casacore::Float> newResidual(MyP->getterRes());
89 0 : casacore::Matrix<casacore::Float> AspConvPsf(MyP->getterAspConvPsf());
90 0 : casacore::Matrix<casacore::Float> Asp(MyP->getterAsp());
91 0 : casacore::Matrix<casacore::Float> dAspConvPsf(MyP->getterDAspConvPsf());
92 0 : casacore::Matrix<casacore::Float> dAsp(MyP->getterDAsp());
93 :
94 0 : func = 0;
95 0 : double amp = 1;
96 :
97 0 : const int refi = nX/2;
98 0 : const int refj = nY/2;
99 :
100 0 : int minX = nX - 1;
101 0 : int maxX = 0;
102 0 : int minY = nY - 1;
103 0 : int maxY = 0;
104 :
105 : // First, get the amp * AspenConvPsf for each Aspen to update the residual
106 0 : for (unsigned int k = 0; k < AspLen; k ++)
107 : {
108 0 : amp = x[2*k];
109 0 : double scale = x[2*k+1];
110 : //std::cout << "f: amp " << amp << " scale " << scale << std::endl;
111 :
112 0 : if (isnan(amp) || scale < 0.4) // GSL scale < 0
113 : {
114 : //std::cout << "nan? " << amp << " neg scale? " << scale << std::endl;
115 : // If scale is small (<0.4), make it 0 scale to utilize Hogbom and save time
116 0 : scale = (scale = fabs(scale)) < 0.4 ? 0 : scale;
117 : //std::cout << "reset neg scale to " << scale << std::endl;
118 :
119 0 : if (scale <= 0)
120 0 : return;
121 : }
122 :
123 : // generate a gaussian for each Asp in the Aspen set
124 : // x[0]: Amplitude0, x[1]: scale0
125 : // x[2]: Amplitude1, x[3]: scale1
126 : // x[2k]: Amplitude(k), x[2k+1]: scale(k+1)
127 : //casacore::Matrix<casacore::Float> Asp(nX, nY);
128 0 : Asp = 0.0;
129 0 : dAsp = 0.0;
130 :
131 0 : const double sigma5 = 5 * scale / 2;
132 0 : const int minI = std::max(0, (int)(center[k][0] - sigma5));
133 0 : const int maxI = std::min(nX-1, (int)(center[k][0] + sigma5));
134 0 : const int minJ = std::max(0, (int)(center[k][1] - sigma5));
135 0 : const int maxJ = std::min(nY-1, (int)(center[k][1] + sigma5));
136 :
137 0 : if (minI < minX)
138 0 : minX = minI;
139 0 : if (maxI > maxX)
140 0 : maxX = maxI;
141 0 : if (minJ < minY)
142 0 : minY = minJ;
143 0 : if (maxJ > maxY)
144 0 : maxY = maxJ;
145 :
146 0 : for (int j = minJ; j <= maxJ; j++)
147 : {
148 0 : for (int i = minI; i <= maxI; i++)
149 : {
150 0 : const int px = i;
151 0 : const int py = j;
152 :
153 0 : Asp(i,j) = (1.0/(sqrt(2*M_PI)*fabs(scale)))*exp(-(pow(i-center[k][0],2) + pow(j-center[k][1],2))*0.5/pow(scale,2));
154 0 : dAsp(i,j)= Asp(i,j) * (((pow(i-center[k][0],2) + pow(j-center[k][1],2)) / pow(scale,2) - 1) / fabs(scale)); // verified by python
155 : }
156 : }
157 :
158 0 : casacore::Matrix<casacore::Complex> AspFT;
159 0 : MyP->fft.fft0(AspFT, Asp);
160 0 : casacore::Matrix<casacore::Complex> cWork;
161 0 : cWork = AspFT * itsPsfFT;
162 0 : MyP->fft.fft0(AspConvPsf, cWork, false);
163 0 : MyP->fft.flip(AspConvPsf, false, false); //need this
164 :
165 : // gradient. 0: amplitude; 1: scale
166 : // returns the gradient evaluated on x
167 0 : casacore::Matrix<casacore::Complex> dAspFT;
168 :
169 : //auto start = std::chrono::high_resolution_clock::now();
170 0 : MyP->fft.fft0(dAspFT, dAsp);
171 : //auto stop = std::chrono::high_resolution_clock::now();
172 : //auto duration = std::chrono::duration_cast<std::chrono::microseconds>(stop - start) ;
173 : //std::cout << "BFGS fft0 runtime " << duration.count() << " us" << std::endl;
174 :
175 0 : casacore::Matrix<casacore::Complex> dcWork;
176 0 : dcWork = dAspFT * itsPsfFT;
177 0 : MyP->fft.fft0(dAspConvPsf, dcWork, false);
178 0 : MyP->fft.flip(dAspConvPsf, false, false); //need this
179 : } // end get amp * AspenConvPsf
180 :
181 : // reset grad to 0. This is important to get the correct optimization.
182 0 : double dA = 0.0;
183 0 : double dS = 0.0;
184 :
185 : // Update the residual using the current residual image and the latest Aspen.
186 : // Sanjay used, Res = OrigDirty - active-set aspen * Psf, in 2004, instead.
187 : // Both works but the current approach is simpler and performs well too.
188 0 : for (int j = minY; j < maxY; ++j)
189 : {
190 0 : for(int i = minX; i < maxX; ++i)
191 : {
192 0 : newResidual(i, j) = itsMatDirty(i, j) - amp * AspConvPsf(i, j);
193 0 : func = func + double(pow(newResidual(i, j), 2));
194 :
195 : // derivatives of amplitude
196 0 : dA += double((-2) * newResidual(i,j) * AspConvPsf(i,j));
197 : // derivative of scale
198 0 : dS += double((-2) * amp * newResidual(i,j) * dAspConvPsf(i,j));
199 : }
200 : }
201 : //std::cout << "after f " << func << std::endl;
202 :
203 0 : grad[0] = dA;
204 0 : grad[1] = dS;
205 : }
206 :
207 :
208 :
209 : } // end namespace casa
210 :
211 : #endif // SYNTHESIS_OBJFUNCALGLIB_H
|