casa  $Rev:20696$
 All Classes Namespaces Files Functions Variables
testhelper.py
Go to the documentation of this file.
00001 from casac import casac
00002 import os
00003 import commands
00004 import math
00005 import shutil
00006 import string
00007 import time
00008 from taskinit import casalog,tb
00009 import numpy as np
00010 
00011 '''
00012 A set of helper functions for unit tests:
00013    compTables - compare two CASA tables
00014    DictDiffer - a class with methods to take a difference of two 
00015                 Python dictionaries
00016 '''
00017 
00018 def compTables(referencetab, testtab, excludecols, tolerance=0.001):
00019 
00020     """
00021     compTables - compare two CASA tables
00022     
00023        referencetab - the table which is assumed to be correct
00024 
00025        testtab - the table which is to be compared to referencetab
00026 
00027        excludecols - list of column names which are to be ignored
00028 
00029        tolerance - permitted fractional difference (default 0.001 = 0.1 percent)
00030     """
00031 
00032     rval = True
00033 
00034     tb2 = casac.table()
00035 
00036     tb.open(referencetab)
00037     cnames = tb.colnames()
00038 
00039     #print cnames
00040 
00041     tb2.open(testtab)
00042 
00043     try:
00044         for c in cnames:
00045             if c in excludecols:
00046                 continue
00047             print c
00048             a = tb.getcol(c)
00049             #print a
00050             b = 0
00051             try:
00052                 b = tb2.getcol(c)
00053             except:
00054                 rval = False
00055                 print 'Error accessing column ', c, ' in table ', testtab
00056                 print sys.exc_info()[0]
00057                 break
00058             #print b
00059             if not (len(a)==len(b)):
00060                 print 'Column ',c,' has different length in tables ', referencetab, ' and ', testtab
00061                 print a
00062                 print b
00063                 rval = False
00064                 break
00065             else:
00066                 if not (a==b).all():
00067                     differs = False
00068                     for i in range(0,len(a)):
00069                         if (type(a[i])==float):
00070                             if (abs(a[i]-b[i]) > tolerance*abs(a[i]+b[i])):
00071                                 print 'Column ',c,' differs in tables ', referencetab, ' and ', testtab
00072                                 print i
00073                                 print a[i]
00074                                 print b[i]
00075                                 differs = True
00076                         elif (type(a[i])==int):
00077                             if (abs(a[i]-b[i]) > 0):
00078                                 print 'Column ',c,' differs in tables ', referencetab, ' and ', testtab
00079                                 print i
00080                                 print a[i]
00081                                 print b[i]
00082                                 differs = True
00083                         elif (type(a[i])==str):
00084                             if not (a[i]==b[i]):
00085                                 print 'Column ',c,' differs in tables ', referencetab, ' and ', testtab
00086                                 print i
00087                                 print a[i]
00088                                 print b[i]
00089                                 differs = True
00090                         elif (type(a[i])==list or type(a[i])==np.ndarray):
00091                             for j in range(0,len(a[i])):
00092                                 if (type(a[i][j])==float or type(a[i][j])==int):
00093                                     if (abs(a[i][j]-b[i][j]) > tolerance*abs(a[i][j]+b[i][j])):
00094                                         print 'Column ',c,' differs in tables ', referencetab, ' and ', testtab
00095                                         print i, j
00096                                         print a[i][j]
00097                                         print b[i][j]
00098                                         differs = True
00099                                 elif (type(a[i][j])==list or type(a[i][j])==np.ndarray):
00100                                     for k in range(0,len(a[i][j])):
00101                                         if (abs(a[i][j][k]-b[i][j][k]) > tolerance*abs(a[i][j][k]+b[i][j][k])):
00102                                             print 'Column ',c,' differs in tables ', referencetab, ' and ', testtab
00103                                             print i, j, k
00104                                             print a[i][j][k]
00105                                             print b[i][j][k]
00106                                             differs = True
00107                     if differs:
00108                         rval = False
00109                         break
00110     finally:
00111         tb.close()
00112         tb2.close()
00113 
00114     if rval:
00115         print 'Tables ', referencetab, ' and ', testtab, ' agree.'
00116 
00117     return rval
00118 
00119 
00120 def compVarColTables(referencetab, testtab, varcol):
00121     '''Compare a variable column of two tables.
00122        referencetab  --> a reference table
00123        testtab       --> a table to verify
00124        varcol        --> the name of a variable column (str)
00125        Returns True or False.
00126     '''
00127     
00128     retval = True
00129     tb2 = casac.table()
00130 
00131     tb.open(referencetab)
00132     cnames = tb.colnames()
00133 
00134     tb2.open(testtab)
00135     col = varcol
00136     if tb.isvarcol(col) and tb2.isvarcol(col):
00137         try:
00138             rcol = tb.getvarcol('DATA')
00139             tcol = tb2.getvarcol('DATA')
00140             rk = rcol.keys()
00141             tk = tcol.keys()
00142             
00143             # First check
00144             if len(rk) != len(tk):
00145                 print 'Length of %s differ from %s, %s!=%s'%(referencetab,testtab,len(rk),len(tk))
00146                 retval = False
00147                 
00148             for k in rk:
00149                 rdata = rcol[k]
00150                 tdata = tcol[k]
00151                 if not (rdata==tdata).all():
00152                     print 'ERROR: Column %s of %s and %s do not agree'%(col,referencetab, testtab)
00153                     retval = False
00154                     break
00155         finally:
00156             tb.close()
00157             tb2.close()
00158     
00159     else:
00160         retval = False
00161 
00162     if retval:
00163         print 'Column %s of %s and %s agree'%(col,referencetab, testtab)
00164         
00165     return retval
00166 
00167     
00168         
00169 class DictDiffer(object):
00170     """
00171     Calculate the difference between two dictionaries as:
00172     (1) items added
00173     (2) items removed
00174     (3) keys same in both but changed values
00175     (4) keys same in both and unchanged values
00176     Example:
00177             mydiff = DictDiffer(dict1, dict2)
00178             mydiff.changed()  # to show what has changed
00179     """
00180     def __init__(self, current_dict, past_dict):
00181         self.current_dict, self.past_dict = current_dict, past_dict
00182         self.set_current, self.set_past = set(current_dict.keys()), set(past_dict.keys())
00183         self.intersect = self.set_current.intersection(self.set_past)
00184     def added(self):
00185         return self.set_current - self.intersect 
00186     def removed(self):
00187         return self.set_past - self.intersect 
00188     def changed(self):
00189         return set(o for o in self.intersect if self.past_dict[o] != self.current_dict[o])            
00190     def unchanged(self):
00191         return set(o for o in self.intersect if self.past_dict[o] == self.current_dict[o])
00192