Line data Source code
1 : //# VisCalSolver2.cc: Implementation of generic visibility solving
2 : //# Copyright (C) 1996,1997,1998,1999,2000,2001,2002,2003
3 : //# Associated Universities, Inc. Washington DC, USA.
4 : //#
5 : //# This library is free software; you can redistribute it and/or modify it
6 : //# under the terms of the GNU Library General Public License as published by
7 : //# the Free Software Foundation; either version 2 of the License, or (at your
8 : //# option) any later version.
9 : //#
10 : //# This library is distributed in the hope that it will be useful, but WITHOUT
11 : //# ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
12 : //# FITNESS FOR A PARTICULAR PURPOSE. See the GNU Library General Public
13 : //# License for more details.
14 : //#
15 : //# You should have received a copy of the GNU Library General Public License
16 : //# along with this library; if not, write to the Free Software Foundation,
17 : //# Inc., 675 Massachusetts Ave, Cambridge, MA 02139, USA.
18 : //#
19 : //# Correspondence concerning AIPS++ should be addressed as follows:
20 : //# Internet email: aips2-request@nrao.edu.
21 : //# Postal address: AIPS++ Project Office
22 : //# National Radio Astronomy Observatory
23 : //# 520 Edgemont Road
24 : //# Charlottesville, VA 22903-2475 USA
25 : //#
26 :
27 : #include <synthesis/MeasurementComponents/VisCalSolver2.h>
28 :
29 : #include <msvis/MSVis/VisBuffer.h>
30 :
31 : #include <casacore/casa/Arrays/ArrayMath.h>
32 : #include <casacore/casa/Arrays/MaskArrMath.h>
33 : #include <casacore/casa/Arrays/ArrayLogical.h>
34 : #include <casacore/casa/Arrays/ArrayIter.h>
35 : //#include <scimath/Mathematics/MatrixMathLA.h>
36 : #include <casacore/casa/BasicSL/String.h>
37 : #include <casacore/casa/BasicMath/Math.h>
38 : #include <casacore/casa/Utilities/Assert.h>
39 : #include <casacore/casa/Exceptions/Error.h>
40 : #include <casacore/casa/OS/Memory.h>
41 : #include <casacore/casa/OS/Path.h>
42 :
43 : #include <sstream>
44 :
45 : #include <casacore/casa/Logging/LogMessage.h>
46 : #include <casacore/casa/Logging/LogSink.h>
47 :
48 : #define VCS2_PRTLEV 0
49 :
50 : namespace casa { //# NAMESPACE CASA - BEGIN
51 :
52 : using namespace casacore;
53 :
54 :
55 : // **********************************************************
56 : // VisCalSolver2 Implementations
57 : //
58 :
59 0 : VisCalSolver2::VisCalSolver2() :
60 : SDBs_(NULL),
61 : ve_(NULL),
62 : svc_(NULL),
63 : nPar_(0),
64 : maxIter_(50),
65 : chiSq_(0.0),
66 : chiSqV_(4,0.0),
67 : lastChiSq_(0.0),dChiSq_(0.0),
68 : sumWt_(0.0),sumWtV_(4,0.0),nWt_(0),
69 : cvrgcount_(0),
70 : par_(), parOK_(), parErr_(), lastPar_(),
71 : dpar_(),
72 : grad_(),hess_(),
73 : lambda_(2.0),
74 : optstep_(True),
75 : doL1_(false),
76 : L1clamp_(0),
77 : doRMSThresh_(false),
78 : RMSThresh_(0),
79 : nRMSThresh_(0),
80 0 : prtlev_(VCS2_PRTLEV)
81 : {
82 0 : if (prtlev()>0) cout << "VCS2::VCS2()" << endl;
83 0 : }
84 :
85 15912 : VisCalSolver2::VisCalSolver2(String solmode, Vector<Float>& rmsthresh) :
86 : SDBs_(NULL),
87 : ve_(NULL),
88 : svc_(NULL),
89 : nPar_(0),
90 : maxIter_(50),
91 : chiSq_(0.0),
92 : chiSqV_(4,0.0),
93 : lastChiSq_(0.0),dChiSq_(0.0),
94 : sumWt_(0.0),sumWtV_(4,0.0),nWt_(0),
95 : cvrgcount_(0),
96 : par_(), parOK_(), parErr_(), lastPar_(),
97 : dpar_(),
98 : grad_(),hess_(),
99 : lambda_(2.0),
100 : optstep_(True),
101 : doL1_(false),
102 31824 : L1clamp_(std::vector<Float>({5e-3, 5e-4, 5e-5})),
103 : doRMSThresh_(false),
104 : RMSThresh_(rmsthresh), //
105 15912 : nRMSThresh_(rmsthresh.nelements()),
106 47736 : prtlev_(VCS2_PRTLEV)
107 : {
108 15912 : if (prtlev()>0) cout << "VCS2::VCS2(solmode)" << endl;
109 :
110 15912 : if (solmode.contains("L1")) doL1_=true;
111 15912 : if (solmode.contains("R")) doRMSThresh_=true;
112 :
113 15912 : if (doRMSThresh_ && nRMSThresh_==0) {
114 0 : doRMSThresh_=false;
115 : //RMSThresh_=Vector<Float>(std::vector<Float>({7.0,5.0,4.0,3.5,3.0,2.8,2.6,2.4,2.2}));
116 : //nRMSThresh_=RMSThresh_.nelements();
117 : }
118 :
119 15912 : }
120 :
121 15912 : VisCalSolver2::~VisCalSolver2()
122 : {
123 15912 : if (prtlev()>0) cout << "VCS2::~VCS2()" << endl;
124 15912 : }
125 :
126 :
127 : // New SDBList version
128 17160 : Bool VisCalSolver2::solve(VisEquation& ve, SolvableVisCal& svc, SDBList& sdbs) {
129 :
130 : // If L1 and/or outlier flagging requested, call specialize method
131 17160 : if (doL1_ || doRMSThresh_)
132 0 : return solveL1R(ve,svc,sdbs);
133 :
134 17160 : if (prtlev()>1) cout << "VCS2::solve(,,SDBs)" << endl;
135 :
136 : /*
137 : LogSink logsink;
138 : {
139 : LogMessage message(LogOrigin("VisCalSolver2", "solve"));
140 : ostringstream o; o<<"Beginning solve...";
141 : message.message(o);
142 : logsink.post(message);
143 : }
144 : */
145 : // Pointers to local ve,svc
146 17160 : ve_=&ve;
147 17160 : svc_=&svc;
148 17160 : SDBs_=&sdbs;
149 :
150 : // Verify that VisEq has the correct svc:
151 : // TBD?
152 :
153 : // Initialize everything
154 17160 : initSolve();
155 :
156 34320 : Vector<Float> steplist(maxIter_+2,0.0);
157 34320 : Vector<Float> rsteplist(maxIter_+2,0.0);
158 :
159 : // Verify Data's validity for solve w.r.t. baselines available
160 : // (this sets parOK() on per-antenna basis (for focusChan)
161 : // based on data weights and baseline participation)
162 17160 : Bool oktosolve = svc_->verifyConstraints(*SDBs_);
163 :
164 17160 : if (oktosolve) {
165 :
166 9390 : if (prtlev()>1) cout << "First guess:" << endl
167 0 : << "amp = " << amplitude(par()) << endl
168 0 : << "pha = " << phase(par())
169 0 : << endl;
170 :
171 : // Iterate solution
172 9390 : Int iter(0);
173 9390 : Bool done(False);
174 76820 : while (!done) {
175 :
176 76820 : if (prtlev()>2) cout << " Beginning iteration " << iter
177 0 : << "---------------------------------" << endl;
178 :
179 : // Differentiate the VB and get current Chi2
180 76820 : differentiate2();
181 76820 : chiSquare2();
182 76820 : if (chiSq()==0.0) {
183 0 : cout << "CHI2 IS SPURIOUSLY ZERO!*************************************" << endl;
184 : //cout << "R() = " << R() << endl;
185 : // cout << "sum(wtmat) = " << sum(wtmat) << endl;
186 0 : return False;
187 : }
188 :
189 76820 : dChiSq() = chiSq()-lastChiSq();
190 :
191 : // cout << "chi2 = " << chiSq() << " " << dChiSq() << " " << dChiSq()/chiSq() << endl;
192 :
193 : // Continuue if we haven't converged
194 76820 : if (!converged()) {
195 :
196 67430 : if (dChiSq()<=0.0) {
197 : // last step was good...
198 66757 : lastChiSq()=chiSq();
199 :
200 : // so accumulate new grad/hess...
201 66757 : accGradHess2();
202 :
203 : //...and adjust lambda downward
204 : // lambda()/=2.0;
205 : // lambda()=0.8;
206 66757 : lambda()=1.0;
207 : }
208 : else {
209 : // cout << "reverting..." << chiSq() << " " << dChiSq() << " (" << iter << ")" << endl;
210 : // last step was bad, revert to previous
211 673 : revert();
212 : //...with a larger lambda
213 : // lambda()*=4.0;
214 673 : lambda()=1.0;
215 : }
216 :
217 : // Solve for the parameter step
218 67430 : solveGradHess();
219 :
220 : // Remember curr pars
221 67430 : lastPar()=par();
222 :
223 : // Refine the step size by exploring chi2 in the
224 : // gradient direction
225 67430 : if (optstep_) // && cvrgcount_>=3)
226 67430 : optStepSize2();
227 :
228 : // Update current parameters (saves a copy of them)
229 67430 : updatePar();
230 :
231 :
232 67430 : steplist(iter)=max(amplitude(dpar()));
233 67430 : rsteplist(iter)=max(amplitude(dpar())/amplitude(par()));
234 :
235 : }
236 : else {
237 : // Convergence means we're done!
238 9390 : done=True;
239 :
240 9390 : if (prtlev()>0) {
241 0 : cout << "par()=" << par() << endl;
242 : }
243 :
244 :
245 : /*
246 : cout << " good pars=" << ntrue(parOK())
247 : << " iterations=" << iter << endl
248 : << " steps=" << steplist(IPosition(1,0),IPosition(1,iter))
249 : << endl
250 : << " rsteps=" << rsteplist(IPosition(1,0),IPosition(1,iter))
251 : << endl;
252 : */
253 :
254 : // Get parameter errors:
255 9390 : accGradHess2();
256 9390 : getErrors();
257 :
258 : // Return, signaling success if at least 1 good solution
259 9390 : return (ntrue(parOK())>0);
260 :
261 : }
262 :
263 : // Escape iteration loop via iteration limit
264 67430 : if (iter==maxIter()) {
265 0 : cout << "Reached iteration limit: " << iter << " iterations. " << endl;
266 : // cout << " good pars = " << ntrue(parOK())
267 : // << " steps = " << steplist
268 : // << endl;
269 0 : done=True;
270 : }
271 :
272 : // Advance iteration counter
273 67430 : iter++;
274 : }
275 :
276 : }
277 : else {
278 7770 : cout << " Insufficient unflagged antennas to proceed with this solve." << endl;
279 : }
280 :
281 7770 : return False;
282 :
283 : }
284 :
285 : // New L1(R)-capable version
286 0 : Bool VisCalSolver2::solveL1R(VisEquation& ve, SolvableVisCal& svc, SDBList& sdbs) {
287 :
288 0 : if (prtlev()>1) cout << "VCS2::solve(,,SDBs)" << endl;
289 :
290 : /*
291 : LogSink logsink;
292 : {
293 : LogMessage message(LogOrigin("VisCalSolver2", "solve"));
294 : ostringstream o; o<<"Beginning solve...";
295 : message.message(o);
296 : logsink.post(message);
297 : }
298 : */
299 : // Pointers to local ve,svc
300 0 : ve_=&ve;
301 0 : svc_=&svc;
302 0 : SDBs_=&sdbs;
303 :
304 : // Verify that VisEq has the correct svc:
305 : // TBD?
306 :
307 : // Initialize everything
308 0 : initSolve();
309 :
310 0 : Vector<Float> steplist(maxIter_+2,0.0);
311 0 : Vector<Float> rsteplist(maxIter_+2,0.0);
312 :
313 : // Verify Data's validity for solve w.r.t. baselines available
314 : // (this sets parOK() on per-antenna basis (for focusChan)
315 : // based on data weights and baseline participation)
316 0 : Bool oktosolve = svc_->verifyConstraints(*SDBs_);
317 :
318 0 : if (oktosolve) {
319 :
320 : // Tweak guess in L1 case, to avoid degeneracy...
321 0 : if (doL1_)
322 0 : par()*=Complex(1.0001,0.0);
323 :
324 0 : if (prtlev()>1) cout << "First guess:" << endl
325 0 : << "amp = " << amplitude(par()) << endl
326 0 : << "pha = " << phase(par())
327 0 : << endl;
328 :
329 : // Iterate solution
330 0 : Int iter(0);
331 0 : Bool done(False);
332 0 : Bool applyWorkingFlags(false);
333 0 : Int L1iter(0), IRiter(0);
334 0 : while (!done) {
335 :
336 0 : if (prtlev()>2) cout << " Beginning iteration " << iter
337 0 : << "---------------------------------" << endl;
338 :
339 : // Differentiate the VB and get current Chi2
340 0 : differentiate2();
341 :
342 0 : if (doRMSThresh_ && applyWorkingFlags) {
343 0 : SDBs_->updateWorkingFlags();
344 0 : applyWorkingFlags=false; // must be explicitly triggered below
345 : }
346 :
347 : // Set up working weights
348 0 : if (doL1_)
349 0 : SDBs_->updateWorkingWeights(doL1_,L1clamp_(L1iter));
350 : else
351 0 : SDBs_->updateWorkingWeights(false);
352 :
353 :
354 0 : chiSquare2();
355 0 : if (chiSq()==0.0) {
356 0 : cout << "CHI2 IS SPURIOUSLY ZERO!*************************************" << endl;
357 : //cout << "R() = " << R() << endl;
358 : // cout << "sum(wtmat) = " << sum(wtmat) << endl;
359 0 : return False;
360 : }
361 :
362 0 : dChiSq() = chiSq()-lastChiSq();
363 :
364 : //cout << "iter=" << iter << " X2=" << chiSq() << " dX2=" << dChiSq() << " dX2/X2=" << dChiSq()/chiSq(); // << endl;
365 :
366 : // Continuue if we haven't converged
367 0 : if (!converged()) {
368 :
369 : //if (dChiSq()<=0.0) {
370 : if (true || dChiSq()<=0.0) {
371 : // last step was good...
372 0 : lastChiSq()=chiSq();
373 :
374 : // so accumulate new grad/hess...
375 0 : accGradHess2();
376 :
377 : //...and adjust lambda downward
378 : // lambda()/=2.0;
379 : // lambda()=0.8;
380 0 : lambda()=1.0;
381 : }
382 : else {
383 : // cout << "reverting..." << chiSq() << " " << dChiSq() << " (" << iter << ")" << endl;
384 : // last step was bad, revert to previous
385 : revert();
386 : //...with a larger lambda
387 : // lambda()*=4.0;
388 : lambda()=1.0;
389 : }
390 :
391 : // Solve for the parameter step
392 0 : solveGradHess();
393 :
394 : // Remember curr pars
395 0 : lastPar()=par();
396 :
397 : // Refine the step size by exploring chi2 in the
398 : // gradient direction
399 0 : if (optstep_ && !doL1_) // && cvrgcount_>=3)
400 0 : optStepSize2();
401 :
402 : // Update current parameters (saves a copy of them)
403 0 : updatePar();
404 :
405 0 : steplist(iter)=max(amplitude(dpar()));
406 0 : rsteplist(iter)=max(amplitude(dpar())/amplitude(par()));
407 :
408 : //cout << " rstep=" << rsteplist(iter) << endl;
409 :
410 : }
411 : else {
412 :
413 : // Convergence means we're done, NOMINALLY
414 0 : done=True;
415 :
416 : // Override convergence if we need to solve again with
417 : // revised weight/flag conditions for robustness
418 0 : if (doL1_ && L1iter<Int(L1clamp_.nelements())-1) {
419 : //cout << "*~*~*~*~*~*~* Converged w/ L1clamp = " << L1clamp_(L1iter) << " *~*~*~*~*~*~*~*~*~*~*~*~*~*~*~*~*" << endl;
420 0 : done=false;
421 0 : ++L1iter;
422 0 : iter=-1;
423 0 : cvrgcount_=0;
424 0 : lastChiSq()=DBL_MAX;
425 : }
426 0 : else if (doRMSThresh_ && IRiter<nRMSThresh_) {
427 : //cout << "*~*~*~*~*~*~* Applying RMSThresh = " << RMSThresh_(IRiter) << " *~*~*~*~*~*~*~*~*~*~*~*~*~*~*~*~*" << endl;
428 0 : RMSThresh(IRiter);
429 0 : ++IRiter;
430 0 : applyWorkingFlags=true; // force apply of the RMSThresh'd flags at the top of loop _after_ differentiation
431 0 : done=false;
432 0 : L1iter=0;
433 0 : iter=-1;
434 0 : cvrgcount_=0;
435 0 : lastChiSq()=DBL_MAX;
436 : }
437 :
438 : // If still done (robustness options absent or exhausted), escape solve loop
439 0 : if (done) {
440 :
441 0 : if (prtlev()>0) {
442 0 : cout << "par()=" << par() << endl;
443 : }
444 :
445 : /*
446 : cout << " good pars=" << ntrue(parOK())
447 : << " iterations=" << iter << endl
448 : << " steps=" << steplist(IPosition(1,0),IPosition(1,iter))
449 : << endl
450 : << " rsteps=" << rsteplist(IPosition(1,0),IPosition(1,iter))
451 : << endl;
452 : */
453 :
454 : // Get parameter errors:
455 0 : accGradHess2();
456 0 : getErrors();
457 :
458 : // Return, signaling success if at least 1 good solution
459 0 : return (ntrue(parOK())>0);
460 : }
461 :
462 : } // converged?
463 :
464 : // Escape iteration loop via iteration limit
465 0 : if (iter==maxIter()) {
466 0 : cout << "Reached iteration limit: " << iter << " iterations. " << endl;
467 : // cout << " good pars = " << ntrue(parOK())
468 : // << " steps = " << steplist
469 : // << endl;
470 0 : done=True;
471 : }
472 :
473 : // Advance iteration counter
474 0 : iter++;
475 : }
476 :
477 : }
478 : else {
479 0 : cout << " Insufficient unflagged antennas to proceed with this solve." << endl;
480 : }
481 :
482 0 : return False;
483 :
484 : }
485 :
486 17160 : void VisCalSolver2::initSolve() {
487 :
488 17160 : if (prtlev()>2) cout << " VCS2::initSolve()" << endl;
489 :
490 : // Get total number of cal parameters from svc info
491 17160 : nPar()=svc().nTotalPar();
492 :
493 17160 : if (prtlev()>2)
494 0 : cout << " Total parameters in solve: " << nPar() << endl;
495 :
496 : // Chi2 and weights
497 17160 : chiSq()=0.0;
498 17160 : lastChiSq()=DBL_MAX;
499 17160 : dChiSq()=0.0;
500 :
501 17160 : sumWt()=0.0;
502 17160 : nWt()=0;
503 :
504 : // Link up svc's internal pars with local reference
505 : // (only if shape is correct)
506 :
507 17160 : if (svc().solveCPar().nelements()==uInt(nPar())) {
508 17160 : par().reference(svc().solveCPar().reform(IPosition(1,nPar())));
509 17160 : parOK().reference(svc().solveParOK().reform(IPosition(1,nPar())));
510 17160 : parErr().reference(svc().solveParErr().reform(IPosition(1,nPar())));
511 : }
512 : else
513 0 : throw(AipsError("Solver and SVC cannot synchronize parameters."));
514 :
515 : // Pars
516 :
517 17160 : dpar().resize(nPar());
518 17160 : dpar()=0.0;
519 :
520 17160 : lastPar().resize(nPar());
521 :
522 : // Gradient and Hessian
523 17160 : grad().resize(nPar());
524 17160 : grad()=0.0;
525 :
526 17160 : hess().resize(nPar());
527 17160 : hess()=0.0;
528 :
529 : // Levenberg-Marquardt factor
530 17160 : lambda()=2.0;
531 :
532 : // Convergence anticipation
533 17160 : cvrgcount_=0;
534 :
535 17160 : }
536 :
537 176420 : void VisCalSolver2::residualate2() {
538 :
539 : // if (prtlev()>2) cout << " VCS2::residualate()" << endl;
540 :
541 : // For now, just use ve.diffResid, until we have
542 : // implemented focuschan-aware trial corrupt in SVC
543 : // (this will hurt performance a bit)
544 :
545 414480 : for (Int isdb=0;isdb<sdbs().nSDB();++isdb)
546 238060 : ve().differentiate(sdbs()(isdb));
547 176420 : }
548 :
549 76820 : void VisCalSolver2::differentiate2() {
550 :
551 76820 : if (prtlev()>2) cout << " VCS2::differentiate(SDB version)" << endl;
552 :
553 : // TBD: Should this be packaged in the SolveDataBuffer such
554 : // that is could be called there with a reference to the svc()?
555 : // Eg:
556 : // sdbs().differentiate(svc()); // an aggregate method in SDBList
557 : // ...which then does:
558 : // svc.differentiate(this) (for each SDB)
559 : //
560 : // also consider whether VE is in the loop here?
561 : //
562 : // (don't wind this up in a way that makes it harder to extend....)
563 :
564 : // Delegate to VisEquation
565 167435 : for (Int isdb=0;isdb<sdbs().nSDB();++isdb)
566 90615 : ve().differentiate(sdbs()(isdb));
567 :
568 76820 : }
569 :
570 253240 : void VisCalSolver2::chiSquare2() {
571 :
572 253240 : if (prtlev()>2) cout << " VCS2::chiSquare(SDB version)" << endl;
573 :
574 : // TBD: per-ant/bln chiSq?
575 :
576 253240 : chiSq()=0.0;
577 253240 : chiSqV()=0.0;
578 253240 : sumWt()=0.0;
579 253240 : sumWtV()=0.0;
580 253240 : nWt()=0;
581 :
582 253240 : Cube<Complex> R;
583 :
584 : // Loop over SDBs
585 581915 : for (Int isdb=0;isdb<sdbs().nSDB();++isdb) {
586 :
587 : // Current SDB
588 328675 : SolveDataBuffer& sdb(sdbs()(isdb));
589 328675 : R.reference(sdb.residuals());
590 :
591 : // _const_ access to working flags and weights
592 328675 : const Cube<Bool>& wFC(sdb.const_workingFlagCube());
593 328675 : const Cube<Float>& wWS(sdb.const_workingWtSpec());
594 :
595 : // Shapes for iteration
596 657350 : IPosition shR(R.shape());
597 328675 : Int nCorr=shR(0);
598 328675 : Int nChan=shR(1);
599 328675 : Int nRow=shR(2);
600 :
601 : // Simple indexed accumulation of chiSq
602 : // TBD: optimize w.r.t. indexing?
603 328675 : Double chisq0(0.0);
604 21759742 : for (Int irow=0;irow<nRow;++irow) {
605 21431067 : if (!sdb.flagRow()(irow)) {
606 42862134 : for (Int ich=0;ich<nChan;++ich) {
607 91858319 : for (Int icorr=0;icorr<nCorr;++icorr) {
608 : //if (!sdb.residFlagCube()(icorr,ich,irow)) { // OLD: residFlagCube
609 70427252 : const Bool& fl(wFC(icorr,ich,irow)); // NEW: workingFlagCube CORRECT?
610 70427252 : if (!fl) {
611 69812542 : const Float& wt(wWS(icorr,ich,irow));
612 69812542 : if (wt>0.0) {
613 59513734 : Complex& Ri(R(icorr,ich,irow));
614 :
615 : // This element's contribution
616 59513734 : chisq0=Double(wt*real(Ri*conj(Ri))); // cf: square(abs(R))?
617 :
618 : // Accumulate per-corr
619 59513734 : chiSqV()(icorr)+=chisq0;
620 59513734 : sumWtV()(icorr)+=wt;
621 59513734 : nWt()++;
622 : } // wt>0
623 : } // !flag
624 : } // icorr
625 : } // ich
626 : } // !flagRow
627 : } // irow
628 :
629 : } // isdb
630 :
631 : //cout << "chiSqV() = " << chiSqV() << endl;
632 :
633 : // Totals over corrs
634 253240 : chiSq()=sum(chiSqV());
635 253240 : sumWt()=sum(sumWtV());
636 :
637 253240 : }
638 :
639 : // RMS calculation (for thresholding)
640 0 : void VisCalSolver2::RMSThresh(Int RejIter) {
641 :
642 0 : if (prtlev()>2) cout << " VCS2::RMS(SDB version)" << endl;
643 :
644 0 : const Float threshold(RMSThresh_(RejIter));
645 0 : Bool dolog=(RejIter==nRMSThresh_-1);
646 :
647 : // TBD: per-ant/bln chiSq?
648 :
649 0 : Int nCorr=sdbs().nCorrelations();
650 0 : Vector<Double> xxV(nCorr,0.0);
651 0 : Vector<Double> sWtV(nCorr,0.0);
652 :
653 0 : Cube<Complex> R;
654 :
655 : // Loop over SDBs
656 0 : for (Int isdb=0;isdb<sdbs().nSDB();++isdb) {
657 :
658 : // Current SDB
659 0 : SolveDataBuffer& sdb(sdbs()(isdb));
660 0 : R.reference(sdb.residuals());
661 :
662 : // Shapes for iteration
663 0 : IPosition shR(R.shape());
664 0 : Int nCorr=shR(0);
665 0 : Int nChan=shR(1);
666 0 : Int nRow=shR(2);
667 :
668 0 : const Cube<Bool>& wFC(sdb.const_workingFlagCube());
669 :
670 : // Simple indexed accumulation of XX
671 0 : Double xx0(0.0);
672 0 : for (Int irow=0;irow<nRow;++irow) {
673 0 : if (!sdb.flagRow()(irow)) {
674 0 : for (Int ich=0;ich<nChan;++ich) {
675 0 : for (Int icorr=0;icorr<nCorr;++icorr) {
676 0 : if (!wFC(icorr,ich,irow)) {
677 0 : Float& wt(sdb.infocusWtSpec()(icorr,ich,irow));
678 0 : if (wt>0.0) {
679 0 : Complex& Ri(R(icorr,ich,irow));
680 :
681 : // This element's contribution
682 0 : xx0=Double(wt*real(Ri*conj(Ri))); // cf: square(abs(R))?
683 :
684 : // Accumulate per-corr
685 0 : xxV(icorr)+=xx0;
686 0 : sWtV(icorr)+=wt;
687 : } // wt>0
688 : } // !flag
689 : } // icorr
690 : } // ich
691 : } // !flagRow
692 : } // irow
693 :
694 : } // isdb
695 :
696 0 : Vector<Float> rmsV(nCorr,0.0);
697 0 : for (Int icorr=0;icorr<nCorr;++icorr) {
698 0 : if (sWtV(icorr)>0.0)
699 0 : rmsV(icorr)=Float(sqrt(xxV(icorr)/sWtV(icorr)));
700 : }
701 :
702 : // Now Apply the threshold
703 :
704 0 : LogIO logsink;
705 :
706 : // Loop over SDBs
707 0 : for (Int isdb=0;isdb<sdbs().nSDB();++isdb) {
708 :
709 : // Current SDB
710 0 : SolveDataBuffer& sdb(sdbs()(isdb));
711 0 : R.reference(sdb.residuals());
712 :
713 : // Initialize wFC afresh
714 0 : sdb.workingFlagCube().resize(0,0,0);
715 0 : sdb.workingFlagCube().assign(sdb.residFlagCube());
716 :
717 : // Shapes for iteration
718 0 : IPosition shR(R.shape());
719 0 : Int nCorr=shR(0);
720 0 : Int nChan=shR(1);
721 0 : Int nRow=shR(2);
722 :
723 0 : for (Int irow=0;irow<nRow;++irow) {
724 0 : if (!sdb.flagRow()(irow)) {
725 0 : for (Int ich=0;ich<nChan;++ich) {
726 0 : for (Int icorr=0;icorr<nCorr;++icorr) {
727 0 : if (!sdb.residFlagCube()(icorr,ich,irow)) {
728 0 : Float& wt(sdb.infocusWtSpec()(icorr,ich,irow));
729 0 : if (wt>0.0) {
730 0 : Float Ra(abs(R(icorr,ich,irow)));
731 0 : if (Ra>(threshold*rmsV(icorr))) {
732 0 : sdb.workingFlagCube()(icorr,ich,irow)=true;
733 : //sdb.workingWtSpec()(icorr,ich,irow)=0.0;
734 :
735 0 : if (dolog) // only on last go-round, report what baselines have been flagged
736 0 : logsink << "Rejected outlier at: " << MVTime(sdb.time()(irow)/C::day).string(MVTime::YMD,7)
737 0 : << " spw=" << sdb.spectralWindow()(irow)
738 0 : << " BL=" << sdb.antenna1()(irow) << "-" << sdb.antenna2()(irow)
739 : << " corr=" << icorr
740 0 : << ": residual=" << Ra/rmsV(icorr) << "sigma" << " (threshold=" << threshold << ")" << LogIO::POST;
741 :
742 : }
743 : } // wt>0
744 : } // !flag
745 : } // icorr
746 : } // ich
747 : } // !flagRow
748 : } // irow
749 :
750 : } // isdb
751 :
752 0 : }
753 :
754 :
755 :
756 76820 : Bool VisCalSolver2::converged() {
757 :
758 76820 : if (prtlev()>2) cout << " VCS2::converged()" << endl;
759 :
760 : // Change in chi2
761 76820 : dChiSq() = chiSq()-lastChiSq();
762 76820 : Float fChiSq(dChiSq()/chiSq());
763 :
764 : // Consider convergence if chi2 decreases...
765 : // if (dChiSq()<=0.0) {
766 76820 : if (fChiSq<=0.001) {
767 :
768 : // ...and the change is small:
769 76820 : if (abs(dChiSq()) < 0.1*chiSq()) {
770 56340 : ++cvrgcount_;
771 :
772 : // if (cvrgcount_==2) lambda()=2.0;
773 :
774 : }
775 :
776 76820 : if (prtlev()>0)
777 0 : cout << " Good: chiSq=" << chiSq()
778 0 : << " dChiSq=" << dChiSq()
779 0 : << " fChiSq=" << dChiSq()/chiSq()
780 0 : << " cvrgcnt=" << cvrgcount_
781 0 : << " lambda=" << lambda()
782 0 : << endl;
783 :
784 :
785 : // Five such steps we believe we have converged!
786 76820 : if (cvrgcount_>5)
787 9390 : return True;
788 :
789 : }
790 : else {
791 : // (chi2 increased)
792 :
793 : // If a large increase, don't anticipate yet
794 0 : if (abs(dChiSq()) > 0.1*chiSq())
795 0 : cvrgcount_=0;
796 : else {
797 : // anticipate a little less if upward change is small
798 : // TBD: is this right?
799 0 : --cvrgcount_;
800 0 : cvrgcount_=max(cvrgcount_,0); // never less than zero
801 : }
802 :
803 0 : if (prtlev()>0)
804 0 : cout << " Bad: chiSq=" << chiSq()
805 0 : << " dChiSq=" << dChiSq()
806 0 : << " fChiSq=" << dChiSq()/chiSq()
807 0 : << " cvrgcnt=" << cvrgcount_
808 0 : << " lambda=" << lambda()
809 0 : << endl;
810 :
811 :
812 : }
813 :
814 : // Not yet converged
815 67430 : return False;
816 :
817 : }
818 :
819 76147 : void VisCalSolver2::accGradHess2() {
820 :
821 76147 : if (prtlev()>2) cout << " VCS2::accGradHess(SDB version)" << endl;
822 :
823 76147 : grad()=0.0;
824 76147 : hess()=0.0;
825 :
826 152294 : Cube<Complex> R;
827 152294 : Array<Complex> dR;
828 :
829 : // Loop over SDBs
830 165928 : for (Int isdb=0;isdb<sdbs().nSDB();++isdb) {
831 :
832 : // Current SDB
833 89781 : SolveDataBuffer& sdb(sdbs()(isdb));
834 :
835 89781 : R.reference(sdb.residuals());
836 89781 : dR.reference(sdb.diffResiduals());
837 :
838 89781 : const Cube<Float>& wWS(sdb.const_workingWtSpec());
839 89781 : const Cube<Bool>& wFC(sdb.const_workingFlagCube());
840 :
841 179562 : IPosition dRip(dR.shape());
842 :
843 89781 : Int nRow(dRip(3));
844 89781 : Int nChan(dRip(2));
845 89781 : Int nParPerAnt(dRip(1)); // pars per antenna
846 89781 : Int nCorr(dRip(0));
847 :
848 : // Simple indexed accumulation
849 6138586 : for (Int irow=0;irow<nRow;++irow) {
850 6048805 : if (!sdb.flagRow()(irow)) {
851 6048805 : Int a1i= nParPerAnt*sdb.antenna1()(irow);
852 6048805 : Int a2i= nParPerAnt*sdb.antenna2()(irow);
853 12097610 : for (Int ichan=0;ichan<nChan;++ichan) {
854 25616585 : for (int icorr=0;icorr<nCorr;++icorr) {
855 : //if (!sdb.residFlagCube()(icorr,ichan,irow)) { // OLD: residFlagCube
856 19567780 : const Bool& fl(wFC(icorr,ichan,irow)); // NEW: workingFlagCube CORRECT?
857 19567780 : if (!fl) {
858 19382318 : const Float& wt(wWS(icorr,ichan,irow));
859 19382318 : if (wt>0.0) {
860 17152576 : Complex& Ri(R(icorr,ichan,irow));
861 41006568 : for (Int ipar=0;ipar<nParPerAnt;++ipar) {
862 :
863 : // Accumulate grad and hess for this icorr,ichan,irow,ipar
864 : // for a1:
865 23853992 : Complex& dR1(dR(IPosition(5,icorr,ipar,ichan,irow,0)));
866 23853992 : grad()(a1i+ipar)+= DComplex(wt*(Ri*conj(dR1)));
867 23853992 : hess()(a1i+ipar)+= Double(wt*real(dR1*conj(dR1)));
868 : // for a2:
869 23853992 : Complex& dR2(dR(IPosition(5,icorr,ipar,ichan,irow,1)));
870 23853992 : grad()(a2i+ipar)+= DComplex(wt*(dR2*conj(Ri)));
871 23853992 : hess()(a2i+ipar)+= Double(wt*real(dR2*conj(dR2)));
872 :
873 : } // ipar
874 : } // wt>0
875 : } // !flag
876 : } // icorr
877 : } // ichan
878 : } // !flagRow
879 : } // irow
880 :
881 : } // isdb
882 :
883 76147 : if (prtlev()>4) { // grad, hess
884 0 : cout << " grad= " << grad() << endl;
885 0 : cout << " hess= " << hess() << endl;
886 : }
887 :
888 76147 : }
889 :
890 673 : void VisCalSolver2::revert() {
891 :
892 673 : if (prtlev()>2) cout << " VCS2::revert()" << endl;
893 :
894 : // Recall the last decent parameter set
895 : // TBD: the OK flag?
896 673 : par()=lastPar();
897 :
898 673 : }
899 :
900 67430 : void VisCalSolver2::solveGradHess() {
901 :
902 67430 : if (prtlev()>2) cout << " VCS2::solveGradHess()" << endl;
903 :
904 : // TBD: explicit option to avoid lmfact?
905 : // TBD: pointer (or MaskedArray?) optimization?
906 :
907 67430 : Double lmfact(1.0+lambda());
908 :
909 67430 : lmfact=2.0;
910 :
911 67430 : dpar()=Complex(0.0);
912 1114446 : for (Int ipar=0; ipar<nPar(); ipar++) {
913 1047016 : if ( parOK()(ipar) && hess()(ipar)!=0.0) {
914 : // good hess for this par:
915 1027146 : dpar()(ipar) = grad()(ipar)/hess()(ipar);
916 1027146 : dpar()(ipar)/=lmfact;
917 : }
918 : else {
919 19870 : dpar()(ipar)=0.0;
920 19870 : parOK()(ipar)=False;
921 : }
922 : }
923 :
924 : // Negate (so updatePar() can _add_)
925 67430 : dpar()*=Complex(-1.0f);
926 :
927 67430 : }
928 :
929 67430 : void VisCalSolver2::updatePar() {
930 :
931 67430 : if (prtlev()>2) cout << " VCS2::updatePar()" << endl;
932 :
933 : // if (prtlev()>4) cout << " old =" << par() << endl;
934 :
935 : // if (prtlev()>4) cout << " dpar=" << dpar() << endl;
936 :
937 :
938 :
939 : // Tell svc to update the par
940 : // (permits svc() to condition the current solutions)
941 67430 : svc().updatePar(dpar());
942 :
943 67430 : if (prtlev()>4) {
944 0 : cout << " abs(dpar()) = " << amplitude(dpar()) << endl;
945 0 : cout << " new amp = " << amplitude(par()) << endl
946 0 : << " pha = " << phase(par()) << endl;
947 : }
948 :
949 67430 : }
950 :
951 :
952 67430 : void VisCalSolver2::optStepSize2() {
953 :
954 67430 : if (prtlev()>2) cout << " VCS2::optStepSize2(SDB version)" << endl;
955 :
956 134860 : Vector<Double> x2(3,-999.0);
957 67430 : Float step(1.0);
958 :
959 : // Starting point is curr chiSq
960 67430 : x2(0)=chiSq();
961 :
962 : // take nominal step
963 67430 : par()+=dpar();
964 67430 : residualate2();
965 67430 : chiSquare2();
966 67430 : x2(1)=chiSq();
967 :
968 : // If nominal step is an improvement...
969 67430 : if (x2(1)<x2(0)) {
970 :
971 : // ...double step size until x2 starts increasing
972 63671 : par()=dpar(); par()*=Complex(2.0*step); par()+=lastPar();
973 63671 : residualate2();
974 63671 : chiSquare2();
975 63671 : x2(2)=chiSq();
976 63671 : if (prtlev()>4)
977 0 : cout << " down: " << step << " " << x2-x2(0) << LogicalArray(x2>=x2(0)) <<endl;
978 93545 : while (x2(2)<x2(1)) { // && step<4.0) {
979 29874 : step*=2.0;
980 29874 : par()=dpar(); par()*=Complex(2.0*step); par()+=lastPar();
981 29874 : residualate2();
982 29874 : chiSquare2();
983 29874 : x2(1)=x2(2);
984 29874 : x2(2)=chiSq();
985 29874 : if (prtlev()>4)
986 0 : cout << " stretch: " << step << " " << x2-x2(0) << LogicalArray(x2>=x2(0)) <<endl;
987 :
988 : }
989 : }
990 : // else nominal step too big, so...
991 : else {
992 :
993 : // ... contract by halves until we bracket a minimum
994 3759 : step*=0.5;
995 3759 : par()=dpar(); par()*=Complex(step); par()+=lastPar();
996 3759 : residualate2();
997 3759 : chiSquare2();
998 3759 : x2(2)=x2(1);
999 3759 : x2(1)=chiSq();
1000 3759 : if (prtlev()>4)
1001 0 : cout << " up: " << step << " " << x2-x2(0) << LogicalArray(x2>=x2(0)) <<endl;
1002 15445 : while (x2(1)>x2(0)) { // && step>0.125) {
1003 11686 : step*=0.5;
1004 11686 : par()=dpar(); par()*=Complex(step); par()+=lastPar();
1005 11686 : residualate2();
1006 11686 : chiSquare2();
1007 11686 : x2(2)=x2(1);
1008 11686 : x2(1)=chiSq();
1009 11686 : if (prtlev()>4)
1010 0 : cout << " contract: " << step << " " << x2-x2(0) << LogicalArray(x2>=x2(0)) <<endl;
1011 : }
1012 :
1013 : }
1014 :
1015 : // At this point x2(0) > x2(1) < x2(2), so
1016 : // calculate (quadratic) step optimization factor
1017 67430 : Double optfactor(0.0);
1018 67430 : Double optn(x2(2)-x2(1));
1019 67430 : Double optd(x2(0)-2*x2(1)+x2(2));
1020 :
1021 67430 : if (abs(optd)>0.0)
1022 67430 : optfactor=Double(step)*(1.5-optn/optd);
1023 :
1024 : /*
1025 : cout << "Optimization: "
1026 : << step << " "
1027 : << optfactor << " "
1028 : << x2 << " "
1029 : << "(" << min(amplitude(lastPar())) << ") "
1030 : << max(amplitude(dpar())/amplitude(lastPar()))*180.0/C::pi << " ";
1031 : */
1032 :
1033 :
1034 67430 : if (prtlev()>4) cout << " optfactor=" << optfactor << endl;
1035 :
1036 :
1037 67430 : par()=lastPar();
1038 :
1039 : // Adjust step by the optfactor
1040 67430 : if (optfactor>0.0)
1041 67430 : dpar()*=Complex(optfactor);
1042 :
1043 : /*
1044 : cout << max(amplitude(dpar())/amplitude(lastPar()))*180.0/C::pi
1045 : << endl;
1046 : */
1047 67430 : }
1048 :
1049 9390 : void VisCalSolver2::getErrors() {
1050 :
1051 : // Number of *REAL* dof
1052 : // Int nDOF=2*(nWt()-ntrue(parOK())); // !!!! this is zero for 3 antennas!
1053 9390 : Int nDOF=max(2*(nWt()-ntrue(parOK())), 1u);
1054 :
1055 9390 : Double k2=chiSq()/Double(nDOF);
1056 :
1057 9390 : parErr()=0.0;
1058 153838 : for (Int i=0;i<nPar();++i)
1059 144448 : if (hess()(i)>0.0) {
1060 141692 : parErr()(i)=1.0/sqrt(hess()(i)/k2/2.0); // 2 is from def of Hess!
1061 : }
1062 :
1063 :
1064 9390 : if (prtlev()>4) {
1065 :
1066 0 : cout << "ChiSq = " << chiSq() << endl;
1067 0 : cout << "ChiSqV = " << chiSqV() << endl;
1068 0 : cout << "sumWt = " << sumWt() << endl;
1069 0 : cout << "nWt = " << nWt()
1070 0 : << "; nPar() = " << nPar()
1071 0 : << "; nParOK = " << ntrue(parOK())
1072 0 : << "; nDOF = " << nDOF
1073 0 : << endl;
1074 :
1075 0 : cout << "rChiSq = " << k2 << endl;
1076 0 : cout << "max(dpar) = " << max(amplitude(dpar())) << endl;
1077 0 : cout << "Amplitudes = " << amplitude(par()) << endl;
1078 0 : cout << "Errors = " << parErr() << endl;
1079 : // cout << "Errors = " << mean(parErr()(parOK())) << endl;
1080 :
1081 : }
1082 9390 : }
1083 :
1084 :
1085 : } //# NAMESPACE CASA - END
1086 :
|