casa  $Rev:20696$
 All Classes Namespaces Files Functions Variables
task_plotuv.py
Go to the documentation of this file.
00001 from matplotlib.widgets import Button
00002 from taskinit import ms, tbtool, casalog
00003 from update_spw import expand_tilde
00004 
00005 #from taskutil import get_global_namespace
00006 #my_globals = get_global_namespace()
00007 #pl = my_globals['pl']
00008 #del my_globals
00009 
00010 import pylab as pl
00011 
00012 def plotuv(vis=None, field=None, antenna=None, spw=None, observation=None, array=None,
00013            maxnpts=None, colors=None, symb=None, ncycles=None, figfile=None):
00014     """
00015     Plots the uv coverage of vis in klambda.  ncycles of colors will be
00016     allocated to representative wavelengths.
00017 
00018     colors: a list of matplotlib color codes.
00019     symb: One of matplotlib's codes for plot symbols: .:,o^v<>s+xDd234hH|_
00020           default: ',':  The smallest points I could find.
00021     maxnpts: Save memory and/or screen space by plotting a maximum of maxnpts
00022              (or all of them if maxnpts < 1).  There is a very sharp
00023              slowdown if the plotter starts swapping.
00024     spw: spw selection string (ignores channel specifications for now).
00025     field: field selection string (for now, only 1 field will be plotted).
00026     antenna: antenna selection string (currently ignored).
00027     """
00028     casalog.origin('plotuv')
00029     try:
00030         uvplotinfo = UVPlotInfo(vis, spw, field, antenna, observation, array,
00031                                 ncycles, colors, symb, figfile, maxnpts)
00032     except Exception, e:
00033         casalog.post("Error plotting the UVWs of %s:" % vis, 'SEVERE')
00034         casalog.post("%s" % e, 'SEVERE')
00035         return False
00036     retval = True
00037     try:
00038         if len(uvplotinfo.selindices['field']) > 1:
00039             fldnav = NavField(uvplotinfo)
00040             #inprogress = fldnav.show()
00041             fldnav.next("dummy")
00042         else:
00043             retval = plotfield(uvplotinfo.selindices['field'][0], uvplotinfo)
00044     except Exception, e:
00045         casalog.post("Error plotting the UVWs of %s:" % vis, 'SEVERE')
00046         casalog.post("%s" % e, 'SEVERE')
00047         return False
00048     return retval
00049 
00050 class UVPlotInfo:
00051     """Gathers and holds info for a uv plot or set of them."""
00052     def __init__(self, vis, spw, field, antenna, observation, array,
00053                  ncycles, colors, symb, figfile, maxnpts):
00054         self.ncolors = ncycles * len(colors)
00055         try:
00056             self.symbs = [c + symb for c in colors]
00057         except Exception, e:
00058             raise ValueError, "Exception %s forming the plot symbols out of %s and %s" % (e, colors, symb)
00059         self.nsymbs = len(self.symbs)
00060         self.maxnpts = maxnpts
00061         try:
00062             self.figfile = figfile
00063             if figfile:
00064                 figfileparts = figfile.split('.')
00065                 self.ext = '.' + figfileparts[-1]
00066                 self.figfile = '.'.join(figfileparts[:-1]) + '_fld'
00067         except Exception, e:
00068             raise ValueError, "Exception %s parsing figfile" % e
00069 
00070         self.vis = vis
00071         self.title = vis
00072         self.tb = tbtool()
00073 
00074         # Convert '' to '*' for ms.msseltoindex.
00075         if not spw:
00076             spw = '*'
00077         if not field:
00078             field = '*'
00079         if not antenna:
00080             antenna = '*'
00081         self.selindices = ms.msseltoindex(vis, field=field, spw=spw,
00082                                           baseline=antenna,
00083                                           observation=str(observation))
00084         basequery = ""
00085         if observation:
00086             basequery += 'OBSERVATION_ID in %s' % self.selindices['obsids']
00087         if antenna != '*':
00088             if basequery:
00089                 basequery += ' and '
00090             basequery += 'ANTENNA1 in [%s] and ANTENNA2 in [%s]' % (
00091                 ','.join(map(str, self.selindices['antenna1'])),
00092                 ','.join(map(str, self.selindices['antenna2'])))
00093         if array:
00094             if basequery:
00095                 basequery += ' and '
00096             arrids = expand_tilde(array)
00097             basequery += 'ARRAY_ID in [%s]' % ','.join(map(str, arrids))
00098         self.basequery = basequery
00099 
00100         try:
00101             self.fldnames = {}
00102             self.listfield = False
00103             if field != '*' or len(self.selindices['field']) > 1:
00104                 self.listfield = True
00105                 self.tb.open(vis + '/FIELD')
00106                 fldnamarr = self.tb.getcol('NAME')
00107                 self.tb.close()
00108                 for i in xrange(len(fldnamarr)):
00109                     self.fldnames[i] = fldnamarr[i]
00110             self.subtitle = ''
00111             if spw != '*' or antenna != '*' or observation:
00112                 subtitles = []
00113                 if spw != '*':
00114                     subtitles.append("spw='%s'" % spw)
00115                 if antenna != '*':
00116                     subtitles.append("antenna='%s'" % antenna)
00117                 if observation:
00118                     subtitles.append("observation='%s'" % observation)
00119                 self.subtitle = '(' + ', '.join(subtitles) + ')'
00120         except Exception, e:
00121             raise ValueError, "Exception %s parsing the selection." % e
00122 
00123         try:
00124             self.tb.open(vis + '/SPECTRAL_WINDOW')
00125             self.chfs = self.tb.getvarcol('CHAN_FREQ')
00126             self.tb.close()
00127             self.nspw = len(self.chfs)
00128             if self.nspw > 1:
00129                 # Bite the bullet now instead of while the main table is open.
00130                 self.tb.open(vis + '/DATA_DESCRIPTION')
00131                 self.dd_to_spw = self.tb.getcol('SPECTRAL_WINDOW_ID')
00132                 self.tb.close()
00133             else:
00134                 self.dd_to_spw = [0]
00135         except Exception, e:
00136             raise ValueError, "Exception %s getting the frequencies from %s" % (e, vis)
00137 
00138 def plotfield(fld, uvplotinfo, debug=False):
00139     """Plot the selected baselines of fld."""
00140     fldquery = uvplotinfo.basequery
00141     fldtitle = uvplotinfo.title
00142     if uvplotinfo.listfield:
00143         if fldquery:
00144             fldquery += ' and '
00145         fldquery += 'FIELD_ID==' + str(fld)
00146         fldtitle += ', field %d (%s)' % (fld, uvplotinfo.fldnames[fld])
00147         casalog.post('Plotting field %d (%s)' % (fld, uvplotinfo.fldnames[fld]))
00148 
00149     # Figure out how to divvy up the plotting among the frequencies.
00150     # I'm sure nested queries can be done, but I want to avoid temp tables
00151     # on disk.
00152     casalog.post('fldquery: ' + fldquery, 'DEBUG1')
00153     uvplotinfo.tb.open(uvplotinfo.vis)
00154     ftab = uvplotinfo.tb.query(fldquery, columns='DATA_DESC_ID')
00155     nbl = ftab.nrows()
00156     casalog.post("nbl: %d" % nbl, 'DEBUG1')
00157 
00158     if uvplotinfo.nspw > 1:
00159         ddids = ftab.getcol('DATA_DESC_ID')
00160         usedddids = {}
00161         for d in ddids:
00162             if uvplotinfo.dd_to_spw[d] in uvplotinfo.selindices['spw']:
00163                 usedddids[d] = uvplotinfo.dd_to_spw[d]
00164         ddids = list(usedddids.keys())
00165         ddids.sort()
00166         usedspws = list(set(usedddids.values()))
00167         usedspws.sort()
00168     elif uvplotinfo.nspw == 1:
00169         ddids = [0]
00170         usedspws = [0]
00171     else:
00172         ddids = []
00173         usedspws = []
00174     ftab.close()
00175 
00176     if not ddids:
00177         casalog.post('Nothing selected for field %d' % fld)
00178         return False
00179 
00180     maxddid = ddids[-1]
00181     minddid = ddids[0]
00182     minmax = {}
00183     globminf = -1
00184     globmaxf = -1
00185     for s in usedspws:
00186         r = 'r' + str(s + 1)
00187         minf = uvplotinfo.chfs[r][0, 0]
00188         maxf = uvplotinfo.chfs[r][-1, 0]
00189         if maxf < minf:
00190             minf, maxf = maxf, minf
00191         minmax[s] = (minf, maxf)
00192         #print "minmax[s]:", minmax[s]
00193         if minf < globminf or globminf < 0.0:
00194             globminf = minf
00195         if maxf > globmaxf:
00196             globmaxf = maxf
00197     freqspan = globmaxf - globminf
00198 
00199     def colorind(f):
00200         if freqspan > 0:
00201             return min(int(uvplotinfo.ncolors * (f - globminf) / freqspan),
00202                        uvplotinfo.ncolors - 1)
00203         else:
00204             return 0
00205     
00206     if uvplotinfo.maxnpts > 0 and nbl * (1 + colorind(globmaxf) -
00207                               colorind(globminf)) > uvplotinfo.maxnpts:
00208         casalog.post(
00209  "Only plotting %d out of %d (scaled) baselines to conserve memory" % (uvplotinfo.maxnpts,
00210                           nbl * (1 + colorind(globmaxf) - colorind(globminf))),
00211                      'WARN')
00212 
00213     if fldquery:
00214         fldquery += ' and '
00215     pl.ion()                 # Magic incantation to make the plot
00216     pl.clf()                 # window appear (and clear it).
00217     pl.ioff()
00218     for d in ddids:
00219         #print "minmax[%d] = %s" % (d, minmax[d])
00220         s = uvplotinfo.dd_to_spw[d]
00221         maxci = colorind(minmax[s][1])
00222         minci = colorind(minmax[s][0])
00223         ncolsspanned = 1 + maxci - minci
00224 
00225         # Get the subset of UVW that will be plotted for this spw.
00226         st = uvplotinfo.tb.query(fldquery + 'DATA_DESC_ID==' + str(d), columns='UVW')
00227         snbl = st.nrows()
00228         #print "snbl:", snbl
00229         if snbl > 0:
00230             uvw = 0.001 * st.getcol('UVW')
00231             st.close()
00232             if uvplotinfo.maxnpts > 0:
00233                 ntoplot = (uvplotinfo.maxnpts * snbl) / (nbl * ncolsspanned)
00234             else:
00235                 ntoplot = snbl
00236             if ntoplot < snbl:
00237                 uvinds = [((snbl - 1) * uvi) / (ntoplot - 1) for uvi in xrange(ntoplot)]
00238                 casalog.post("(max, min)(uvinds) = %g, %g" % (max(uvinds), min(uvinds)),
00239                              'DEBUG1')
00240                 casalog.post("len(uvw[0]) = %d" % len(uvw[0]), 'DEBUG1')
00241                 u = uvw[0, uvinds]
00242                 v = uvw[1, uvinds]
00243             else:
00244                 u = uvw[0, :]
00245                 v = uvw[1, :]
00246             del uvw
00247             casalog.post('len(u) = %d' % len(u), 'DEBUG1')
00248 
00249             freqspread = minmax[s][1] - minmax[s][0]
00250             # It'd be easier to just divide the frequency range by ncolors, but those
00251             # frequencies might not land on real channels.
00252             chfkey = 'r' + str(s + 1)
00253             nchans = uvplotinfo.chfs[chfkey].shape[0]
00254             casalog.post("nchans: %d" % nchans, 'DEBUG1')
00255             if ncolsspanned > 1:
00256                 cinds = [((nchans - 1) * c) / (ncolsspanned - 1) for c in
00257                          xrange(ncolsspanned)]
00258             else:
00259                 cinds = [nchans / 2]
00260             wvlngths = 2.9978e8 / uvplotinfo.chfs[chfkey].flatten()[cinds]
00261 
00262             # All this fussing with permutations and sieves is to give all the
00263             # frequencies a chance at being seen, at least in the case of a
00264             # single spw.  This way a channel will only blot out the plot where
00265             # either it really does have a much higher density than the others
00266             # or all the channels overlap.
00267             perm = pl.array(range(ncolsspanned - 1, -1, -1))
00268             if debug:
00269                 print "****s:", s
00270                 print "perm:", perm
00271                 print "ntoplot:", ntoplot
00272                 print "ncolsspanned:", ncolsspanned
00273                 print "minci:", minci
00274             for si in perm:
00275                 if debug:
00276                     print '(perm + si) % ncolsspanned =', (perm + si) % ncolsspanned
00277                 for ci in (perm + si) % ncolsspanned:
00278                     symb = uvplotinfo.symbs[(ci + minci) % uvplotinfo.nsymbs]
00279                     wvlngth = wvlngths[ci]
00280                     #casalog.post("spw %d, lambda: %g" % (s, wvlngth), 'DEBUG1')
00281                     casalog.post(
00282                         "d %d, spw %d, si %d, ntoplot %d, ncolsspanned %d, symb %s" %
00283                                  (d, s, si, ntoplot, ncolsspanned, symb), 'DEBUG1')
00284                     pl.plot( u[si:ntoplot:ncolsspanned] / wvlngth,
00285                              v[si:ntoplot:ncolsspanned] / wvlngth, symb)
00286                     pl.plot(-u[si:ntoplot:ncolsspanned] / wvlngth,
00287                             -v[si:ntoplot:ncolsspanned] / wvlngth, symb)
00288                     casalog.post('plotted baselines both ways', 'DEBUG1')
00289         else:
00290             st.close()
00291     uvplotinfo.tb.close()
00292     casalog.post('Scaling axes', 'DEBUG1')
00293     pl.axis('equal')
00294     pl.axis('scaled')
00295     pl.xlabel('u (k$\lambda$)')
00296     pl.ylabel('v (k$\lambda$)')
00297     if uvplotinfo.subtitle:
00298         pl.suptitle(fldtitle, fontsize=14)
00299         pl.title(uvplotinfo.subtitle, fontsize=10)
00300     else:
00301         pl.title(fldtitle)                
00302     if uvplotinfo.figfile:
00303         pl.savefig(uvplotinfo.figfile + str(fld) + uvplotinfo.ext)
00304     pl.draw()
00305     pl.ion()
00306     return True
00307 
00308 class NavField:
00309     def __init__(self, pltinfo):
00310         self.fld = -1
00311         self.nflds = len(pltinfo.selindices['field'])
00312         self.pltinfo = pltinfo
00313         prevwidth = 0.13
00314         nextwidth = 0.12
00315         butheight = 0.05
00316         butleft = 0.7
00317         butbot = 0.025
00318         butgap = 0.5 * butbot
00319         self.nextloc = [butleft + prevwidth + butgap,
00320                         butbot, nextwidth, butheight]
00321         self.prevloc = [butleft, butbot, prevwidth, butheight]
00322         self.inactivecolor = '#99aaff'
00323         self.activecolor = '#aaffcc'        
00324 
00325     def _draw_buts(self):
00326         if self.fld < self.nflds - 1:
00327             axnext = pl.axes(self.nextloc)
00328             self.bnext = Button(axnext, 'Next fld >',
00329                                 color=self.inactivecolor,
00330                                 hovercolor=self.activecolor)
00331             self.bnext.on_clicked(self.next)
00332         if self.fld > 0:
00333             axprev = pl.axes(self.prevloc)
00334             self.bprev = Button(axprev, '< Prev fld',
00335                                 color=self.inactivecolor,
00336                                 hovercolor=self.activecolor)
00337             self.bprev.on_clicked(self.prev)
00338         #pl.show()
00339 
00340     def _do_plot(self):
00341         didplot = plotfield(self.pltinfo.selindices['field'][self.fld],
00342                             self.pltinfo)
00343         if didplot:
00344             self._draw_buts()
00345         return didplot
00346     
00347     def next(self, event):
00348         didplot = False
00349         startfld = self.fld
00350         while self.fld < self.nflds - 1 and not didplot:
00351             self.fld += 1
00352             didplot = self._do_plot()
00353         if not didplot:
00354             print "You are on the last field with any selected baselines."
00355             self.fld = startfld
00356             #plotfield(self.pltinfo.selindices['field'][self.fld],
00357             #          self.pltinfo, self.mytb)
00358 
00359     def prev(self, event):
00360         didplot = False
00361         startfld = self.fld
00362         while self.fld > 0 and not didplot:
00363             self.fld -= 1
00364             didplot = self._do_plot()
00365         if not didplot:
00366             print "You are on the first field with any selected baselines."
00367             self.fld = startfld
00368             #plotfield(self.pltinfo.selindices['field'][self.fld],
00369             #          self.pltinfo, self.mytb)