• Facebook
  • Twitter
  • Reddit
  • StumbleUpon
  • Digg
  • email

#! /usr/bin/env python
 
"""
usage: %(progname)s [args]
"""
 
 
import os, sys, string, time, getopt
from log import *
 
import odb
import sqlite
 
import re
 
# --- these are using for removing nulls from strings
# --- because sqlite can't handle them
 
def escape_string(str):
    def subfn(m):
        c = m.group(0)
        return "%%%02X" % ord(c)
 
    return re.sub("('|\0|%)",subfn,str)
 
def unescape_string(str):
    def subfn(m):
        hexnum = int(m.group(1),16)
        return "%c" % hexnum
    return re.sub("%(..)",subfn,str)
 
class Database(odb.Database):
  def __init__(self,db, debug=0):
    odb.Database.__init__(self, db, debug=debug)
    self.SQLError = sqlite.Error
 
  def escape(self,str):
    if str is None:
      return None
    elif type(str) == type(""):
      return string.replace(str,"'","''")
    elif type(str) == type(1):
      return str
    else:
      raise "unknown column data type: %s" % type(str)
 
 
  def listTables(self, cursor=None):
    if cursor is None: cursor = self.defaultCursor()
    cursor.execute("select name from sqlite_master where type='table'")
    rows = cursor.fetchall()
    tables = []
    for row in rows: tables.append(row[0])
    return tables
 
  def listIndices(self, cursor=None):
    if cursor is None: cursor = self.defaultCursor()
    cursor.execute("select name from sqlite_master where type='index'")
    rows = cursor.fetchall()
    tables = []
    for row in rows: tables.append(row[0])
    return tables
 
  def listFieldsDict(self, table_name, cursor=None):
    if cursor is None: cursor = self.defaultCursor()
    sql = "pragma table_info(%s)" % table_name
    cursor.execute(sql)
    rows = cursor.fetchall()
 
    columns = {}
    for row in rows:
      colname = row[1]
      columns[colname] = row
    return columns
 
  def _tableCreateStatement(self, table_name, cursor=None):
    if cursor is None: cursor = self.defaultCursor()
    sql = "select sql from sqlite_master where type='table' and name='%s'" % table_name
    print sql
    cursor.execute(sql)
    row = cursor.fetchone()
    sqlstatement = row[0]
    return sqlstatement
 
 
  def alterTableToMatch(self, table):
    tableName = table.getTableName()
    tmpTableName = tableName + "_" + str(os.getpid())
 
 
    invalidAppCols, invalidDBCols = table.checkTable(warnflag=0)
 
##     if invalidAppCols or invalidDBCols:
##       return
 
    if not invalidAppCols and not invalidDBCols:
      return
 
 
    oldcols = self.listFieldsDict(tableName)
#    tmpcols = oldcols.keys()
 
    tmpcols = []
    newcols = table.getAppColumnList()
    for colname, coltype, options in newcols:
      if oldcols.has_key(colname): tmpcols.append(colname)
 
    tmpcolnames = string.join(tmpcols, ",")
 
    statements = []
 
    sql = "begin transaction"
    statements.append(sql)
 
    sql = "create temporary table %s (%s)" % (tmpTableName, tmpcolnames)
    statements.append(sql)
 
    sql = "insert into %s select %s from %s" % (tmpTableName, tmpcolnames, tableName)
    statements.append(sql)
 
    sql = "drop table %s" % tableName
    statements.append(sql)
 
    sql = table._createTableSQL()
    statements.append(sql)
 
    sql = "insert into %s(%s) select %s from %s" % (tableName, tmpcolnames, tmpcolnames, tmpTableName)
    statements.append(sql)
 
    sql = "drop table %s" % tmpTableName
    statements.append(sql)
 
    sql = "commit"
    statements.append(sql)
 
    cur = self.defaultCursor()
    for statement in statements:
#      print statement
      cur.execute(statement)
 
 
def test():
  pass
 
def usage(progname):
  print __doc__ % vars()
 
def main(argv, stdout, environ):
  progname = argv[0]
  optlist, args = getopt.getopt(argv[1:], "", ["help", "test", "debug"])
 
  testflag = 0
  if len(args) == 0:
    usage(progname)
    return
  for (field, val) in optlist:
    if field == "--help":
      usage(progname)
      return
    elif field == "--debug":
      debugfull()
    elif field == "--test":
      testflag = 1
 
  if testflag:
    test()
    return
 
 
if __name__ == "__main__":
  main(sys.argv, sys.stdout, os.environ)