casa
$Rev:20696$
|
00001 import os 00002 from taskinit import * 00003 00004 import sdutil 00005 import asap as sd 00006 import pylab as pl 00007 from numpy import ma, array, logical_not, logical_and 00008 import sdutil 00009 00010 def sdfit(infile, antenna, fluxunit, telescopeparm, specunit, restfreq, frame, doppler, scanlist, field, iflist, pollist, fitfunc, fitmode, maskline, invertmask, nfit, thresh, min_nchan, avg_limit, box_size, edge, outfile, overwrite, plotlevel): 00011 00012 casalog.origin('sdfit') 00013 00014 try: 00015 worker = sdfit_worker(**locals()) 00016 worker.initialize() 00017 worker.execute() 00018 worker.finalize() 00019 00020 return worker.fitresult 00021 00022 except Exception, instance: 00023 sdutil.process_exception(instance) 00024 raise Exception, instance 00025 00026 class sdfit_worker(sdutil.sdtask_template): 00027 def __init__(self, **kwargs): 00028 super(sdfit_worker,self).__init__(**kwargs) 00029 00030 def __del__(self): 00031 # restore scantable when the instance is deleted 00032 self.cleanup() 00033 00034 def parameter_check(self): 00035 self.doguess = not ((self.fitmode.lower()=='list') and (self.invertmask)) 00036 00037 def initialize_scan(self): 00038 # load the data without averaging 00039 sorg = sd.scantable(self.infile,average=False,antenna=self.antenna) 00040 00041 # restorer 00042 self.restorer = sdutil.scantable_restore_factory(sorg, 00043 self.infile, 00044 self.fluxunit, 00045 self.specunit, 00046 self.frame, 00047 self.doppler, 00048 self.restfreq) 00049 00050 # Select scan and field 00051 sorg.set_selection(self.get_selector()) 00052 00053 # this is bit tricky 00054 # set fluxunit here instead of self.set_to_scan 00055 # and remove fluxunit attribute to disable additional 00056 # call of set_fluxunit in self.set_to_scan 00057 self.scan = sdutil.set_fluxunit(sorg, self.fluxunit, self.telescopeparm, False) 00058 self.fluxunit_saved = self.fluxunit 00059 del self.fluxunit 00060 00061 if self.scan: 00062 # Restore flux unit in original table before deleting 00063 self.restorer.restore() 00064 del self.restorer 00065 self.restorer = None 00066 else: 00067 self.scan = sorg 00068 00069 def execute(self): 00070 self.set_to_scan() 00071 00072 # restore fluxunit 00073 self.fluxunit = self.fluxunit_saved 00074 del self.fluxunit_saved 00075 00076 self.__set_linelist() 00077 00078 self.__fit() 00079 00080 def save(self): 00081 # Store fit 00082 if ( self.outfile != '' ): 00083 self.__store_fit() 00084 00085 def cleanup(self): 00086 # restore original scantable 00087 if self.restorer is not None: 00088 self.restorer.restore() 00089 00090 def __fit(self): 00091 # initialize fitter 00092 self.fitter = sd.fitter() 00093 self.fitresult = dict.fromkeys(['nfit','peak','cent','fwhm'],[]) 00094 self.fitparams = [] 00095 if abs(self.plotlevel) > 0: 00096 self.__init_plot() 00097 00098 dbw = 1.0 00099 current_unit = self.scan.get_unit() 00100 kw = {'thescan':self.scan} 00101 if len(self.defaultmask) > 0: kw['mask'] = self.defaultmask 00102 self.fitter.set_scan(**kw) 00103 firstplot = True 00104 00105 for irow in range(self.scan.nrow()): 00106 casalog.post( "start row %d" % (irow) ) 00107 numlines = self.nlines[irow] if isinstance(self.nlines,list) \ 00108 else self.nlines 00109 00110 if numlines == 0: 00111 self.fitparams.append([[0,0,0]]) 00112 self.fitresult['nfit']+=[-1] 00113 self.__warn_fit_failed(irow,'No lines detected.') 00114 continue 00115 00116 if ( self.fitmode == 'auto'): 00117 # Auto mode - one comp per line region 00118 # Overwriting user-supplied nfit 00119 numfit = numlines 00120 comps = [1 for i in xrange(numlines)] 00121 else: 00122 # Get number of things to fit from nfit list 00123 comps = self.nfit if isinstance(self.nfit,list) else [self.nfit] 00124 # Drop extra over numlines 00125 numfit = min(len(comps),numlines) 00126 ncomps = sum(comps) 00127 00128 casalog.post( "Will fit %d components in %d regions" % (ncomps, numfit) ) 00129 00130 if numfit <= 0: 00131 self.fitparams.append([[0,0,0]]) 00132 self.fitresult['nfit']+=[-1] 00133 self.__warn_fit_failed(irow,'Fit failed.') 00134 continue 00135 00136 # Fit the line using numfit gaussians or lorentzians 00137 # Assume the nfit list matches maskline 00138 self.fitter.set_function(**{self.fitfunc:ncomps}) 00139 if ( self.doguess ): 00140 # in auto mode, linelist will be detemined for each spectra 00141 # otherwise, linelist will be the same for all spectra 00142 if current_unit != 'channel': 00143 xx = self.scan._getabcisssa(irow) 00144 dbw = abs(xx[1]-xx[0]) 00145 self.__initial_guess(dbw,numfit,comps,irow) 00146 else: 00147 # No guesses 00148 casalog.post( "Fitting lines without starting guess" ) 00149 00150 # Now fit 00151 self.fitter.fit(row=irow) 00152 fstat = self.fitter.get_parameters() 00153 00154 # Check for convergence 00155 goodfit = ( len(fstat['errors']) > 0 ) 00156 if ( goodfit ): 00157 self.__update_params(ncomps) 00158 else: 00159 # Did not converge 00160 self.fitresult['nfit'] += [-ncomps] 00161 self.fitparams.append([[0,0,0]]) 00162 self.__warn_fit_failed(irow,'Fit failed to converge') 00163 00164 # plot 00165 if (irow < 16 and abs(self.plotlevel) > 0): 00166 self.__plot(irow, goodfit, firstplot) 00167 firstplot = False 00168 00169 def __initial_guess(self, dbw, numfit, comps, irow): 00170 llist = self.linelist[irow] if self.fitmode == 'auto' \ 00171 else self.linelist 00172 if len(llist) > 0: 00173 # guesses: [maxlist, cenlist, fwhmlist] 00174 guesses = [[],[],[]] 00175 for x in llist: 00176 x.sort() 00177 casalog.post( "detected line: "+str(x) ) 00178 msk = self.scan.create_mask(x, row=irow) 00179 guess = self.__get_initial_guess(msk,x,dbw,irow) 00180 for i in xrange(3): 00181 guesses[i] = guesses[i] + [guess[i]] 00182 else: 00183 guess = self.__get_initial_guess(self.defaultmask,[],dbw,irow) 00184 guesses = [[guess[i]] for i in xrange(3)] 00185 00186 # Guesses using max, cen, and fwhm=0.7*eqw 00187 # NOTE: should there be user options here? 00188 n = 0 00189 for i in range(numfit): 00190 # cannot guess for multiple comps yet 00191 if ( comps[i] == 1 ): 00192 # use guess 00193 #getattr(self.fitter,'set_%s_parameters'%(self.fitfunc))(maxl[i], cenl[i], fwhm[i], component=n) 00194 guess = (guesses[k][i] for k in xrange(3)) 00195 getattr(self.fitter,'set_%s_parameters'%(self.fitfunc))(*guess, component=n) 00196 n += comps[i] 00197 00198 def __get_initial_guess(self, msk, linerange, dbw, irow): 00199 [maxl,suml] = [self.scan._math._statsrow(self.scan,msk,st,irow)[0] \ 00200 for st in ['max','sum']] 00201 fwhm = maxl if maxl==0.0 else 0.7*abs(suml/maxl*dbw) 00202 cen = 0.5*sum(linerange[:2]) if len(linerange) > 1 \ 00203 else self.scan.nchan(self.scan.getif(irow))/2 00204 return (maxl,cen,fwhm) 00205 00206 def __update_params(self, ncomps): 00207 # Retrieve fit parameters 00208 self.fitresult['nfit'] = self.fitresult['nfit'] + [ncomps] 00209 keys = ['peak','cent','fwhm'] 00210 retl = dict.fromkeys(keys,[]) 00211 nkeys = len(keys) 00212 parameters = self.fitter.get_parameters() 00213 params = parameters['params'] 00214 errors = parameters['errors'] 00215 for i in range(ncomps): 00216 offset = i*nkeys 00217 for j in xrange(nkeys): 00218 key = keys[j] 00219 retl[key] = retl[key] + [[params[offset+j],\ 00220 errors[offset+j]]] 00221 for key in keys: 00222 self.fitresult[key] = self.fitresult[key] + [retl[key]] 00223 pars = parameters['params'] 00224 npars = len(pars) / ncomps 00225 self.fitparams.append(list(array(pars).reshape((ncomps,npars)))) 00226 00227 def __set_linelist(self): 00228 self.defaultmask = [] 00229 self.linelist = [] 00230 self.nlines = 1 00231 getattr(self,'_set_linelist_%s'%(self.fitmode.lower()))() 00232 00233 def _set_linelist_list(self): 00234 # Assume the user has given a list of lines 00235 # e.g. maskline=[[3900,4300]] for a single line 00236 if ( len(self.maskline) > 0 ): 00237 # There is a user-supplied channel mask for lines 00238 if ( not self.invertmask ): 00239 # Make sure this is a list-of-lists (e.g. [[1,10],[20,30]]) 00240 self.linelist = self.maskline if isinstance(self.maskline[0],list) \ 00241 else to_list_of_list(self.maskline) 00242 self.nlines = len(self.linelist) 00243 self.defaultmask = self.scan.create_mask(self.maskline,invert=self.invertmask) 00244 else: 00245 # Use whole region 00246 if self.invertmask: 00247 msg='No channel is selected because invertmask=True. Exit without fittinging.' 00248 raise Exception(msg) 00249 00250 casalog.post( "Identified %d regions for fitting" % (self.nlines) ) 00251 if ( self.invertmask ): 00252 casalog.post("No starting guesses available") 00253 else: 00254 casalog.post("Will use these as starting guesses") 00255 00256 def _set_linelist_interact(self): 00257 # Interactive masking 00258 new_mask = sdutil.init_interactive_mask(self.scan, 00259 self.maskline, 00260 self.invertmask) 00261 self.defaultmask = sdutil.get_interactive_mask(new_mask, 00262 purpose='to fit lines') 00263 self.linelist=self.scan.get_masklist(self.defaultmask) 00264 self.nlines=len(self.linelist) 00265 if self.nlines < 1: 00266 msg='No channel is selected. Exit without fittinging.' 00267 raise Exception(msg) 00268 print '%d region(s) is selected as a linemask' % self.nlines 00269 print 'The final mask list ('+self.scan._getabcissalabel()+') ='+str(self.linelist) 00270 print 'Number of line(s) to fit: nfit =',self.nfit 00271 ans=raw_input('Do you want to reassign nfit? [N/y]: ') 00272 if (ans.upper() == 'Y'): 00273 ans=input('Input nfit = ') 00274 if type(ans) == list: self.nfit = ans 00275 elif type(ans) == int: self.nfit = [ans] 00276 else: 00277 msg='Invalid definition of nfit. Setting nfit=[1] and proceed.' 00278 casalog.post(msg, priority='WARN') 00279 self.nfit = [1] 00280 casalog.post('List of line number reassigned.\n nfit = '+str(self.nfit)) 00281 sdutil.finalize_interactive_mask(new_mask) 00282 00283 def _set_linelist_auto(self): 00284 # Fit mode AUTO and in channel mode 00285 casalog.post( "Trying AUTO mode - find line channel regions" ) 00286 if ( len(self.maskline) > 0 ): 00287 # There is a user-supplied channel mask for lines 00288 self.defaultmask=self.scan.create_mask(self.maskline, 00289 invert=self.invertmask) 00290 00291 # Use linefinder to find lines 00292 casalog.post( "Using linefinder" ) 00293 fl=sd.linefinder() 00294 fl.set_scan(self.scan) 00295 # This is the tricky part 00296 # def fl.set_options(threshold=1.732,min_nchan=3,avg_limit=8,box_size=0.2) 00297 # e.g. fl.set_options(threshold=5,min_nchan=3,avg_limit=4,box_size=0.1) seem ok? 00298 fl.set_options(threshold=self.thresh,min_nchan=self.min_nchan,avg_limit=self.avg_limit,box_size=self.box_size) 00299 # Now find the lines for each row in scantable 00300 self.nlines=[] 00301 for irow in range(self.scan.nrow()): 00302 self.nlines.append(fl.find_lines(mask=self.defaultmask,nRow=irow,edge=self.edge)) 00303 # Get ranges 00304 ptout="SCAN[%d] IF[%d] POL[%d]: " %(self.scan.getscan(irow), self.scan.getif(irow), self.scan.getpol(irow)) 00305 if ( self.nlines[irow] > 0 ): 00306 ll = fl.get_ranges() 00307 casalog.post( ptout+"Found %d lines at %s" % (self.nlines[irow], str(ll) ) ) 00308 else: 00309 ll = () 00310 casalog.post( ptout+"Nothing found.", priority = 'WARN' ) 00311 00312 # This is a linear list of pairs of values, so turn these into a list of lists 00313 self.linelist.append(to_list_of_list(ll)) 00314 00315 # Done with linefinder 00316 casalog.post( "Finished linefinder." ) 00317 00318 def __store_fit(self): 00319 outf = file(sdutil.get_abspath(self.outfile),'w') 00320 00321 #header 00322 header="#%-4s %-4s %-4s %-12s " %("SCAN", "IF", "POL", "Function") 00323 numparam=3 # gaussian fitting is assumed (max, center, fwhm) 00324 for i in xrange(numparam): 00325 header+='%-12s '%('P%d'%(i)) 00326 outf.write(header+'\n') 00327 00328 #data 00329 for i in xrange(len(self.fitparams)): 00330 dattmp=" %-4d %-4d %-4d " \ 00331 %(self.scan.getscan(i), self.scan.getif(i), self.scan.getpol(i)) 00332 for j in xrange(len(self.fitparams[i])): 00333 if ( self.fitparams[i][j][0]!=0.0): 00334 datstr=dattmp+'%-12s '%('%s%d'%(self.fitfunc,j)) 00335 for k in xrange(len(self.fitparams[i][j])): 00336 datstr+="%3.8f " %(self.fitparams[i][j][k]) 00337 outf.write(datstr+'\n') 00338 00339 outf.close() 00340 00341 def __init_plot(self): 00342 n = self.scan.nrow() 00343 if n > 16: 00344 casalog.post( 'Only first 16 results are plotted.', priority = 'WARN' ) 00345 n = 16 00346 00347 # initialize plotter 00348 from matplotlib import rc as rcp 00349 rcp('lines', linewidth=1) 00350 if not (self.fitter._p and self.fitter._p._alive()): 00351 self.fitter._p = sdutil.get_plotter(self.plotlevel) 00352 self.fitter._p.hold() 00353 self.fitter._p.clear() 00354 # set nrow and ncol (maximum 4x4) 00355 self.fitter._p.set_panels(rows=n, cols=0, ganged=False) 00356 casalog.post( 'nrow,ncol= %d,%d' % (self.fitter._p.rows, self.fitter._p.cols ) ) 00357 self.fitter._p.palette(0,["#777777", "#dddddd", "red", "orange", "purple", "green", "magenta", "cyan"]) 00358 00359 def __plot(self, irow, fitted, firstplot=False ): 00360 if firstplot: 00361 labels = ['Spectrum', 'Selected Region', 'Residual', 'Fit'] 00362 else: 00363 labels = ['spec', 'select', 'res', 'fit'] 00364 myp = self.fitter._p 00365 00366 myp.subplot(irow) 00367 # plot spectra 00368 x = self.fitter.data._getabcissa(irow) 00369 y = self.fitter.data._getspectrum(irow) 00370 mr = self.fitter.data._getflagrow(irow) 00371 if mr: # a whole spectrum is flagged 00372 themask = False 00373 else: 00374 msk = array(fitter.data._getmask(irow)) 00375 fmsk = array(fitter.mask) 00376 themask = logical_and(msk,fmsk) 00377 del msk, fmsk 00378 # plot masked region if any of channel is not in fit range. 00379 idx = 0 00380 if mr or (not all(themask)): 00381 # dumped region 00382 plot_line(myp,x,y,themask,label=labels[0],color=1) 00383 idx = 1 00384 themask = logical_not(themask) 00385 00386 # fitted region 00387 plot_line(myp,x,y,themask,label=labels[idx],color=0,scale=True) 00388 00389 # plot fitted result 00390 if ( fitted ): 00391 # plot residual 00392 if ( self.plotlevel==2 ): 00393 plot_line(myp,x,fitter.get_residual(),themask,label=labels[2],color=7) 00394 # plot fit 00395 plot_line(myp,x,fitter.fitter.getfit(),themask,label=labels[3],color=2) 00396 00397 if ( irow == 0 ): 00398 tlab=self.fitter.data._getsourcename(self.fitter._fittedrow) 00399 myp.set_axes('title',tlab) 00400 if (irow%myp.rows == 0): 00401 ylab=self.fitter.data._get_ordinate_label() 00402 myp.set_axes('ylabel',ylab) 00403 if (irow/myp.rows == myp.cols-1): 00404 xlab=self.fitter.data._getabcissalabel(self.fitter._fittedrow) 00405 myp.set_axes('xlabel',xlab) 00406 myp.release() 00407 00408 def __warn_fit_failed(self,irow,message=''): 00409 casalog.post( 'Fitting:' ) 00410 casalog.post( 'Scan[%d] Beam[%d] IF[%d] Pol[%d] Cycle[%d]' %(self.scan.getscan(irow), self.scan.getbeam(irow), self.scan.getif(irow), self.scan.getpol(irow), self.scan.getcycle(irow)) ) 00411 casalog.post( " %s"%(message), priority = 'WARN' ) 00412 00413 def plot_line(plotter,x,y,msk,label,color,colormap=None,scale=False): 00414 plotter.set_line(label=label) 00415 plotter.palette(color,colormap) 00416 my=ma.masked_array(y,msk) 00417 if scale: 00418 xlim=[min(x),max(x)] 00419 ylim=[min(my),max(my)] 00420 wy=ylim[1]-ylim[0] 00421 ylim=[ylim[0]-wy*0.1,ylim[1]+wy*0.1] 00422 plotter.axes.set_xlim(xlim) 00423 plotter.axes.set_ylim(ylim) 00424 plotter.plot(x,my) 00425 00426 def to_list_of_list(l): 00427 return array(l).reshape(len(l)/2,2).tolist() 00428