casa  $Rev:20696$
 All Classes Namespaces Files Functions Variables
task_sdfit.py
Go to the documentation of this file.
00001 import os
00002 from taskinit import *
00003 
00004 import sdutil
00005 import asap as sd
00006 import pylab as pl
00007 from numpy import ma, array, logical_not, logical_and
00008 import sdutil
00009 
00010 def sdfit(infile, antenna, fluxunit, telescopeparm, specunit, restfreq, frame, doppler, scanlist, field, iflist, pollist, fitfunc, fitmode, maskline, invertmask, nfit, thresh, min_nchan, avg_limit, box_size, edge, outfile, overwrite, plotlevel):
00011     
00012     casalog.origin('sdfit')
00013     
00014     try:
00015         worker = sdfit_worker(**locals())
00016         worker.initialize()
00017         worker.execute()
00018         worker.finalize()
00019         
00020         return worker.fitresult
00021 
00022     except Exception, instance:
00023         sdutil.process_exception(instance)
00024         raise Exception, instance
00025 
00026 class sdfit_worker(sdutil.sdtask_template):
00027     def __init__(self, **kwargs):
00028         super(sdfit_worker,self).__init__(**kwargs)
00029 
00030     def __del__(self):
00031         # restore scantable when the instance is deleted
00032         self.cleanup()
00033 
00034     def parameter_check(self):
00035         self.doguess = not ((self.fitmode.lower()=='list') and (self.invertmask))
00036 
00037     def initialize_scan(self):
00038         # load the data  without averaging
00039         sorg = sd.scantable(self.infile,average=False,antenna=self.antenna)
00040                         
00041         # restorer
00042         self.restorer = sdutil.scantable_restore_factory(sorg,
00043                                                          self.infile,
00044                                                          self.fluxunit,
00045                                                          self.specunit,
00046                                                          self.frame,
00047                                                          self.doppler,
00048                                                          self.restfreq)
00049 
00050         # Select scan and field
00051         sorg.set_selection(self.get_selector())
00052 
00053         # this is bit tricky
00054         # set fluxunit here instead of self.set_to_scan
00055         # and remove fluxunit attribute to disable additional
00056         # call of set_fluxunit in self.set_to_scan
00057         self.scan = sdutil.set_fluxunit(sorg, self.fluxunit, self.telescopeparm, False)
00058         self.fluxunit_saved = self.fluxunit
00059         del self.fluxunit
00060 
00061         if self.scan:
00062             # Restore flux unit in original table before deleting
00063             self.restorer.restore()
00064             del self.restorer
00065             self.restorer = None
00066         else:
00067             self.scan = sorg
00068 
00069     def execute(self):
00070         self.set_to_scan()
00071 
00072         # restore fluxunit
00073         self.fluxunit = self.fluxunit_saved
00074         del self.fluxunit_saved
00075 
00076         self.__set_linelist()
00077 
00078         self.__fit()
00079 
00080     def save(self):
00081         # Store fit
00082         if ( self.outfile != '' ):
00083             self.__store_fit()
00084 
00085     def cleanup(self):
00086         # restore original scantable
00087         if self.restorer is not None:
00088             self.restorer.restore()
00089 
00090     def __fit(self):
00091         # initialize fitter
00092         self.fitter = sd.fitter()
00093         self.fitresult = dict.fromkeys(['nfit','peak','cent','fwhm'],[])
00094         self.fitparams = []
00095         if abs(self.plotlevel) > 0:
00096             self.__init_plot()
00097 
00098         dbw = 1.0
00099         current_unit = self.scan.get_unit()
00100         kw = {'thescan':self.scan}
00101         if len(self.defaultmask) > 0: kw['mask'] = self.defaultmask
00102         self.fitter.set_scan(**kw)
00103         firstplot = True
00104     
00105         for irow in range(self.scan.nrow()):
00106             casalog.post( "start row %d" % (irow) )
00107             numlines = self.nlines[irow] if isinstance(self.nlines,list) \
00108                        else self.nlines
00109 
00110             if numlines == 0:
00111                 self.fitparams.append([[0,0,0]])
00112                 self.fitresult['nfit']+=[-1]
00113                 self.__warn_fit_failed(irow,'No lines detected.')
00114                 continue
00115                 
00116             if ( self.fitmode == 'auto'):
00117                 # Auto mode - one comp per line region
00118                 # Overwriting user-supplied nfit
00119                 numfit = numlines
00120                 comps = [1 for i in xrange(numlines)]
00121             else:
00122                 # Get number of things to fit from nfit list
00123                 comps = self.nfit if isinstance(self.nfit,list) else [self.nfit]
00124                 # Drop extra over numlines
00125                 numfit = min(len(comps),numlines)
00126             ncomps = sum(comps)
00127         
00128             casalog.post( "Will fit %d components in %d regions" % (ncomps, numfit) )
00129 
00130             if numfit <= 0:
00131                 self.fitparams.append([[0,0,0]])
00132                 self.fitresult['nfit']+=[-1]
00133                 self.__warn_fit_failed(irow,'Fit failed.')
00134                 continue
00135 
00136             # Fit the line using numfit gaussians or lorentzians
00137             # Assume the nfit list matches maskline
00138             self.fitter.set_function(**{self.fitfunc:ncomps})
00139             if ( self.doguess ):
00140                 # in auto mode, linelist will be detemined for each spectra
00141                 # otherwise, linelist will be the same for all spectra
00142                 if current_unit != 'channel':
00143                     xx = self.scan._getabcisssa(irow)
00144                     dbw = abs(xx[1]-xx[0])
00145                 self.__initial_guess(dbw,numfit,comps,irow)
00146             else:
00147                 # No guesses
00148                 casalog.post( "Fitting lines without starting guess" )
00149 
00150             # Now fit
00151             self.fitter.fit(row=irow)
00152             fstat = self.fitter.get_parameters()
00153 
00154             # Check for convergence
00155             goodfit = ( len(fstat['errors']) > 0 )
00156             if ( goodfit ):
00157                 self.__update_params(ncomps)
00158             else:
00159                 # Did not converge
00160                 self.fitresult['nfit'] += [-ncomps]
00161                 self.fitparams.append([[0,0,0]])
00162                 self.__warn_fit_failed(irow,'Fit failed to converge')
00163 
00164             # plot
00165             if (irow < 16 and abs(self.plotlevel) > 0):
00166                 self.__plot(irow, goodfit, firstplot)
00167                 firstplot = False
00168         
00169     def __initial_guess(self, dbw, numfit, comps, irow):
00170         llist = self.linelist[irow] if self.fitmode == 'auto' \
00171                 else self.linelist
00172         if len(llist) > 0:
00173             # guesses: [maxlist, cenlist, fwhmlist]
00174             guesses = [[],[],[]]
00175             for x in llist:
00176                 x.sort()
00177                 casalog.post( "detected line: "+str(x) ) 
00178                 msk = self.scan.create_mask(x, row=irow)
00179                 guess = self.__get_initial_guess(msk,x,dbw,irow)
00180                 for i in xrange(3):
00181                     guesses[i] = guesses[i] + [guess[i]]
00182         else:
00183             guess = self.__get_initial_guess(self.defaultmask,[],dbw,irow)
00184             guesses = [[guess[i]] for i in xrange(3)]
00185 
00186         # Guesses using max, cen, and fwhm=0.7*eqw
00187         # NOTE: should there be user options here?
00188         n = 0
00189         for i in range(numfit):
00190             # cannot guess for multiple comps yet
00191             if ( comps[i] == 1 ):
00192                 # use guess
00193                 #getattr(self.fitter,'set_%s_parameters'%(self.fitfunc))(maxl[i], cenl[i], fwhm[i], component=n)
00194                 guess = (guesses[k][i] for k in xrange(3))
00195                 getattr(self.fitter,'set_%s_parameters'%(self.fitfunc))(*guess, component=n)
00196             n += comps[i]
00197 
00198     def __get_initial_guess(self, msk, linerange, dbw, irow):
00199         [maxl,suml] = [self.scan._math._statsrow(self.scan,msk,st,irow)[0] \
00200                        for st in ['max','sum']]
00201         fwhm = maxl if maxl==0.0 else 0.7*abs(suml/maxl*dbw)
00202         cen = 0.5*sum(linerange[:2]) if len(linerange) > 1 \
00203               else self.scan.nchan(self.scan.getif(irow))/2
00204         return (maxl,cen,fwhm)
00205 
00206     def __update_params(self, ncomps):
00207         # Retrieve fit parameters
00208         self.fitresult['nfit'] = self.fitresult['nfit'] + [ncomps]
00209         keys = ['peak','cent','fwhm']
00210         retl = dict.fromkeys(keys,[])
00211         nkeys = len(keys)
00212         parameters = self.fitter.get_parameters()
00213         params = parameters['params']
00214         errors = parameters['errors']
00215         for i in range(ncomps):
00216             offset = i*nkeys
00217             for j in xrange(nkeys):
00218                 key = keys[j]
00219                 retl[key] = retl[key] + [[params[offset+j],\
00220                                           errors[offset+j]]]
00221         for key in keys:
00222             self.fitresult[key] = self.fitresult[key] + [retl[key]]
00223         pars = parameters['params']
00224         npars = len(pars) / ncomps
00225         self.fitparams.append(list(array(pars).reshape((ncomps,npars))))
00226 
00227     def __set_linelist(self):
00228         self.defaultmask = []
00229         self.linelist = []
00230         self.nlines = 1
00231         getattr(self,'_set_linelist_%s'%(self.fitmode.lower()))()
00232 
00233     def _set_linelist_list(self):
00234         # Assume the user has given a list of lines
00235         # e.g. maskline=[[3900,4300]] for a single line
00236         if ( len(self.maskline) > 0 ):
00237             # There is a user-supplied channel mask for lines
00238             if ( not self.invertmask ):
00239                 # Make sure this is a list-of-lists (e.g. [[1,10],[20,30]])
00240                 self.linelist = self.maskline if isinstance(self.maskline[0],list) \
00241                                 else to_list_of_list(self.maskline)
00242                 self.nlines = len(self.linelist)
00243             self.defaultmask = self.scan.create_mask(self.maskline,invert=self.invertmask)
00244         else:
00245             # Use whole region
00246             if self.invertmask:
00247                 msg='No channel is selected because invertmask=True. Exit without fittinging.'
00248                 raise Exception(msg)
00249 
00250         casalog.post( "Identified %d regions for fitting" % (self.nlines) )
00251         if ( self.invertmask ):
00252             casalog.post("No starting guesses available")
00253         else:
00254             casalog.post("Will use these as starting guesses")
00255 
00256     def _set_linelist_interact(self):
00257         # Interactive masking
00258         new_mask = sdutil.init_interactive_mask(self.scan,
00259                                                 self.maskline,
00260                                                 self.invertmask)
00261         self.defaultmask = sdutil.get_interactive_mask(new_mask,
00262                                                        purpose='to fit lines')
00263         self.linelist=self.scan.get_masklist(self.defaultmask)
00264         self.nlines=len(self.linelist)
00265         if self.nlines < 1:
00266             msg='No channel is selected. Exit without fittinging.'
00267             raise Exception(msg)
00268         print '%d region(s) is selected as a linemask' % self.nlines
00269         print 'The final mask list ('+self.scan._getabcissalabel()+') ='+str(self.linelist)
00270         print 'Number of line(s) to fit: nfit =',self.nfit
00271         ans=raw_input('Do you want to reassign nfit? [N/y]: ')
00272         if (ans.upper() == 'Y'):
00273             ans=input('Input nfit = ')
00274             if type(ans) == list: self.nfit = ans
00275             elif type(ans) == int: self.nfit = [ans]
00276             else:
00277                 msg='Invalid definition of nfit. Setting nfit=[1] and proceed.'
00278                 casalog.post(msg, priority='WARN')
00279                 self.nfit = [1]
00280             casalog.post('List of line number reassigned.\n   nfit = '+str(self.nfit))
00281         sdutil.finalize_interactive_mask(new_mask)
00282 
00283     def _set_linelist_auto(self):
00284         # Fit mode AUTO and in channel mode
00285         casalog.post( "Trying AUTO mode - find line channel regions" )
00286         if ( len(self.maskline) > 0 ):
00287             # There is a user-supplied channel mask for lines
00288             self.defaultmask=self.scan.create_mask(self.maskline,
00289                                                    invert=self.invertmask)
00290             
00291         # Use linefinder to find lines
00292         casalog.post( "Using linefinder" )
00293         fl=sd.linefinder()
00294         fl.set_scan(self.scan)
00295         # This is the tricky part
00296         # def  fl.set_options(threshold=1.732,min_nchan=3,avg_limit=8,box_size=0.2)
00297         # e.g. fl.set_options(threshold=5,min_nchan=3,avg_limit=4,box_size=0.1) seem ok?
00298         fl.set_options(threshold=self.thresh,min_nchan=self.min_nchan,avg_limit=self.avg_limit,box_size=self.box_size)
00299         # Now find the lines for each row in scantable
00300         self.nlines=[]
00301         for irow in range(self.scan.nrow()):
00302             self.nlines.append(fl.find_lines(mask=self.defaultmask,nRow=irow,edge=self.edge))
00303             # Get ranges    
00304             ptout="SCAN[%d] IF[%d] POL[%d]: " %(self.scan.getscan(irow), self.scan.getif(irow), self.scan.getpol(irow))
00305             if ( self.nlines[irow] > 0 ):
00306                 ll = fl.get_ranges()
00307                 casalog.post( ptout+"Found %d lines at %s" % (self.nlines[irow], str(ll) ) )
00308             else:
00309                 ll = ()
00310                 casalog.post( ptout+"Nothing found.", priority = 'WARN' )
00311 
00312             # This is a linear list of pairs of values, so turn these into a list of lists
00313             self.linelist.append(to_list_of_list(ll))
00314 
00315         # Done with linefinder
00316         casalog.post( "Finished linefinder." )    
00317 
00318     def __store_fit(self):
00319         outf = file(sdutil.get_abspath(self.outfile),'w')
00320 
00321         #header 
00322         header="#%-4s %-4s %-4s %-12s " %("SCAN", "IF", "POL", "Function")
00323         numparam=3     # gaussian fitting is assumed (max, center, fwhm)
00324         for i in xrange(numparam):
00325             header+='%-12s '%('P%d'%(i))
00326         outf.write(header+'\n')
00327 
00328         #data
00329         for i in xrange(len(self.fitparams)):
00330             dattmp=" %-4d %-4d %-4d " \
00331                     %(self.scan.getscan(i), self.scan.getif(i), self.scan.getpol(i))
00332             for j in xrange(len(self.fitparams[i])):
00333                 if ( self.fitparams[i][j][0]!=0.0): 
00334                     datstr=dattmp+'%-12s '%('%s%d'%(self.fitfunc,j))
00335                     for k in xrange(len(self.fitparams[i][j])):
00336                         datstr+="%3.8f " %(self.fitparams[i][j][k])
00337                     outf.write(datstr+'\n')
00338                         
00339         outf.close()
00340 
00341     def __init_plot(self):
00342         n = self.scan.nrow()
00343         if n > 16:
00344             casalog.post( 'Only first 16 results are plotted.', priority = 'WARN' )
00345             n = 16
00346         
00347         # initialize plotter
00348         from matplotlib import rc as rcp
00349         rcp('lines', linewidth=1)
00350         if not (self.fitter._p and self.fitter._p._alive()):
00351             self.fitter._p = sdutil.get_plotter(self.plotlevel)
00352         self.fitter._p.hold()
00353         self.fitter._p.clear()
00354         # set nrow and ncol (maximum 4x4)
00355         self.fitter._p.set_panels(rows=n, cols=0, ganged=False)
00356         casalog.post( 'nrow,ncol= %d,%d' % (self.fitter._p.rows, self.fitter._p.cols ) )
00357         self.fitter._p.palette(0,["#777777", "#dddddd", "red", "orange", "purple", "green", "magenta", "cyan"])        
00358         
00359     def __plot(self, irow, fitted, firstplot=False ):
00360         if firstplot:
00361             labels = ['Spectrum', 'Selected Region', 'Residual', 'Fit']
00362         else:
00363             labels = ['spec', 'select', 'res', 'fit']
00364         myp = self.fitter._p
00365 
00366         myp.subplot(irow)
00367         # plot spectra
00368         x = self.fitter.data._getabcissa(irow)
00369         y = self.fitter.data._getspectrum(irow)
00370         mr = self.fitter.data._getflagrow(irow)
00371         if mr: # a whole spectrum is flagged
00372             themask = False
00373         else:
00374             msk = array(fitter.data._getmask(irow))
00375             fmsk = array(fitter.mask)
00376             themask = logical_and(msk,fmsk)
00377             del msk, fmsk
00378         # plot masked region if any of channel is not in fit range.
00379         idx = 0
00380         if mr or (not all(themask)):
00381             # dumped region
00382             plot_line(myp,x,y,themask,label=labels[0],color=1)
00383             idx = 1
00384         themask = logical_not(themask)
00385         
00386         # fitted region
00387         plot_line(myp,x,y,themask,label=labels[idx],color=0,scale=True)
00388 
00389         # plot fitted result
00390         if ( fitted ):
00391             # plot residual
00392             if ( self.plotlevel==2 ):
00393                 plot_line(myp,x,fitter.get_residual(),themask,label=labels[2],color=7)
00394             # plot fit
00395             plot_line(myp,x,fitter.fitter.getfit(),themask,label=labels[3],color=2)
00396 
00397         if ( irow == 0 ):
00398                 tlab=self.fitter.data._getsourcename(self.fitter._fittedrow)
00399                 myp.set_axes('title',tlab)
00400         if (irow%myp.rows == 0):
00401                 ylab=self.fitter.data._get_ordinate_label()
00402                 myp.set_axes('ylabel',ylab)
00403         if (irow/myp.rows == myp.cols-1):
00404                 xlab=self.fitter.data._getabcissalabel(self.fitter._fittedrow)
00405                 myp.set_axes('xlabel',xlab)
00406         myp.release()
00407 
00408     def __warn_fit_failed(self,irow,message=''):
00409         casalog.post( 'Fitting:' )
00410         casalog.post( 'Scan[%d] Beam[%d] IF[%d] Pol[%d] Cycle[%d]' %(self.scan.getscan(irow), self.scan.getbeam(irow), self.scan.getif(irow), self.scan.getpol(irow), self.scan.getcycle(irow)) )
00411         casalog.post( "   %s"%(message), priority = 'WARN' )
00412 
00413 def plot_line(plotter,x,y,msk,label,color,colormap=None,scale=False):
00414     plotter.set_line(label=label)
00415     plotter.palette(color,colormap)
00416     my=ma.masked_array(y,msk)
00417     if scale:
00418         xlim=[min(x),max(x)]
00419         ylim=[min(my),max(my)]
00420         wy=ylim[1]-ylim[0]
00421         ylim=[ylim[0]-wy*0.1,ylim[1]+wy*0.1]
00422         plotter.axes.set_xlim(xlim)
00423         plotter.axes.set_ylim(ylim)
00424     plotter.plot(x,my)
00425 
00426 def to_list_of_list(l):
00427     return array(l).reshape(len(l)/2,2).tolist()
00428