casa  $Rev:20696$
 All Classes Namespaces Files Functions Variables
task_sdflag.py
Go to the documentation of this file.
00001 import os
00002 from taskinit import *
00003 
00004 import sdutil
00005 import asap as sd
00006 from asap.scantable import is_scantable,is_ms
00007 from asap.flagplotter import flagplotter
00008 import pylab as pl
00009 from numpy import ma, array, logical_not, logical_and
00010 
00011 def sdflag(infile, antenna, specunit, restfreq, frame, doppler, scanlist, field, iflist, pollist, maskflag, flagrow, clip, clipminmax, clipoutside, flagmode, interactive, showflagged, outfile, outform, overwrite, plotlevel):
00012 
00013     casalog.origin('sdflag')
00014     
00015     try:
00016         worker = sdflag_worker(**locals())
00017         worker.initialize()
00018         worker.execute()
00019         worker.finalize()
00020         
00021     except Exception, instance:
00022         sdutil.process_exception(instance)
00023         raise Exception, instance
00024 
00025 class sdflag_worker(sdutil.sdtask_template):
00026     def __init__(self, **kwargs):
00027         super(sdflag_worker,self).__init__(**kwargs)
00028 
00029         # initialize plotter
00030         self.__init_plotter()
00031 
00032     def parameter_check(self):
00033         # by default, the task overwrite infile
00034         if len(self.outfile)==0: 
00035             self.project = self.infile
00036         else:
00037             self.project = self.outfile
00038 
00039         sdutil.assert_outfile_canoverwrite_or_nonexistent(self.project,
00040                                                           self.outform,
00041                                                           self.overwrite)
00042         
00043         #check the format of the infile
00044         filename = sdutil.get_abspath(self.infile)
00045         if isinstance(self.infile, str):
00046             if is_scantable(filename):
00047                 informat = 'ASAP'
00048             elif is_ms(filename):
00049                 informat = 'MS2'
00050             else:
00051                 informat = 'SDFITS'
00052         else:
00053             informat = 'UNDEFINED'
00054                 
00055         # Check the formats of infile and outfile are identical when overwrite=True.
00056         # (CAS-3096). If not, print warning message and exit.
00057         outformat = self.outform.upper()
00058         if (outformat == 'MS'): outformat = 'MS2'
00059         if self.overwrite and os.path.exists(self.project) \
00060                and (os.path.samefile(self.project,self.infile)) \
00061                and (outformat != informat):
00062             msg = "The input and output data format must be identical when "
00063             msg += "their names are identical and overwrite=True. "
00064             msg += "%s and %s given for input and output, respectively." % (informat, outformat)
00065             raise Exception, msg
00066 
00067         # check restfreq
00068         self.rfset = (self.restfreq != '') and (self.restfreq != [])
00069         self.restore = (self.specunit == 'km/s') and self.rfset
00070 
00071         # Do at least one
00072         self.docmdflag = ((len(self.flagrow)+len(self.maskflag)>0) \
00073                           or self.clip)
00074         if (not self.docmdflag) and (not self.interactive):
00075             raise Exception, 'No flag operation specified.'
00076 
00077         # check flagmode
00078         if not self.flagmode.lower() in ['flag','unflag']:
00079             raise Exception, 'unexpected flagmode'
00080         self.unflag = (self.flagmode.lower() == 'unflag')
00081 
00082         # check whether any flag operation is done or not
00083         self.anyflag = False
00084         
00085     def initialize_scan(self):
00086         sorg = sd.scantable(self.infile,average=False,antenna=self.antenna)
00087 
00088         if ( abs(self.plotlevel) > 1 ):
00089             casalog.post( "Initial Scantable:" )
00090             sorg._summary()
00091 
00092         # data selection
00093         sorg.set_selection(self.get_selector())
00094         
00095         # Copy the original data (CAS-3987)
00096         if self.is_disk_storage \
00097                and (os.path.samefile(self.project,self.infile)):
00098             self.scan = sorg.copy()
00099         else:
00100             self.scan = sorg
00101 
00102     def execute(self):
00103         self.set_to_scan()
00104 
00105         if (len(self.maskflag) > 0):
00106             self.masks = self.scan.create_mask(self.maskflag)
00107         else:
00108             self.masks = [False for i in xrange(self.scan.nchan())]
00109         
00110         self.threshold = [None,None]
00111         if isinstance(self.clipminmax, list):
00112             if (len(self.clipminmax) == 2):
00113                 self.threshold = self.clipminmax[:]
00114                 self.threshold.sort()
00115             
00116         if self.docmdflag and (abs(self.plotlevel) > 0):
00117             # plot flag and update self.docmdflag by the user input
00118             self.prior_plot()
00119 
00120         if self.docmdflag:
00121             self.command_flag()
00122 
00123         if self.interactive:
00124             self.interactive_flag()
00125         
00126         if not self.anyflag:
00127             raise Exception, 'No flag operation. Finish without saving'
00128 
00129         if abs(self.plotlevel) > 0:
00130             self.posterior_plot()
00131 
00132     def command_flag(self):
00133         # Actual flag operations
00134         if self.clip:
00135             self.do_clip()
00136         elif len(self.flagrow) == 0:
00137             self.do_channel_flag()
00138         else:
00139             self.do_row_flag()
00140         self.anyflag = True
00141 
00142         # Add history entry
00143         params={'mode':self.flagmode,'maskflag':self.maskflag}
00144         sel = self.scan.get_selection()
00145         keys=['pol','if','scan']
00146         for key in keys:
00147             val = getattr(sel,'get_%ss'%(key))()
00148             params['%ss'%(key)] = val if len(val)>0 else list(getattr(self.scan,'get%snos'%(key))())
00149         #print "input parameters:\n", params
00150         self.scan._add_history( "sdflag", params ) 
00151 
00152     def do_clip(self):
00153         casalog.post('Number of spectra to be flagged: %d\nApplying clipping...'%(self.scan.nrow()))
00154         casalog.post('flagrow and maskflag will be ignored',priority='WARN')
00155 
00156         if self.threshold[1] > self.threshold[0]:
00157             self.scan.clip(self.threshold[1], self.threshold[0], self.clipoutside, self.unflag)
00158 
00159     def do_channel_flag(self):
00160         casalog.post('Number of spectra to be flagged: %d\nApplying channel flagging...'%(self.scan.nrow()))
00161 
00162         self.scan.flag(mask=self.masks, unflag=self.unflag)
00163 
00164     def do_row_flag(self):
00165         casalog.post('Number of rows to be flagged: %d\nApplying row flagging...'%(len(self.flagrow)))
00166         casalog.post('maskflag will be ignored',priority='WARN')
00167 
00168         self.scan.flag_row(self.flagrow, self.unflag)
00169 
00170     def interactive_flag(self):
00171         from matplotlib import rc as rcp
00172         rcp('lines', linewidth=1)
00173         guiflagger = flagplotter(visible=True)
00174         #guiflagger.set_legend(loc=1,refresh=False)
00175         guiflagger.set_showflagged(self.showflagged)
00176         guiflagger.plot(self.scan)
00177         finish=raw_input("Press enter to finish interactive flagging:")
00178         guiflagger._plotter.unmap()
00179         ismodified = guiflagger._ismodified
00180         guiflagger._plotter = None
00181         self.anyflag = self.anyflag or ismodified
00182 
00183     def save(self):
00184         sdutil.save(self.scan, self.project, self.outform, self.overwrite)
00185 
00186     def prior_plot(self):
00187         nr = self.scan.nrow()
00188         np = min(nr,16)
00189         if nr >16:
00190             casalog.post( "Only first 16 spectra is plotted.", priority = 'WARN' )
00191 
00192         self.myp.set_panels(rows=np,cols=0,nplots=np)
00193         self.myp.legend(loc=1)
00194         labels = ['Spectrum','current flag masks','previously flagged data']
00195         masklist =  [ None,      None,                None]
00196         idefaultmask = logical_not(array(self.masks))
00197         for row in xrange(np):
00198             self.myp.subplot(row)
00199             x = self.scan._getabcissa(row)
00200             y = self.scan._getspectrum(row)
00201             nchan = len(y)
00202 
00203             if self.scan._getflagrow(row):
00204                 masklist[2] = array([False]*(nchan))
00205             else:
00206                 masklist[2] = array(self.scan._getmask(row))
00207 
00208             if self.clip:
00209                 if self.threshold[0] == self.threshold[1]:
00210                     masklist[1] = array([True]*nchan)
00211                 else:
00212                     masklist[1] = array(self.scan._getclipmask(row, self.threshold[1], self.threshold[0], (not self.clipoutside), self.unflag))
00213             elif len(self.flagrow) > 0:
00214                 masklist[1] = array([(row not in self.flagrow) or self.unflag]*nchan)
00215             else:
00216                 masklist[1] = idefaultmask
00217             masklist[0] = logical_not(logical_and(masklist[1],masklist[2]))
00218             for i in xrange(3):
00219                 plot_data(self.myp,x,y,masklist[i],i,labels[i])
00220             xlim=[min(x),max(x)]
00221             self.myp.axes.set_xlim(xlim)
00222 
00223             labels = ['spec','flag','prev']
00224         self.myp.release()
00225         
00226         #Apply flag
00227         if self.plotlevel > 0 and sd.rcParams['plotter.gui']:
00228             ans=raw_input("Apply %s (y/N)?: " % self.flagmode)
00229         else:
00230             casalog.post("Applying selected flags")
00231             ans = 'Y'
00232 
00233         # update self.docmdflag
00234         self.docmdflag = (ans.upper() == 'Y')
00235     def posterior_plot(self):
00236         #Plot the result
00237         #print "Showing only the first spectrum..."
00238         casalog.post( "Showing only the first spectrum..." )
00239         row=0
00240 
00241         self.myp.set_panels()
00242         x = self.scan._getabcissa(row)
00243         y = self.scan._getspectrum(row)
00244         allmskarr=array(self.scan._getmask(row))
00245         plot_data(self.myp,x,y,logical_not(allmskarr),0,"Spectrum after %s" % self.flagmode+'ging')
00246         plot_data(self.myp,x,y,allmskarr,2,"Flagged")
00247         xlim=[min(x),max(x)]
00248         self.myp.axes.set_xlim(xlim)
00249         if ( self.plotlevel < 0 ):
00250             # Hardcopy - currently no way w/o screen display first
00251             pltfile=self.project+'_flag.eps'
00252             self.myp.save(pltfile)
00253         self.myp.release()
00254     
00255     def __init_plotter(self):
00256         colormap = ["green","red","#dddddd","#777777"]
00257         self.myp = sdutil.get_plotter(self.plotlevel)
00258         casalog.post('Create new plotter')
00259         self.myp.palette(0,colormap)
00260         self.myp.hold()
00261         self.myp.clear()
00262 
00263 
00264 def plot_data(myp,x,y,msk,color=0,label=None):
00265     if label:
00266         myp.set_line(label=label)
00267     myp.palette(color)
00268     ym = ma.masked_array(y,mask=msk)
00269     myp.plot(x,ym)
00270