interactive-mining/interactive-mining-3rdparty.../madis/src/functions/__init__.py

613 lines
20 KiB
Python
Executable File

"""functions
"""
VERSION = "1.9"
import setpath
import os.path
import os
import apsw
import sqltransform
import traceback
import logging
import re
import sys
import copy
try:
from collections import OrderedDict
except ImportError:
# Python 2.6
from lib.collections26 import OrderedDict
try:
from inspect import isgeneratorfunction
except ImportError:
# Python < 2.6
def isgeneratorfunction(obj):
return bool((inspect.isfunction(object) or inspect.ismethod(object)) and
obj.func_code.co_flags & CO_GENERATOR)
sys.setcheckinterval(1000)
sqlite_version = apsw.sqlitelibversion()
apsw_version = apsw.apswversion()
VTCREATE = 'create virtual table temp.'
SQLITEAFTER3711 = False
SQLITEAFTER380 = False
sqlite_version_split = [int(x) for x in sqlite_version.split('.')]
if sqlite_version_split[0:3] >= [3,8,0]:
SQLITEAFTER380 = True
try:
if sqlite_version_split[0:3] >= [3,7,11]:
VTCREATE = 'create virtual table if not exists temp.'
SQLITEAFTER3711 = True
except Exception, e:
VTCREATE = 'create virtual table if not exists temp.'
SQLITEAFTER3711 = True
firstimport=True
test_connection = None
settings={
'tracing':False,
'vtdebug':False,
'logging':False,
'syspath':str(os.path.abspath(os.path.expandvars(os.path.expanduser(os.path.normcase(sys.path[0])))))
}
functions = {'row': {}, 'aggregate': {}, 'vtable': {}}
multiset_functions = {}
iterheader = 'ITER'+chr(30)
variables = lambda _: _
variables.flowname = ''
variables.execdb = None
variables.filename = ''
privatevars=lambda _: _
rowfuncs=lambda _: _
oldexecdb=-1
ExecutionCompleteError = apsw.ExecutionCompleteError
def getvar(name):
return variables.__dict__[name]
def setvar(name, value):
variables.__dict__[name] = value
def mstr(s):
if s==None:
return None
try:
return unicode(s, 'utf-8', errors='replace')
except KeyboardInterrupt:
raise
except:
# Parse exceptions that cannot be converted by unicode above
try:
return str(s)
except KeyboardInterrupt:
raise
except:
pass
o=repr(s)
if (o[0:2]=="u'" and o[-1]=="'") or (o[0:2]=='u"' and o[-1]=='"'):
o=o[2:-1]
elif (o[0]=="'" and o[-1]=="'") or (o[0]=='"' and o[-1]=='"'):
o=o[1:-1]
o=o.replace('''\\n''','\n')
o=o.replace('''\\t''','\t')
return o
class MadisError(Exception):
def __init__(self,msg):
self.msg=mstr(msg)
def __str__(self):
merrormsg="Madis SQLError: \n"
if self.msg.startswith(merrormsg):
return self.msg
else:
return merrormsg+self.msg
class OperatorError(MadisError):
def __init__(self,opname,msg):
self.msg="Operator %s: %s" %(mstr(opname.upper()),mstr(msg))
class DynamicSchemaWithEmptyResultError(MadisError):
def __init__(self,opname):
self.msg="Operator %s: Cannot initialize dynamic schema virtual table without data" %(mstr(opname.upper()))
def echofunctionmember(func):
def wrapper(*args, **kw):
if settings['tracing']:
if settings['logging']:
try:
lg = logging.LoggerAdapter(logging.getLogger(__name__),{ "flowname" : variables.flowname })
if hasattr(lg.logger.parent.handlers[0],'baseFilename'):
lg.info("%s(%s)" %(func.__name__,','.join(list([repr(el) for el in args[1:]])+["%s=%s" %(k,repr(v)) for k,v in kw.items()])))
except Exception:
pass
print "%s(%s)" %(func.__name__,','.join(list([repr(el)[:200]+('' if len(repr(el))<=200 else '...') for el in args[1:]])+["%s=%s" %(k,repr(v)) for k,v in kw.items()]))
return func(*args, **kw)
return wrapper
def iterwrapper(con, func, *args):
global iterheader
i=func(*args)
si=iterheader+str(i)
con.openiters[si]=i
return buffer(si)
def iterwrapperaggr(con, func, self):
global iterheader
i=func(self)
si=iterheader+str(i)
con.openiters[si]=i
return buffer(si)
class Cursor(object):
def __init__(self,w):
self.__wrapped=w
self.__vtables=[]
self.__permanentvtables=OrderedDict()
self.__query = ''
self.__initialised=True #this should be last in init
def __getattr__(self, attr):
if self.__dict__.has_key(attr):
return self.__dict__[attr]
return getattr(self.__wrapped, attr)
def __setattr__(self, attr, value):
if self.__dict__.has_key(attr):
return object.__setattr__(self, attr, value)
if not self.__dict__.has_key('_Cursor__initialised'): # this test allows attributes to be set in the __init__ method
return object.__setattr__(self, attr, value)
return setattr(self.__wrapped, attr, value)
@echofunctionmember
def executetrace(self,statements,bindings=None):
try:
return self.__wrapped.execute(statements,bindings)
except Exception, e:
try: # avoid masking exception in recover statements
raise e, None, sys.exc_info()[2]
finally:
try:
self.cleanupvts()
except:
pass
def execute(self,statements,bindings=None,parse=True, localbindings=None): # overload execute statement
if localbindings!=None:
bindings=localbindings
else:
if bindings==None:
bindings=variables.__dict__
else:
if type(bindings) is dict:
bindings.update(variables.__dict__)
if not parse:
self.__query = statements
return self.executetrace(statements,bindings)
svts=sqltransform.transform(statements, multiset_functions.keys(), functions['vtable'], functions['row'].keys(), substitute=functions['row']['subst'])
s=svts[0]
try:
if self.__vtables != []:
self.executetrace(''.join(['drop table ' + 'temp.'+x +';' for x in reversed(self.__vtables)]))
self.__vtables = []
for i in svts[1]:
createvirtualsql=None
if re.match(r'\s*$', i[2]) is None:
sep=','
else:
sep=''
createvirtualsql = VTCREATE+i[0]+ ' using ' + i[1] + "(" + i[2] + sep + "'automatic_vtable:1'" +")"
try:
self.executetrace(createvirtualsql)
except Exception, e:
strex = mstr(e)
if SQLITEAFTER3711 or type(e) != apsw.SQLError or strex.find('already exists')==-1 or strex.find(i[0])==-1:
raise e, None, sys.exc_info()[2]
else:
self.__permanentvtables[i[0]]=createvirtualsql
if len(i)==4:
self.__permanentvtables[i[0]]=createvirtualsql
else:
self.__vtables.append(i[0])
self.__query = s
return self.executetrace(s, bindings)
except Exception, e:
if settings['tracing']:
traceback.print_exc(limit=sys.getrecursionlimit())
try: # avoid masking exception in recover statements
raise e, None, sys.exc_info()[2]
finally:
try:
self.cleanupvts()
except:
pass
def getdescriptionsafe(self):
try:
# Try to get the schema the normal way
schema = self.__wrapped.getdescription()
except apsw.ExecutionCompleteError:
# Else create a tempview and query the view
if not self.__query.strip().lower().startswith('select'):
raise apsw.ExecutionCompleteError
try:
list(self.executetrace('create temp view temp.___schemaview as '+ self.__query + ';'))
schema = [(x[1], x[2]) for x in list(self.executetrace('pragma table_info(___schemaview);'))]
list(self.executetrace('drop view temp.___schemaview;'))
except Exception, e:
raise apsw.ExecutionCompleteError
return schema
def close(self, force=False):
self.cleanupvts()
return self.__wrapped.close(force)
def cleanupvts(self):
if self.__vtables!=[]:
for t in reversed(self.__vtables):
self.executetrace('drop table if exists ' + 'temp.'+t)
self.__vtables=[]
class Connection(apsw.Connection):
def cursor(self):
if 'registered' not in self.__dict__:
self.registered=True
register(self)
self.openiters = {}
return Cursor(apsw.Connection.cursor(self))
def queryplan(self, statements, bindings=None, parse=True, localbindings=None):
def authorizer(operation, paramone, paramtwo, databasename, triggerorview):
"""Called when each operation is prepared. We can return SQLITE_OK, SQLITE_DENY or SQLITE_IGNORE"""
# find the operation name
plan.append([apsw.mapping_authorizer_function[operation], paramone, paramtwo, databasename, triggerorview])
return apsw.SQLITE_OK
def buststatementcache():
c = self.cursor()
for i in xrange(110):
a = list(c.execute("select "+str(i)))
plan = []
buststatementcache()
cursor = self.cursor()
cursor.setexectrace(lambda v1, v2, v3: apsw.SQLITE_DENY)
self.setauthorizer(authorizer)
cursor.execute(statements)
self.setauthorizer(None)
cursor.close()
yield (('operation', 'text'), ('paramone', 'text'), ('paramtwo', 'text'), ('databasename', 'text'), ('triggerorview', 'text'))
for r in plan:
if r[1] not in ('sqlite_temp_master', 'sqlite_master'):
yield r
@echofunctionmember
def close(self):
apsw.Connection.close(self)
def register(connection=None):
global firstimport, oldexecdb
if connection == None:
if 'SQLITE_OPEN_URI' in apsw.__dict__:
connection = Connection(':memory:', flags=apsw.SQLITE_OPEN_READWRITE | apsw.SQLITE_OPEN_CREATE | apsw.SQLITE_OPEN_URI)
else:
connection = Connection(':memory:')
connection.openiters = {}
connection.registered = True
connection.cursor().execute("attach database ':memory:' as mem;", parse=False)
variables.filename = connection.filename
# To avoid db corruption set connection to fullfsync mode when MacOS is detected
if sys.platform == 'darwin':
c = connection.cursor().execute('pragma fullfsync=1;', parse=False)
functionspath=os.path.abspath(__path__[0])
def findmodules(abspath, relativepath):
return [ os.path.splitext(file)[0] for file
in os.listdir(os.path.join(abspath , relativepath))
if file.endswith(".py") and not file.startswith("_") ]
## Register main functions of madis (functions)
rowfiles = findmodules(functionspath, 'row')
aggrfiles = findmodules(functionspath, 'aggregate')
vtabfiles = findmodules(functionspath, 'vtable')
[__import__("functions.row" + "." + module) for module in rowfiles]
[__import__("functions.aggregate" + "." + module) for module in aggrfiles]
[__import__("functions.vtable" + "." + module) for module in vtabfiles]
# Register aggregate functions
for module in aggrfiles:
moddict = aggregate.__dict__[module]
register_ops(moddict,connection)
# Register row functions
for module in rowfiles:
moddict = row.__dict__[module]
register_ops(moddict,connection)
register_ops(vtable,connection)
## Register madis local functions (functionslocal)
functionslocalpath=os.path.abspath(os.path.join(functionspath,'..','functionslocal'))
flrowfiles = findmodules(functionslocalpath, 'row')
flaggrfiles = findmodules(functionslocalpath, 'aggregate')
flvtabfiles = findmodules(functionslocalpath, 'vtable')
for module in flrowfiles:
tmp=__import__("functionslocal.row." + module)
register_ops(tmp.row.__dict__[module], connection)
for module in flaggrfiles:
tmp=__import__("functionslocal.aggregate." + module)
register_ops(tmp.aggregate.__dict__[module], connection)
localvtable=lambda x:x
for module in flvtabfiles:
localvtable.__dict__[module]=__import__("functionslocal.vtable." + module, fromlist=['functionslocal.vtable'])
if len(flvtabfiles)!=0:
register_ops(localvtable,connection)
## Register db local functions (functions in db path)
if variables.execdb!=oldexecdb:
oldexecdb=variables.execdb
dbpath=None
if variables.execdb!=None:
dbpath=os.path.join(os.path.abspath(os.path.dirname(variables.execdb)),'functions')
if dbpath==None or not os.path.exists(dbpath):
currentpath=os.path.abspath(os.path.join(os.path.abspath('.'), 'functions'))
if os.path.exists(currentpath):
dbpath=currentpath
if dbpath!=None and os.path.exists(dbpath):
if os.path.abspath(dbpath)!=os.path.abspath(functionspath):
sys.path.append(dbpath)
if os.path.exists(os.path.join(dbpath, 'row')):
lrowfiles = findmodules(dbpath, 'row')
sys.path.append((os.path.abspath(os.path.join(os.path.join(dbpath),'row'))))
for module in lrowfiles:
tmp=__import__(module)
register_ops(tmp, connection)
if os.path.exists(os.path.join(dbpath, 'aggregate')):
sys.path.append((os.path.abspath(os.path.join(os.path.join(dbpath),'aggregate'))))
laggrfiles = findmodules(dbpath, 'aggregate')
for module in laggrfiles:
tmp=__import__(module)
register_ops(tmp, connection)
if os.path.exists(os.path.join(dbpath, 'vtable')):
sys.path.append((os.path.abspath(os.path.join(os.path.join(dbpath),'vtable'))))
lvtabfiles = findmodules(dbpath, 'vtable')
tmp=lambda x:x
for module in lvtabfiles:
tmp.__dict__[module]=__import__(module)
if localvtable!=None:
register_ops(tmp,connection)
firstimport=False
def register_ops(module, connection):
global rowfuncs, firstimport
def opexists(op):
if firstimport:
return op in functions['vtable'] or op in functions['row'] or op in functions['aggregate']
else:
return False
def wrapfunction(con, opfun):
return lambda *args: iterwrapper(con, opfun, *args)
def wrapaggr(con, opfun):
return lambda self: iterwrapperaggr(con, opfun, self)
def wrapaggregatefactory(wlambda):
return lambda cls: (cls(), cls.step, wlambda)
for f in module.__dict__:
fobject = module.__dict__[f]
if hasattr(fobject, 'registered') and type(fobject.registered).__name__ == 'bool' and fobject.registered == True:
opname=f.lower()
if firstimport:
if opname!=f:
raise MadisError("Extended SQLERROR: Function '"+module.__name__+'.'+f+"' uses uppercase characters. Functions should be lowercase")
if opname.upper() in sqltransform.sqlparse.keywords.KEYWORDS:
raise MadisError("Extended SQLERROR: Function '"+module.__name__+'.'+opname+"' is a reserved SQL function")
if type(fobject).__name__ == 'module':
if opexists(opname):
raise MadisError("Extended SQLERROR: Vtable '"+opname+"' name collision with other operator")
functions['vtable'][opname] = fobject
modinstance = fobject.Source()
modinstance._madisVT = True
connection.createmodule(opname, modinstance)
if type(fobject).__name__ == 'function':
if opexists(opname):
raise MadisError("Extended SQLERROR: Row operator '"+module.__name__+'.'+opname+"' name collision with other operator")
functions['row'][opname] = fobject
if isgeneratorfunction(fobject):
fobject=wrapfunction(connection, fobject)
fobject.multiset=True
setattr(rowfuncs, opname, fobject)
connection.createscalarfunction(opname, fobject)
if type(fobject).__name__ == 'classobj':
if opexists(opname):
raise MadisError("Extended SQLERROR: Aggregate operator '"+module.__name__+'.'+opname+"' name collision with other operator")
functions['aggregate'][opname] = fobject
if isgeneratorfunction(fobject.final):
wlambda = wrapaggr(connection, fobject.final)
fobject.multiset = True
setattr(fobject, 'factory', classmethod(wrapaggregatefactory(wlambda)))
connection.createaggregatefunction(opname, fobject.factory)
else:
setattr(fobject, 'factory', classmethod(lambda cls:(cls(), cls.step, cls.final)))
connection.createaggregatefunction(opname, fobject.factory)
try:
if fobject.multiset:
multiset_functions[opname] = True
except:
pass
def testfunction():
global test_connection, settings
test_connection = Connection(':memory:')
register(test_connection)
variables.execdb=':memory:'
def settestdb(testdb):
global test_connection, settings
abstestdb=str(os.path.abspath(os.path.expandvars(os.path.expanduser(os.path.normcase(testdb)))))
test_connection = Connection(abstestdb)
register(test_connection)
variables.execdb=abstestdb
def sql(sqlquery):
import locale
from lib import pptable
global test_connection
language, output_encoding = locale.getdefaultlocale()
if output_encoding==None:
output_encoding="UTF8"
test_cursor=test_connection.cursor()
e=test_cursor.execute(sqlquery.decode(output_encoding))
try:
desc=test_cursor.getdescription()
print pptable.indent([[x[0] for x in desc]]+[x for x in e], hasHeader=True),
except apsw.ExecutionCompleteError:
print '',
test_cursor.close()
def table(tab, num=''):
import shlex
"""
Creates a test table named "table". It's columns are fitted to the data
given to it and are automatically named a, b, c, ...
'num' parameter:
If a 'num' parameter is given then the table will be named for example
table1 when num=1, table2 when num=2 ...
Example:
table('''
1 2 3
4 5 6
''')
will create a table named 'table' having the following data:
a b c
---------
1 2 3
4 5 6
"""
colnames="abcdefghijklmnop"
import re
tab=tab.splitlines()
tab=[re.sub(r'[\s\t]+',' ',x.strip()) for x in tab]
tab=[x for x in tab if x!='']
# Convert NULL to None
tab=[[(y if y!='NULL' else None) for y in shlex.split(x)] for x in tab]
numberofcols=len(tab[0])
if num=='':
num='0'
createsql='create table table'+str(num)+'('
insertsql="insert into table"+str(num)+" values("
for i in range(0,numberofcols):
createsql=createsql+colnames[i]+' str'+','
insertsql=insertsql+'?,'
createsql=createsql[0:-1]+')'
insertsql=insertsql[0:-1]+')'
test_cursor=test_connection.cursor()
try:
test_cursor.execute(createsql)
except:
test_cursor.execute("drop table table"+str(num))
test_cursor.execute(createsql)
test_cursor.executemany(insertsql, tab)
def table1(tab):
table(tab, num=1)
def table2(tab):
table(tab, num=2)
def table3(tab):
table(tab, num=3)
def table4(tab):
table(tab, num=4)
def table5(tab):
table(tab, num=5)
def table6(tab):
table(tab, num=6)
def setlogfile(file):
pass