Go to the documentation of this file.00001 import os
00002 import sys
00003 import shutil
00004 import commands
00005 import numpy
00006 import numpy.ma as ma
00007 import random
00008 from __main__ import default
00009 from tasks import *
00010 from taskinit import *
00011 from cleanhelper import *
00012 import unittest
00013
00014 '''
00015 Unit tests for statwt
00016 '''
00017
00018
00019
00020
00021
00022
00023
00024 class statwt_test(unittest.TestCase):
00025
00026
00027 msfile = 'ngc5921.ms'
00028 res = False
00029
00030 def setUp(self):
00031 if (os.path.exists(self.msfile)):
00032 os.system('rm -rf ' + self.msfile)
00033
00034 datapath=os.environ.get('CASAPATH').split()[0]+'/data/regression/unittest/visstat/'
00035 shutil.copytree(datapath+self.msfile, self.msfile)
00036
00037 def tearDown(self):
00038 if (os.path.exists(self.msfile)):
00039
00040 pass
00041
00042 def calcVariance(self,specData):
00043 """
00044 calculate variance of a single row of complex vis data
00045 input: specData - numpy masked array
00046 """
00047 dev2 = 0.0
00048 dmean = specData.mean()
00049 nchan=len(specData)
00050 for n in xrange(nchan):
00051 dev = specData[n] - dmean
00052 dev2 += dev*dev.conjugate()
00053 var = (1./(nchan-1))*dev2
00054 return var.real
00055
00056 def calcwt(self,selrow,selcorr,datcol,flagcol):
00057 """
00058 calc weight,sigma from the data
00059 """
00060 dsel = datcol[selcorr]
00061 fsel = flagcol[selcorr]
00062 dsel = dsel.transpose()
00063 fsel = fsel.transpose()
00064 dspec = dsel[selrow]
00065 flagc = fsel[selrow]
00066 mdspec = ma.masked_array(dspec,flagc)
00067 dmean = mdspec.mean()
00068 var = self.calcVariance(mdspec)
00069 sig = numpy.sqrt(var)
00070 rms = numpy.sqrt(dmean*dmean.conjugate()+var)
00071 return (var,sig,rms)
00072
00073 def test_default(self):
00074 """
00075 test default case
00076 """
00077 tol = 1.e-5
00078 self.res=statwt(vis=self.msfile)
00079
00080 self.assertTrue(self.res)
00081
00082 tb.open(self.msfile)
00083 datc=tb.getcol('DATA')
00084 wt = tb.getcol('WEIGHT')
00085 sg = tb.getcol('SIGMA')
00086 fg = tb.getcol('FLAG')
00087 nr = tb.nrows()
00088 tb.close()
00089
00090
00091
00092 random.seed()
00093 randomRowList=random.sample(xrange(nr),10)
00094 for i in random.sample(xrange(nr),10):
00095 icorr = random.randint(0,1)
00096 (v,s,r) = self.calcwt(i,icorr,datc,fg)
00097 diffwt = 1/v - wt[icorr][i]
00098 diffsg = s - sg[icorr][i]
00099
00100
00101 self.assertTrue(abs(diffwt)/(1/v) < tol)
00102 self.assertTrue(abs(diffsg)/s < tol)
00103
00104 def suite():
00105 return [statwt_test]
00106
00107 if __name__ == '__main__':
00108 testSuite = []
00109 for testClass in suite():
00110 testSuite.append(unittest.makeSuite(testClass,'test'))
00111 allTests = unittest.TestSuite(testSuite)
00112 unittest.TextTestRunner(verbosity=2).run(allTests)