diff --git a/tools/testClient/.project b/tools/testClient/.project
new file mode 100644
index 00000000000..c6e95d9c21b
--- /dev/null
+++ b/tools/testClient/.project
@@ -0,0 +1,17 @@
+
+
+ testClient
+
+
+
+
+
+ org.python.pydev.PyDevBuilder
+
+
+
+
+
+ org.python.pydev.pythonNature
+
+
diff --git a/tools/testClient/.pydevproject b/tools/testClient/.pydevproject
new file mode 100644
index 00000000000..505117b6626
--- /dev/null
+++ b/tools/testClient/.pydevproject
@@ -0,0 +1,10 @@
+
+
+
+
+Default
+python 2.7
+
+/testClient/
+
+
diff --git a/tools/testClient/README b/tools/testClient/README
new file mode 100644
index 00000000000..f8e18e74bc0
--- /dev/null
+++ b/tools/testClient/README
@@ -0,0 +1,11 @@
+CloudStack Test Client
+1. generate Cloudstack API python code from an API XML spec file
+ python codegenerator.py -o where-to-put-the-cloudstack-api -s where-the-spec-file
+
+2. Facility it provides:
+ 1. very handy cloudstack API python wrapper
+ 2. support async job executing in parallel
+ 3. remote ssh login/execute command
+ 4. mysql query
+
+3. sample code is under unitTest
diff --git a/tools/testClient/asyncJobMgr.py b/tools/testClient/asyncJobMgr.py
new file mode 100644
index 00000000000..5487ba32775
--- /dev/null
+++ b/tools/testClient/asyncJobMgr.py
@@ -0,0 +1,141 @@
+import threading
+import cloudstackException
+import time
+import Queue
+import copy
+import sys
+
+class job(object):
+ def __init__(self):
+ self.id = None
+ self.cmd = None
+class jobStatus(object):
+ def __init__(self):
+ self.result = None
+ self.status = None
+ self.startTime = None
+ self.endTime = None
+ self.duration = None
+class workThread(threading.Thread):
+ def __init__(self, in_queue, out_dict, apiClient, db=None):
+ threading.Thread.__init__(self)
+ self.inqueue = in_queue
+ self.output = out_dict
+ self.connection = copy.copy(apiClient.connection)
+ self.db = None
+ if db is not None:
+ self.db = copy.copy(db)
+
+ def run(self):
+ while True:
+ job = self.inqueue.get()
+ cmd = job.cmd
+ cmdName = cmd.__class__.__name__
+ responseName = cmdName.replace("Cmd", "Response")
+ responseInstance = self.connection.getclassFromName(cmd, responseName)
+ jobstatus = jobStatus()
+ jobId = None
+ try:
+ if not cmd.isAsync:
+ jobstatus.startTime = time.time()
+ result = self.connection.make_request(cmd, responseInstance)
+ jobstatus.result = result
+ jobstatus.endTime = time.time()
+ else:
+ result = self.connection.make_request(cmd, responseInstance, True)
+ jobId = self.connection.getAsyncJobId(responseInstance, result)
+ result = self.connection.pollAsyncJob(cmd, responseInstance, jobId)
+ jobstatus.result = result
+
+
+ jobstatus.status = True
+ except cloudstackException.cloudstackAPIException, e:
+ jobstatus.result = str(e)
+ jobstatus.status = False
+ except:
+ jobstatus.status = False
+ jobstatus.result = sys.exc_info()
+
+ if self.db is not None and jobId is not None:
+ result = self.db.execute("select created, last_updated from async_job where id=%s"%jobId)
+ if result is not None and len(result) > 0:
+ jobstatus.startTime = result[0][0]
+ jobstatus.endTime = result[0][1]
+ delta = jobstatus.endTime - jobstatus.startTime
+ jobstatus.duration = delta.total_seconds()
+ #print job.id
+ self.output.lock.acquire()
+ self.output.dict[job.id] = jobstatus
+ self.output.lock.release()
+ self.inqueue.task_done()
+
+class jobThread(threading.Thread):
+ def __init__(self, inqueue, interval):
+ threading.Thread.__init__(self)
+ self.inqueue = inqueue
+ self.interval = interval
+ def run(self):
+ while True:
+
+ job = self.inqueue.get()
+ try:
+ job.run()
+ except:
+ pass
+ self.inqueue.task_done()
+ time.sleep(self.interval)
+
+class outputDict(object):
+ def __init__(self):
+ self.lock = threading.Condition()
+ self.dict = {}
+
+class asyncJobMgr(object):
+ def __init__(self, apiClient, db, workers=10):
+ self.inqueue = Queue.Queue()
+ self.output = outputDict()
+ self.apiClient = apiClient
+ self.db = db
+ self.workers = workers
+
+ for i in range(self.workers):
+ worker = workThread(self.inqueue, self.output, self.apiClient, self.db)
+ worker.setDaemon(True)
+ worker.start()
+
+ def submitCmds(self, cmds):
+ if not self.inqueue.empty():
+ return False
+ id = 0
+ ids = []
+ for cmd in cmds:
+ asyncjob = job()
+ asyncjob.id = id
+ asyncjob.cmd = cmd
+ self.inqueue.put(asyncjob)
+ id += 1
+ ids.append(id)
+ return ids
+
+ def waitForComplete(self):
+ self.inqueue.join()
+ return self.output.dict
+
+ def submitCmdsAndWait(self, cmds):
+ self.submitCmds(cmds)
+ return self.waitForComplete()
+
+ def submitJobs(self, job, ntimes=1, nums_threads=1, interval=1):
+ inqueue1 = Queue.Queue()
+ lock = threading.Condition()
+ for i in range(ntimes):
+ newjob = copy.copy(job)
+ setattr(newjob, "apiClient", copy.copy(self.apiClient))
+ setattr(newjob, "lock", lock)
+ inqueue1.put(newjob)
+
+ for i in range(nums_threads):
+ work = jobThread(inqueue1, interval)
+ work.setDaemon(True)
+ work.start()
+ inqueue1.join()
\ No newline at end of file
diff --git a/tools/testClient/cloudstackConnection.py b/tools/testClient/cloudstackConnection.py
new file mode 100644
index 00000000000..4ab5bd9dcc4
--- /dev/null
+++ b/tools/testClient/cloudstackConnection.py
@@ -0,0 +1,244 @@
+import httplib
+import urllib
+import base64
+import hmac
+import hashlib
+import json
+import xml.dom.minidom
+import types
+import time
+import inspect
+import cloudstackException
+from cloudstackAPI import *
+
+class cloudConnection(object):
+ def __init__(self, mgtSvr, port=8096, apiKey = None, securityKey = None, asyncTimeout=3600, logging=None):
+ self.apiKey = apiKey
+ self.securityKey = securityKey
+ self.mgtSvr = mgtSvr
+ self.connection = httplib.HTTPConnection("%s:%d"%(mgtSvr,port))
+ self.port = port
+ self.logging = logging
+ if port == 8096:
+ self.auth = False
+ else:
+ self.auth = True
+
+ self.asyncTimeout = asyncTimeout
+
+ def __copy__(self):
+ return cloudConnection(self.mgtSvr, self.port, self.apiKey, self.securityKey, self.asyncTimeout, self.logging)
+
+ def make_request_with_auth(self, command, requests={}):
+ requests["command"] = command
+ requests["apiKey"] = self.apiKey
+ requests["response"] = "xml"
+ requests = zip(requests.keys(), requests.values())
+ requests.sort(key=lambda x: str.lower(x[0]))
+
+ requestUrl = "&".join(["=".join([request[0], urllib.quote_plus(str(request[1]))]) for request in requests])
+ hashStr = "&".join(["=".join([str.lower(request[0]), urllib.quote_plus(str.lower(str(request[1])))]) for request in requests])
+
+ sig = urllib.quote_plus(base64.encodestring(hmac.new(self.securityKey, hashStr, hashlib.sha1).digest()).strip())
+
+ requestUrl += "&signature=%s"%sig
+
+ self.connection.request("GET", "/client/api?%s"%requestUrl)
+ return self.connection.getresponse().read()
+
+ def make_request_without_auth(self, command, requests={}):
+ requests["command"] = command
+ requests["response"] = "xml"
+ requests = zip(requests.keys(), requests.values())
+ requestUrl = "&".join(["=".join([request[0], urllib.quote_plus(str(request[1]))]) for request in requests])
+ self.connection.request("GET", "/&%s"%requestUrl)
+ return self.connection.getresponse().read()
+
+ def getText(self, elements):
+ if len(elements) < 1:
+ return None
+ if not elements[0].hasChildNodes():
+ return None
+ if elements[0].childNodes[0].nodeValue is None:
+ return None
+ return elements[0].childNodes[0].nodeValue.strip()
+
+ def getclassFromName(self, cmd, name):
+ module = inspect.getmodule(cmd)
+ return getattr(module, name)()
+ def parseOneInstance(self, element, instance):
+ ItemsNeedToCheck = {}
+ for attribute in dir(instance):
+ if attribute != "__doc__" and attribute != "__init__" and attribute != "__module__":
+ ItemsNeedToCheck[attribute] = getattr(instance, attribute)
+ for attribute, value in ItemsNeedToCheck.items():
+ if type(value) == types.ListType:
+ subItem = []
+ for subElement in element.getElementsByTagName(attribute):
+ newInstance = self.getclassFromName(instance, attribute)
+ self.parseOneInstance(subElement, newInstance)
+ subItem.append(newInstance)
+ setattr(instance, attribute, subItem)
+ continue
+ else:
+ item = element.getElementsByTagName(attribute)
+ if len(item) == 0:
+ continue
+
+ returnValue = self.getText(item)
+ setattr(instance, attribute, returnValue)
+
+ def hasErrorCode(self, elements, responseName):
+ errorCode = elements[0].getElementsByTagName("errorcode")
+ if len(errorCode) > 0:
+ erroCodeText = self.getText(errorCode)
+ errorText = elements[0].getElementsByTagName("errortext")
+ if len(errorText) > 0:
+ errorText = self.getText(errorText)
+ errMsg = "errorCode: %s, errorText:%s"%(erroCodeText, errorText)
+ raise cloudstackException.cloudstackAPIException(responseName, errMsg)
+
+ def paraseReturnXML(self, result, response):
+ responseName = response.__class__.__name__.lower()
+ cls = response.__class__
+
+ responseLists = []
+ morethanOne = False
+
+ dom = xml.dom.minidom.parseString(result)
+ elements = dom.getElementsByTagName(responseName)
+ if len(elements) == 0:
+ return responseLists
+
+ self.hasErrorCode(elements, responseName)
+
+ count = elements[0].getElementsByTagName("count")
+ if len(count) > 0:
+ morethanOne = True
+ for node in elements[0].childNodes:
+ if node.nodeName == "count":
+ continue
+ newInstance = cls()
+ self.parseOneInstance(node, newInstance)
+ responseLists.append(newInstance)
+
+ else:
+ if elements[0].hasChildNodes():
+ newInstance = cls()
+ self.parseOneInstance(elements[0], newInstance)
+ responseLists.append(newInstance)
+
+ return responseLists, morethanOne
+
+ def paraseResultFromElement(self, elements, response):
+ responseName = response.__class__.__name__.lower()
+ cls = response.__class__
+
+ responseLists = []
+ morethanOne = False
+
+ newInstance = cls()
+ self.parseOneInstance(elements[0], newInstance)
+ responseLists.append(newInstance)
+
+ return responseLists, morethanOne
+ def getAsyncJobId(self, response, resultXml):
+ responseName = response.__class__.__name__.lower()
+ dom = xml.dom.minidom.parseString(resultXml)
+ elements = dom.getElementsByTagName(responseName)
+ if len(elements) == 0:
+ raise cloudstackException.cloudstackAPIException("can't find %s"%responseName)
+
+ self.hasErrorCode(elements, responseName)
+
+ jobIdEle = elements[0].getElementsByTagName("jobid")
+ if len(jobIdEle) == 0:
+ errMsg = 'can not find jobId in the result:%s'%resultXml
+
+ raise cloudstackException.cloudstackAPIException(errMsg)
+
+ jobId = self.getText(jobIdEle)
+ return jobId
+
+ def pollAsyncJob(self, cmd, response, jobId):
+ commandName = cmd.__class__.__name__.replace("Cmd", "")
+ cmd = queryAsyncJobResult.queryAsyncJobResultCmd()
+ cmd.jobid = jobId
+
+ while self.asyncTimeout > 0:
+ asyncResponse = queryAsyncJobResult.queryAsyncJobResultResponse()
+ responseName = asyncResponse.__class__.__name__.lower()
+ asyncResponseXml = self.make_request(cmd, asyncResponse, True)
+ dom = xml.dom.minidom.parseString(asyncResponseXml)
+ elements = dom.getElementsByTagName(responseName)
+ if len(elements) == 0:
+ raise cloudstackException.cloudstackAPIException("can't find %s"%responseName)
+
+ self.hasErrorCode(elements, responseName)
+
+ jobstatus = self.getText(elements[0].getElementsByTagName("jobstatus"))
+
+ if jobstatus == "2":
+ jobResult = self.getText(elements[0].getElementsByTagName("jobresult"))
+ raise cloudstackException.cloudstackAPIException(commandName, jobResult)
+ elif jobstatus == "1":
+ jobResultEle = elements[0].getElementsByTagName("jobresult")
+
+ return self.paraseResultFromElement(jobResultEle, response)
+
+ time.sleep(5)
+
+ raise cloudstackException.cloudstackAPIException(commandName, "Async job timeout")
+ def make_request(self, cmd, response, raw=False):
+ commandName = cmd.__class__.__name__.replace("Cmd", "")
+ isAsync = "false"
+ requests = {}
+ required = []
+ for attribute in dir(cmd):
+ if attribute != "__doc__" and attribute != "__init__" and attribute != "__module__":
+ if attribute == "isAsync":
+ isAsync = getattr(cmd, attribute)
+ elif attribute == "required":
+ required = getattr(cmd, attribute)
+ else:
+ requests[attribute] = getattr(cmd, attribute)
+
+ for requiredPara in required:
+ if requests[requiredPara] is None:
+ raise cloudstackException.cloudstackAPIException(commandName, "%s is required"%requiredPara)
+ '''remove none value'''
+ for param, value in requests.items():
+ if value is None:
+ requests.pop(param)
+ if self.logging is not None:
+ self.logging.debug("sending command: " + str(requests))
+ result = None
+ if self.auth:
+ result = self.make_request_with_auth(commandName, requests)
+ else:
+ result = self.make_request_without_auth(commandName, requests)
+
+ if self.logging is not None:
+ self.logging.debug("got result: " + result)
+ if result is None:
+ return None
+
+ if raw:
+ return result
+ if isAsync == "false":
+ result,num = self.paraseReturnXML(result, response)
+ else:
+ jobId = self.getAsyncJobId(response, result)
+ result,num = self.pollAsyncJob(cmd, response, jobId)
+ if num:
+ return result
+ else:
+ if len(result) != 0:
+ return result[0]
+ return None
+
+if __name__ == '__main__':
+ xml = '407i-1-407-RS3i-1-407-RS3system1ROOT2011-07-30T14:45:19-0700Runningfalse1CA13kvm-50-2054CentOS 5.5(64-bit) no GUI (KVM)CentOS 5.5(64-bit) no GUI (KVM)false1Small Instance15005121120NetworkFilesystem380203255.255.255.065.19.181.165.19.181.110vlan://65vlan://65GuestDirecttrue06:52:da:00:00:08KVM'
+ conn = cloudConnection(None)
+
+ print conn.paraseReturnXML(xml, deployVirtualMachine.deployVirtualMachineResponse())
\ No newline at end of file
diff --git a/tools/testClient/cloudstackException.py b/tools/testClient/cloudstackException.py
new file mode 100644
index 00000000000..f731be383c7
--- /dev/null
+++ b/tools/testClient/cloudstackException.py
@@ -0,0 +1,24 @@
+
+class cloudstackAPIException(Exception):
+ def __init__(self, cmd = "", result = ""):
+ self.errorMsg = "Execute cmd: %s failed, due to: %s"%(cmd, result)
+ def __str__(self):
+ return self.errorMsg
+
+class InvalidParameterException(Exception):
+ def __init__(self, msg=''):
+ self.errorMsg = msg
+ def __str__(self):
+ return self.errorMsg
+
+class dbException(Exception):
+ def __init__(self, msg=''):
+ self.errorMsg = msg
+ def __str__(self):
+ return self.errorMsg
+
+class internalError(Exception):
+ def __init__(self, msg=''):
+ self.errorMsg = msg
+ def __str__(self):
+ return self.errorMsg
\ No newline at end of file
diff --git a/tools/testClient/cloudstackTestClient.py b/tools/testClient/cloudstackTestClient.py
new file mode 100644
index 00000000000..2ac5a1bb89b
--- /dev/null
+++ b/tools/testClient/cloudstackTestClient.py
@@ -0,0 +1,57 @@
+import cloudstackConnection
+import remoteSSHClient
+import asyncJobMgr
+import dbConnection
+from cloudstackAPI import *
+
+class cloudstackTestClient(object):
+ def __init__(self, mgtSvr, port=8096, apiKey = None, securityKey = None, asyncTimeout=3600, defaultWorkerThreads=10, logging=None):
+ self.connection = cloudstackConnection.cloudConnection(mgtSvr, port, apiKey, securityKey, asyncTimeout, logging)
+ self.apiClient = cloudstackAPIClient.CloudStackAPIClient(self.connection)
+ self.dbConnection = None
+ self.asyncJobMgr = None
+ self.ssh = None
+ self.defaultWorkerThreads = defaultWorkerThreads
+
+
+ def dbConfigure(self, host="localhost", port=3306, user='cloud', passwd='cloud', db='cloud'):
+ self.dbConnection = dbConnection.dbConnection(host, port, user, passwd, db)
+ self.asyncJobMgr = asyncJobMgr.asyncJobMgr(self.apiClient, self.dbConnection)
+
+ def getDbConnection(self):
+ return self.dbConnection
+
+ def remoteSSHConfigure(self, host, port, user, passwd):
+ self.ssh = remoteSSHClient.remoteSSHClient(host, port, user, passwd)
+
+ def executeViaSSH(self, command):
+ if self.ssh is None:
+ return None
+ return self.ssh.execute(command)
+
+ def getSSHClient(self):
+ return self.ssh
+
+ def executeSql(self, sql=None):
+ if sql is None or self.dbConnection is None:
+ return None
+
+ return self.dbConnection.execute()
+
+ def executeSqlFromFile(self, sqlFile=None):
+ if sqlFile is None or self.dbConnection is None:
+ return None
+ return self.dbConnection.executeSqlFromFile(sqlFile)
+
+ def getApiClient(self):
+ return self.apiClient
+
+ def submitCmdsAndWait(self, cmds):
+ if self.asyncJobMgr is None:
+ return None
+ return self.asyncJobMgr.submitCmdsAndWait(cmds)
+
+ def submitJobs(self, job, ntimes=1, nums_threads=10, interval=1):
+ if self.asyncJobMgr is None:
+ return None
+ self.asyncJobMgr.submitJobs(job, ntimes, nums_threads, interval)
\ No newline at end of file
diff --git a/tools/testClient/codegenerator.py b/tools/testClient/codegenerator.py
new file mode 100644
index 00000000000..98e3eda123c
--- /dev/null
+++ b/tools/testClient/codegenerator.py
@@ -0,0 +1,264 @@
+import xml.dom.minidom
+from optparse import OptionParser
+import os
+import sys
+class cmdParameterProperty(object):
+ def __init__(self):
+ self.name = None
+ self.required = False
+ self.desc = ""
+ self.subProperties = []
+
+class cloudStackCmd:
+ def __init__(self):
+ self.name = ""
+ self.desc = ""
+ self.async = "false"
+ self.request = []
+ self.response = []
+
+class codeGenerator:
+ space = " "
+
+ cmdsName = []
+
+ def __init__(self, outputFolder, apiSpecFile):
+ self.cmd = None
+ self.code = ""
+ self.required = []
+ self.subclass = []
+ self.outputFolder = outputFolder
+ self.apiSpecFile = apiSpecFile
+
+ def addAttribute(self, attr, pro):
+ value = pro.value
+ if pro.required:
+ self.required.append(attr)
+ desc = pro.desc
+ if desc is not None:
+ self.code += self.space
+ self.code += "''' " + pro.desc + " '''"
+ self.code += "\n"
+
+ self.code += self.space
+ self.code += attr + " = " + str(value)
+ self.code += "\n"
+
+ def generateSubClass(self, name, properties):
+ '''generate code for sub list'''
+ subclass = 'class %s:\n'%name
+ subclass += self.space + "def __init__(self):\n"
+ for pro in properties:
+ if pro.desc is not None:
+ subclass += self.space + self.space + '""""%s"""\n'%pro.desc
+ if len (pro.subProperties) > 0:
+ subclass += self.space + self.space + 'self.%s = []\n'%pro.name
+ self.generateSubClass(pro.name, pro.subProperties)
+ else:
+ subclass += self.space + self.space + 'self.%s = None\n'%pro.name
+
+ self.subclass.append(subclass)
+ def generate(self, cmd):
+
+ self.cmd = cmd
+ self.cmdsName.append(self.cmd.name)
+ self.code += "\n"
+ self.code += '"""%s"""\n'%self.cmd.desc
+ self.code += 'from baseCmd import *\n'
+ self.code += 'from baseResponse import *\n'
+ self.code += "class %sCmd (baseCmd):\n"%self.cmd.name
+ self.code += self.space + "def __init__(self):\n"
+
+ self.code += self.space + self.space + 'self.isAsync = "%s"\n' %self.cmd.async
+
+ for req in self.cmd.request:
+ if req.desc is not None:
+ self.code += self.space + self.space + '"""%s"""\n'%req.desc
+ if req.required == "true":
+ self.code += self.space + self.space + '"""Required"""\n'
+ self.code += self.space + self.space + 'self.%s = None\n'%req.name
+ if req.required == "true":
+ self.required.append(req.name)
+
+ self.code += self.space + self.space + "self.required = ["
+ for require in self.required:
+ self.code += '"' + require + '",'
+ self.code += "]\n"
+ self.required = []
+
+
+ """generate response code"""
+ subItems = {}
+ self.code += "\n"
+ self.code += 'class %sResponse (baseResponse):\n'%self.cmd.name
+ self.code += self.space + "def __init__(self):\n"
+ for res in self.cmd.response:
+ if res.desc is not None:
+ self.code += self.space + self.space + '"""%s"""\n'%res.desc
+
+ if len(res.subProperties) > 0:
+ self.code += self.space + self.space + 'self.%s = []\n'%res.name
+ self.generateSubClass(res.name, res.subProperties)
+ else:
+ self.code += self.space + self.space + 'self.%s = None\n'%res.name
+ self.code += '\n'
+
+ for subclass in self.subclass:
+ self.code += subclass + "\n"
+
+ fp = open(self.outputFolder + "/cloudstackAPI/%s.py"%self.cmd.name, "w")
+ fp.write(self.code)
+ fp.close()
+ self.code = ""
+ self.subclass = []
+
+
+ def finalize(self):
+ '''generate an api call'''
+
+ header = '"""Test Client for CloudStack API"""\n'
+ imports = "import copy\n"
+ initCmdsList = '__all__ = ['
+ body = ''
+ body += "class CloudStackAPIClient:\n"
+ body += self.space + 'def __init__(self, connection):\n'
+ body += self.space + self.space + 'self.connection = connection\n'
+ body += "\n"
+
+ body += self.space + 'def __copy__(self):\n'
+ body += self.space + self.space + 'return CloudStackAPIClient(copy.copy(self.connection))\n'
+ body += "\n"
+
+ for cmdName in self.cmdsName:
+ body += self.space + 'def %s(self,command):\n'%cmdName
+ body += self.space + self.space + 'response = %sResponse()\n'%cmdName
+ body += self.space + self.space + 'response = self.connection.make_request(command, response)\n'
+ body += self.space + self.space + 'return response\n'
+ body += '\n'
+
+ imports += 'from %s import %sResponse\n'%(cmdName, cmdName)
+ initCmdsList += '"%s",'%cmdName
+
+ fp = open(self.outputFolder + '/cloudstackAPI/cloudstackAPIClient.py', 'w')
+ for item in [header, imports, body]:
+ fp.write(item)
+ fp.close()
+
+ '''generate __init__.py'''
+ initCmdsList += '"cloudstackAPIClient"]'
+ fp = open(self.outputFolder + '/cloudstackAPI/__init__.py', 'w')
+ fp.write(initCmdsList)
+ fp.close()
+
+ fp = open(self.outputFolder + '/cloudstackAPI/baseCmd.py', 'w')
+ basecmd = '"""Base Command"""\n'
+ basecmd += 'class baseCmd:\n'
+ basecmd += self.space + 'pass\n'
+ fp.write(basecmd)
+ fp.close()
+
+ fp = open(self.outputFolder + '/cloudstackAPI/baseResponse.py', 'w')
+ basecmd = '"""Base class for response"""\n'
+ basecmd += 'class baseResponse:\n'
+ basecmd += self.space + 'pass\n'
+ fp.write(basecmd)
+ fp.close()
+
+
+ def constructResponse(self, response):
+ paramProperty = cmdParameterProperty()
+ paramProperty.name = getText(response.getElementsByTagName('name'))
+ paramProperty.desc = getText(response.getElementsByTagName('description'))
+ if paramProperty.name.find('(*)') != -1:
+ '''This is a list'''
+ paramProperty.name = paramProperty.name.split('(*)')[0]
+ for subresponse in response.getElementsByTagName('arguments')[0].getElementsByTagName('arg'):
+ subProperty = self.constructResponse(subresponse)
+ paramProperty.subProperties.append(subProperty)
+ return paramProperty
+
+ def loadCmdFromXML(self):
+ dom = xml.dom.minidom.parse(self.apiSpecFile)
+ cmds = []
+ for cmd in dom.getElementsByTagName("command"):
+ csCmd = cloudStackCmd()
+ csCmd.name = getText(cmd.getElementsByTagName('name'))
+ assert csCmd.name
+
+ desc = getText(cmd.getElementsByTagName('description'))
+ if desc:
+ csCmd.desc = desc
+
+ async = getText(cmd.getElementsByTagName('isAsync'))
+ if async:
+ csCmd.async = async
+
+ for param in cmd.getElementsByTagName("request")[0].getElementsByTagName("arg"):
+ paramProperty = cmdParameterProperty()
+
+ paramProperty.name = getText(param.getElementsByTagName('name'))
+ assert paramProperty.name
+
+ required = param.getElementsByTagName('required')
+ if required:
+ paramProperty.required = getText(required)
+
+ requestDescription = param.getElementsByTagName('description')
+ if requestDescription:
+ paramProperty.desc = getText(requestDescription)
+
+ csCmd.request.append(paramProperty)
+
+ responseEle = cmd.getElementsByTagName("response")[0]
+ for response in responseEle.getElementsByTagName("arg"):
+ if response.parentNode != responseEle:
+ continue
+
+ paramProperty = self.constructResponse(response)
+ csCmd.response.append(paramProperty)
+
+ cmds.append(csCmd)
+ return cmds
+
+ def generateCode(self):
+ cmds = self.loadCmdFromXML()
+ for cmd in cmds:
+ self.generate(cmd)
+ self.finalize()
+
+def getText(elements):
+ return elements[0].childNodes[0].nodeValue.strip()
+
+if __name__ == "__main__":
+ parser = OptionParser()
+
+ parser.add_option("-o", "--output", dest="output", help="the root path where code genereted, default is .")
+ parser.add_option("-s", "--specfile", dest="spec", help="the path and name of the api spec xml file, default is /etc/cloud/cli/commands.xml")
+
+ (options, args) = parser.parse_args()
+
+ apiSpecFile = "/etc/cloud/cli/commands.xml"
+ if options.spec is not None:
+ apiSpecFile = options.spec
+
+ if not os.path.exists(apiSpecFile):
+ print "the spec file %s does not exists"%apiSpecFile
+ print parser.print_help()
+ exit(1)
+
+
+ folder = "."
+ if options.output is not None:
+ folder = options.output
+ apiModule=folder + "/cloudstackAPI"
+ if not os.path.exists(apiModule):
+ try:
+ os.mkdir(apiModule)
+ except:
+ print "Failed to create folder %s, due to %s"%(apiModule,sys.exc_info())
+ print parser.print_help()
+ exit(2)
+
+ cg = codeGenerator(folder, apiSpecFile)
+ cg.generateCode()
+
diff --git a/tools/testClient/dbConnection.py b/tools/testClient/dbConnection.py
new file mode 100644
index 00000000000..18603ccc8c8
--- /dev/null
+++ b/tools/testClient/dbConnection.py
@@ -0,0 +1,73 @@
+import pymysql
+import cloudstackException
+import sys
+import os
+class dbConnection(object):
+ def __init__(self, host="localhost", port=3306, user='cloud', passwd='cloud', db='cloud'):
+ self.host = host
+ self.port = port
+ self.user = user
+ self.passwd = passwd
+ self.database = db
+
+ try:
+ self.db = pymysql.Connect(host=host, port=port, user=user, passwd=passwd, db=db)
+ except:
+ raise cloudstackException.InvalidParameterException(sys.exc_value)
+
+ def __copy__(self):
+ return dbConnection(self.host, self.port, self.user, self.passwd, self.database)
+
+ def execute(self, sql=None):
+ if sql is None:
+ return None
+
+ resultRow = []
+ cursor = None
+ try:
+ cursor = self.db.cursor()
+ cursor.execute(sql)
+
+ result = cursor.fetchall()
+ if result is not None:
+ for r in result:
+ resultRow.append(r)
+ return resultRow
+ except pymysql.MySQLError, e:
+ if cursor is not None:
+ cursor.close()
+ raise cloudstackException.dbException("db Exception:%s"%e[1])
+ except:
+ if cursor is not None:
+ cursor.close()
+ raise cloudstackException.internalError(sys.exc_value)
+
+
+ def executeSqlFromFile(self, fileName=None):
+ if fileName is None:
+ raise cloudstackException.InvalidParameterException("file can't not none")
+
+ if not os.path.exists(fileName):
+ raise cloudstackException.InvalidParameterException("%s not exists"%fileName)
+
+ sqls = open(fileName, "r").read()
+ return self.execute(sqls)
+
+if __name__ == "__main__":
+ db = dbConnection()
+ '''
+ try:
+
+ result = db.executeSqlFromFile("/tmp/server-setup.sql")
+ if result is not None:
+ for r in result:
+ print r[0], r[1]
+ except cloudstackException.dbException, e:
+ print e
+ '''
+ print db.execute("update vm_template set name='fjkd' where id=200")
+ result = db.execute("select created,last_updated from async_job where id=13")
+ print result
+ delta = result[0][1] - result[0][0]
+ print delta.total_seconds()
+ #print db.execute("update vm_template set name='fjkd' where id=200")
\ No newline at end of file
diff --git a/tools/testClient/pymysql/__init__.py b/tools/testClient/pymysql/__init__.py
new file mode 100644
index 00000000000..903107e539a
--- /dev/null
+++ b/tools/testClient/pymysql/__init__.py
@@ -0,0 +1,131 @@
+'''
+PyMySQL: A pure-Python drop-in replacement for MySQLdb.
+
+Copyright (c) 2010 PyMySQL contributors
+
+Permission is hereby granted, free of charge, to any person obtaining a copy
+of this software and associated documentation files (the "Software"), to deal
+in the Software without restriction, including without limitation the rights
+to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+copies of the Software, and to permit persons to whom the Software is
+furnished to do so, subject to the following conditions:
+
+The above copyright notice and this permission notice shall be included in
+all copies or substantial portions of the Software.
+
+THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+THE SOFTWARE.
+
+'''
+
+VERSION = (0, 4, None)
+
+from constants import FIELD_TYPE
+from converters import escape_dict, escape_sequence, escape_string
+from err import Warning, Error, InterfaceError, DataError, \
+ DatabaseError, OperationalError, IntegrityError, InternalError, \
+ NotSupportedError, ProgrammingError, MySQLError
+from times import Date, Time, Timestamp, \
+ DateFromTicks, TimeFromTicks, TimestampFromTicks
+
+import sys
+
+try:
+ frozenset
+except NameError:
+ from sets import ImmutableSet as frozenset
+ try:
+ from sets import BaseSet as set
+ except ImportError:
+ from sets import Set as set
+
+threadsafety = 1
+apilevel = "2.0"
+paramstyle = "format"
+
+class DBAPISet(frozenset):
+
+
+ def __ne__(self, other):
+ if isinstance(other, set):
+ return super(DBAPISet, self).__ne__(self, other)
+ else:
+ return other not in self
+
+ def __eq__(self, other):
+ if isinstance(other, frozenset):
+ return frozenset.__eq__(self, other)
+ else:
+ return other in self
+
+ def __hash__(self):
+ return frozenset.__hash__(self)
+
+
+STRING = DBAPISet([FIELD_TYPE.ENUM, FIELD_TYPE.STRING,
+ FIELD_TYPE.VAR_STRING])
+BINARY = DBAPISet([FIELD_TYPE.BLOB, FIELD_TYPE.LONG_BLOB,
+ FIELD_TYPE.MEDIUM_BLOB, FIELD_TYPE.TINY_BLOB])
+NUMBER = DBAPISet([FIELD_TYPE.DECIMAL, FIELD_TYPE.DOUBLE, FIELD_TYPE.FLOAT,
+ FIELD_TYPE.INT24, FIELD_TYPE.LONG, FIELD_TYPE.LONGLONG,
+ FIELD_TYPE.TINY, FIELD_TYPE.YEAR])
+DATE = DBAPISet([FIELD_TYPE.DATE, FIELD_TYPE.NEWDATE])
+TIME = DBAPISet([FIELD_TYPE.TIME])
+TIMESTAMP = DBAPISet([FIELD_TYPE.TIMESTAMP, FIELD_TYPE.DATETIME])
+DATETIME = TIMESTAMP
+ROWID = DBAPISet()
+
+def Binary(x):
+ """Return x as a binary type."""
+ return str(x)
+
+def Connect(*args, **kwargs):
+ """
+ Connect to the database; see connections.Connection.__init__() for
+ more information.
+ """
+ from connections import Connection
+ return Connection(*args, **kwargs)
+
+def get_client_info(): # for MySQLdb compatibility
+ return '%s.%s.%s' % VERSION
+
+connect = Connection = Connect
+
+# we include a doctored version_info here for MySQLdb compatibility
+version_info = (1,2,2,"final",0)
+
+NULL = "NULL"
+
+__version__ = get_client_info()
+
+def thread_safe():
+ return True # match MySQLdb.thread_safe()
+
+def install_as_MySQLdb():
+ """
+ After this function is called, any application that imports MySQLdb or
+ _mysql will unwittingly actually use
+ """
+ sys.modules["MySQLdb"] = sys.modules["_mysql"] = sys.modules["pymysql"]
+
+__all__ = [
+ 'BINARY', 'Binary', 'Connect', 'Connection', 'DATE', 'Date',
+ 'Time', 'Timestamp', 'DateFromTicks', 'TimeFromTicks', 'TimestampFromTicks',
+ 'DataError', 'DatabaseError', 'Error', 'FIELD_TYPE', 'IntegrityError',
+ 'InterfaceError', 'InternalError', 'MySQLError', 'NULL', 'NUMBER',
+ 'NotSupportedError', 'DBAPISet', 'OperationalError', 'ProgrammingError',
+ 'ROWID', 'STRING', 'TIME', 'TIMESTAMP', 'Warning', 'apilevel', 'connect',
+ 'connections', 'constants', 'converters', 'cursors',
+ 'escape_dict', 'escape_sequence', 'escape_string', 'get_client_info',
+ 'paramstyle', 'threadsafety', 'version_info',
+
+ "install_as_MySQLdb",
+
+ "NULL","__version__",
+ ]
diff --git a/tools/testClient/pymysql/charset.py b/tools/testClient/pymysql/charset.py
new file mode 100644
index 00000000000..10a91bd19f2
--- /dev/null
+++ b/tools/testClient/pymysql/charset.py
@@ -0,0 +1,174 @@
+MBLENGTH = {
+ 8:1,
+ 33:3,
+ 88:2,
+ 91:2
+ }
+
+class Charset:
+ def __init__(self, id, name, collation, is_default):
+ self.id, self.name, self.collation = id, name, collation
+ self.is_default = is_default == 'Yes'
+
+class Charsets:
+ def __init__(self):
+ self._by_id = {}
+
+ def add(self, c):
+ self._by_id[c.id] = c
+
+ def by_id(self, id):
+ return self._by_id[id]
+
+ def by_name(self, name):
+ for c in self._by_id.values():
+ if c.name == name and c.is_default:
+ return c
+
+_charsets = Charsets()
+"""
+Generated with:
+
+mysql -N -s -e "select id, character_set_name, collation_name, is_default
+from information_schema.collations order by id;" | python -c "import sys
+for l in sys.stdin.readlines():
+ id, name, collation, is_default = l.split(chr(9))
+ print '_charsets.add(Charset(%s, \'%s\', \'%s\', \'%s\'))' \
+ % (id, name, collation, is_default.strip())
+"
+
+"""
+_charsets.add(Charset(1, 'big5', 'big5_chinese_ci', 'Yes'))
+_charsets.add(Charset(2, 'latin2', 'latin2_czech_cs', ''))
+_charsets.add(Charset(3, 'dec8', 'dec8_swedish_ci', 'Yes'))
+_charsets.add(Charset(4, 'cp850', 'cp850_general_ci', 'Yes'))
+_charsets.add(Charset(5, 'latin1', 'latin1_german1_ci', ''))
+_charsets.add(Charset(6, 'hp8', 'hp8_english_ci', 'Yes'))
+_charsets.add(Charset(7, 'koi8r', 'koi8r_general_ci', 'Yes'))
+_charsets.add(Charset(8, 'latin1', 'latin1_swedish_ci', 'Yes'))
+_charsets.add(Charset(9, 'latin2', 'latin2_general_ci', 'Yes'))
+_charsets.add(Charset(10, 'swe7', 'swe7_swedish_ci', 'Yes'))
+_charsets.add(Charset(11, 'ascii', 'ascii_general_ci', 'Yes'))
+_charsets.add(Charset(12, 'ujis', 'ujis_japanese_ci', 'Yes'))
+_charsets.add(Charset(13, 'sjis', 'sjis_japanese_ci', 'Yes'))
+_charsets.add(Charset(14, 'cp1251', 'cp1251_bulgarian_ci', ''))
+_charsets.add(Charset(15, 'latin1', 'latin1_danish_ci', ''))
+_charsets.add(Charset(16, 'hebrew', 'hebrew_general_ci', 'Yes'))
+_charsets.add(Charset(18, 'tis620', 'tis620_thai_ci', 'Yes'))
+_charsets.add(Charset(19, 'euckr', 'euckr_korean_ci', 'Yes'))
+_charsets.add(Charset(20, 'latin7', 'latin7_estonian_cs', ''))
+_charsets.add(Charset(21, 'latin2', 'latin2_hungarian_ci', ''))
+_charsets.add(Charset(22, 'koi8u', 'koi8u_general_ci', 'Yes'))
+_charsets.add(Charset(23, 'cp1251', 'cp1251_ukrainian_ci', ''))
+_charsets.add(Charset(24, 'gb2312', 'gb2312_chinese_ci', 'Yes'))
+_charsets.add(Charset(25, 'greek', 'greek_general_ci', 'Yes'))
+_charsets.add(Charset(26, 'cp1250', 'cp1250_general_ci', 'Yes'))
+_charsets.add(Charset(27, 'latin2', 'latin2_croatian_ci', ''))
+_charsets.add(Charset(28, 'gbk', 'gbk_chinese_ci', 'Yes'))
+_charsets.add(Charset(29, 'cp1257', 'cp1257_lithuanian_ci', ''))
+_charsets.add(Charset(30, 'latin5', 'latin5_turkish_ci', 'Yes'))
+_charsets.add(Charset(31, 'latin1', 'latin1_german2_ci', ''))
+_charsets.add(Charset(32, 'armscii8', 'armscii8_general_ci', 'Yes'))
+_charsets.add(Charset(33, 'utf8', 'utf8_general_ci', 'Yes'))
+_charsets.add(Charset(34, 'cp1250', 'cp1250_czech_cs', ''))
+_charsets.add(Charset(35, 'ucs2', 'ucs2_general_ci', 'Yes'))
+_charsets.add(Charset(36, 'cp866', 'cp866_general_ci', 'Yes'))
+_charsets.add(Charset(37, 'keybcs2', 'keybcs2_general_ci', 'Yes'))
+_charsets.add(Charset(38, 'macce', 'macce_general_ci', 'Yes'))
+_charsets.add(Charset(39, 'macroman', 'macroman_general_ci', 'Yes'))
+_charsets.add(Charset(40, 'cp852', 'cp852_general_ci', 'Yes'))
+_charsets.add(Charset(41, 'latin7', 'latin7_general_ci', 'Yes'))
+_charsets.add(Charset(42, 'latin7', 'latin7_general_cs', ''))
+_charsets.add(Charset(43, 'macce', 'macce_bin', ''))
+_charsets.add(Charset(44, 'cp1250', 'cp1250_croatian_ci', ''))
+_charsets.add(Charset(47, 'latin1', 'latin1_bin', ''))
+_charsets.add(Charset(48, 'latin1', 'latin1_general_ci', ''))
+_charsets.add(Charset(49, 'latin1', 'latin1_general_cs', ''))
+_charsets.add(Charset(50, 'cp1251', 'cp1251_bin', ''))
+_charsets.add(Charset(51, 'cp1251', 'cp1251_general_ci', 'Yes'))
+_charsets.add(Charset(52, 'cp1251', 'cp1251_general_cs', ''))
+_charsets.add(Charset(53, 'macroman', 'macroman_bin', ''))
+_charsets.add(Charset(57, 'cp1256', 'cp1256_general_ci', 'Yes'))
+_charsets.add(Charset(58, 'cp1257', 'cp1257_bin', ''))
+_charsets.add(Charset(59, 'cp1257', 'cp1257_general_ci', 'Yes'))
+_charsets.add(Charset(63, 'binary', 'binary', 'Yes'))
+_charsets.add(Charset(64, 'armscii8', 'armscii8_bin', ''))
+_charsets.add(Charset(65, 'ascii', 'ascii_bin', ''))
+_charsets.add(Charset(66, 'cp1250', 'cp1250_bin', ''))
+_charsets.add(Charset(67, 'cp1256', 'cp1256_bin', ''))
+_charsets.add(Charset(68, 'cp866', 'cp866_bin', ''))
+_charsets.add(Charset(69, 'dec8', 'dec8_bin', ''))
+_charsets.add(Charset(70, 'greek', 'greek_bin', ''))
+_charsets.add(Charset(71, 'hebrew', 'hebrew_bin', ''))
+_charsets.add(Charset(72, 'hp8', 'hp8_bin', ''))
+_charsets.add(Charset(73, 'keybcs2', 'keybcs2_bin', ''))
+_charsets.add(Charset(74, 'koi8r', 'koi8r_bin', ''))
+_charsets.add(Charset(75, 'koi8u', 'koi8u_bin', ''))
+_charsets.add(Charset(77, 'latin2', 'latin2_bin', ''))
+_charsets.add(Charset(78, 'latin5', 'latin5_bin', ''))
+_charsets.add(Charset(79, 'latin7', 'latin7_bin', ''))
+_charsets.add(Charset(80, 'cp850', 'cp850_bin', ''))
+_charsets.add(Charset(81, 'cp852', 'cp852_bin', ''))
+_charsets.add(Charset(82, 'swe7', 'swe7_bin', ''))
+_charsets.add(Charset(83, 'utf8', 'utf8_bin', ''))
+_charsets.add(Charset(84, 'big5', 'big5_bin', ''))
+_charsets.add(Charset(85, 'euckr', 'euckr_bin', ''))
+_charsets.add(Charset(86, 'gb2312', 'gb2312_bin', ''))
+_charsets.add(Charset(87, 'gbk', 'gbk_bin', ''))
+_charsets.add(Charset(88, 'sjis', 'sjis_bin', ''))
+_charsets.add(Charset(89, 'tis620', 'tis620_bin', ''))
+_charsets.add(Charset(90, 'ucs2', 'ucs2_bin', ''))
+_charsets.add(Charset(91, 'ujis', 'ujis_bin', ''))
+_charsets.add(Charset(92, 'geostd8', 'geostd8_general_ci', 'Yes'))
+_charsets.add(Charset(93, 'geostd8', 'geostd8_bin', ''))
+_charsets.add(Charset(94, 'latin1', 'latin1_spanish_ci', ''))
+_charsets.add(Charset(95, 'cp932', 'cp932_japanese_ci', 'Yes'))
+_charsets.add(Charset(96, 'cp932', 'cp932_bin', ''))
+_charsets.add(Charset(97, 'eucjpms', 'eucjpms_japanese_ci', 'Yes'))
+_charsets.add(Charset(98, 'eucjpms', 'eucjpms_bin', ''))
+_charsets.add(Charset(99, 'cp1250', 'cp1250_polish_ci', ''))
+_charsets.add(Charset(128, 'ucs2', 'ucs2_unicode_ci', ''))
+_charsets.add(Charset(129, 'ucs2', 'ucs2_icelandic_ci', ''))
+_charsets.add(Charset(130, 'ucs2', 'ucs2_latvian_ci', ''))
+_charsets.add(Charset(131, 'ucs2', 'ucs2_romanian_ci', ''))
+_charsets.add(Charset(132, 'ucs2', 'ucs2_slovenian_ci', ''))
+_charsets.add(Charset(133, 'ucs2', 'ucs2_polish_ci', ''))
+_charsets.add(Charset(134, 'ucs2', 'ucs2_estonian_ci', ''))
+_charsets.add(Charset(135, 'ucs2', 'ucs2_spanish_ci', ''))
+_charsets.add(Charset(136, 'ucs2', 'ucs2_swedish_ci', ''))
+_charsets.add(Charset(137, 'ucs2', 'ucs2_turkish_ci', ''))
+_charsets.add(Charset(138, 'ucs2', 'ucs2_czech_ci', ''))
+_charsets.add(Charset(139, 'ucs2', 'ucs2_danish_ci', ''))
+_charsets.add(Charset(140, 'ucs2', 'ucs2_lithuanian_ci', ''))
+_charsets.add(Charset(141, 'ucs2', 'ucs2_slovak_ci', ''))
+_charsets.add(Charset(142, 'ucs2', 'ucs2_spanish2_ci', ''))
+_charsets.add(Charset(143, 'ucs2', 'ucs2_roman_ci', ''))
+_charsets.add(Charset(144, 'ucs2', 'ucs2_persian_ci', ''))
+_charsets.add(Charset(145, 'ucs2', 'ucs2_esperanto_ci', ''))
+_charsets.add(Charset(146, 'ucs2', 'ucs2_hungarian_ci', ''))
+_charsets.add(Charset(192, 'utf8', 'utf8_unicode_ci', ''))
+_charsets.add(Charset(193, 'utf8', 'utf8_icelandic_ci', ''))
+_charsets.add(Charset(194, 'utf8', 'utf8_latvian_ci', ''))
+_charsets.add(Charset(195, 'utf8', 'utf8_romanian_ci', ''))
+_charsets.add(Charset(196, 'utf8', 'utf8_slovenian_ci', ''))
+_charsets.add(Charset(197, 'utf8', 'utf8_polish_ci', ''))
+_charsets.add(Charset(198, 'utf8', 'utf8_estonian_ci', ''))
+_charsets.add(Charset(199, 'utf8', 'utf8_spanish_ci', ''))
+_charsets.add(Charset(200, 'utf8', 'utf8_swedish_ci', ''))
+_charsets.add(Charset(201, 'utf8', 'utf8_turkish_ci', ''))
+_charsets.add(Charset(202, 'utf8', 'utf8_czech_ci', ''))
+_charsets.add(Charset(203, 'utf8', 'utf8_danish_ci', ''))
+_charsets.add(Charset(204, 'utf8', 'utf8_lithuanian_ci', ''))
+_charsets.add(Charset(205, 'utf8', 'utf8_slovak_ci', ''))
+_charsets.add(Charset(206, 'utf8', 'utf8_spanish2_ci', ''))
+_charsets.add(Charset(207, 'utf8', 'utf8_roman_ci', ''))
+_charsets.add(Charset(208, 'utf8', 'utf8_persian_ci', ''))
+_charsets.add(Charset(209, 'utf8', 'utf8_esperanto_ci', ''))
+_charsets.add(Charset(210, 'utf8', 'utf8_hungarian_ci', ''))
+
+def charset_by_name(name):
+ return _charsets.by_name(name)
+
+def charset_by_id(id):
+ return _charsets.by_id(id)
+
diff --git a/tools/testClient/pymysql/connections.py b/tools/testClient/pymysql/connections.py
new file mode 100644
index 00000000000..8897644ab09
--- /dev/null
+++ b/tools/testClient/pymysql/connections.py
@@ -0,0 +1,928 @@
+# Python implementation of the MySQL client-server protocol
+# http://forge.mysql.com/wiki/MySQL_Internals_ClientServer_Protocol
+
+try:
+ import hashlib
+ sha_new = lambda *args, **kwargs: hashlib.new("sha1", *args, **kwargs)
+except ImportError:
+ import sha
+ sha_new = sha.new
+
+import socket
+try:
+ import ssl
+ SSL_ENABLED = True
+except ImportError:
+ SSL_ENABLED = False
+
+import struct
+import sys
+import os
+import ConfigParser
+
+try:
+ import cStringIO as StringIO
+except ImportError:
+ import StringIO
+
+from charset import MBLENGTH, charset_by_name, charset_by_id
+from cursors import Cursor
+from constants import FIELD_TYPE, FLAG
+from constants import SERVER_STATUS
+from constants.CLIENT import *
+from constants.COMMAND import *
+from util import join_bytes, byte2int, int2byte
+from converters import escape_item, encoders, decoders
+from err import raise_mysql_exception, Warning, Error, \
+ InterfaceError, DataError, DatabaseError, OperationalError, \
+ IntegrityError, InternalError, NotSupportedError, ProgrammingError
+
+DEBUG = False
+
+NULL_COLUMN = 251
+UNSIGNED_CHAR_COLUMN = 251
+UNSIGNED_SHORT_COLUMN = 252
+UNSIGNED_INT24_COLUMN = 253
+UNSIGNED_INT64_COLUMN = 254
+UNSIGNED_CHAR_LENGTH = 1
+UNSIGNED_SHORT_LENGTH = 2
+UNSIGNED_INT24_LENGTH = 3
+UNSIGNED_INT64_LENGTH = 8
+
+DEFAULT_CHARSET = 'latin1'
+
+
+def dump_packet(data):
+
+ def is_ascii(data):
+ if byte2int(data) >= 65 and byte2int(data) <= 122: #data.isalnum():
+ return data
+ return '.'
+ print "packet length %d" % len(data)
+ print "method call[1]: %s" % sys._getframe(1).f_code.co_name
+ print "method call[2]: %s" % sys._getframe(2).f_code.co_name
+ print "method call[3]: %s" % sys._getframe(3).f_code.co_name
+ print "method call[4]: %s" % sys._getframe(4).f_code.co_name
+ print "method call[5]: %s" % sys._getframe(5).f_code.co_name
+ print "-" * 88
+ dump_data = [data[i:i+16] for i in xrange(len(data)) if i%16 == 0]
+ for d in dump_data:
+ print ' '.join(map(lambda x:"%02X" % byte2int(x), d)) + \
+ ' ' * (16 - len(d)) + ' ' * 2 + \
+ ' '.join(map(lambda x:"%s" % is_ascii(x), d))
+ print "-" * 88
+ print ""
+
+def _scramble(password, message):
+ if password == None or len(password) == 0:
+ return int2byte(0)
+ if DEBUG: print 'password=' + password
+ stage1 = sha_new(password).digest()
+ stage2 = sha_new(stage1).digest()
+ s = sha_new()
+ s.update(message)
+ s.update(stage2)
+ result = s.digest()
+ return _my_crypt(result, stage1)
+
+def _my_crypt(message1, message2):
+ length = len(message1)
+ result = struct.pack('B', length)
+ for i in xrange(length):
+ x = (struct.unpack('B', message1[i:i+1])[0] ^ \
+ struct.unpack('B', message2[i:i+1])[0])
+ result += struct.pack('B', x)
+ return result
+
+# old_passwords support ported from libmysql/password.c
+SCRAMBLE_LENGTH_323 = 8
+
+class RandStruct_323(object):
+ def __init__(self, seed1, seed2):
+ self.max_value = 0x3FFFFFFFL
+ self.seed1 = seed1 % self.max_value
+ self.seed2 = seed2 % self.max_value
+
+ def my_rnd(self):
+ self.seed1 = (self.seed1 * 3L + self.seed2) % self.max_value
+ self.seed2 = (self.seed1 + self.seed2 + 33L) % self.max_value
+ return float(self.seed1) / float(self.max_value)
+
+def _scramble_323(password, message):
+ hash_pass = _hash_password_323(password)
+ hash_message = _hash_password_323(message[:SCRAMBLE_LENGTH_323])
+ hash_pass_n = struct.unpack(">LL", hash_pass)
+ hash_message_n = struct.unpack(">LL", hash_message)
+
+ rand_st = RandStruct_323(hash_pass_n[0] ^ hash_message_n[0],
+ hash_pass_n[1] ^ hash_message_n[1])
+ outbuf = StringIO.StringIO()
+ for _ in xrange(min(SCRAMBLE_LENGTH_323, len(message))):
+ outbuf.write(int2byte(int(rand_st.my_rnd() * 31) + 64))
+ extra = int2byte(int(rand_st.my_rnd() * 31))
+ out = outbuf.getvalue()
+ outbuf = StringIO.StringIO()
+ for c in out:
+ outbuf.write(int2byte(byte2int(c) ^ byte2int(extra)))
+ return outbuf.getvalue()
+
+def _hash_password_323(password):
+ nr = 1345345333L
+ add = 7L
+ nr2 = 0x12345671L
+
+ for c in [byte2int(x) for x in password if x not in (' ', '\t')]:
+ nr^= (((nr & 63)+add)*c)+ (nr << 8) & 0xFFFFFFFF
+ nr2= (nr2 + ((nr2 << 8) ^ nr)) & 0xFFFFFFFF
+ add= (add + c) & 0xFFFFFFFF
+
+ r1 = nr & ((1L << 31) - 1L) # kill sign bits
+ r2 = nr2 & ((1L << 31) - 1L)
+
+ # pack
+ return struct.pack(">LL", r1, r2)
+
+def pack_int24(n):
+ return struct.pack('BBB', n&0xFF, (n>>8)&0xFF, (n>>16)&0xFF)
+
+def unpack_uint16(n):
+ return struct.unpack(' len(self.__data):
+ raise Exception('Invalid advance amount (%s) for cursor. '
+ 'Position=%s' % (length, new_position))
+ self.__position = new_position
+
+ def rewind(self, position=0):
+ """Set the position of the data buffer cursor to 'position'."""
+ if position < 0 or position > len(self.__data):
+ raise Exception("Invalid position to rewind cursor to: %s." % position)
+ self.__position = position
+
+ def peek(self, size):
+ """Look at the first 'size' bytes in packet without moving cursor."""
+ result = self.__data[self.__position:(self.__position+size)]
+ if len(result) != size:
+ error = ('Result length not requested length:\n'
+ 'Expected=%s. Actual=%s. Position: %s. Data Length: %s'
+ % (size, len(result), self.__position, len(self.__data)))
+ if DEBUG:
+ print error
+ self.dump()
+ raise AssertionError(error)
+ return result
+
+ def get_bytes(self, position, length=1):
+ """Get 'length' bytes starting at 'position'.
+
+ Position is start of payload (first four packet header bytes are not
+ included) starting at index '0'.
+
+ No error checking is done. If requesting outside end of buffer
+ an empty string (or string shorter than 'length') may be returned!
+ """
+ return self.__data[position:(position+length)]
+
+ def read_length_coded_binary(self):
+ """Read a 'Length Coded Binary' number from the data buffer.
+
+ Length coded numbers can be anywhere from 1 to 9 bytes depending
+ on the value of the first byte.
+ """
+ c = byte2int(self.read(1))
+ if c == NULL_COLUMN:
+ return None
+ if c < UNSIGNED_CHAR_COLUMN:
+ return c
+ elif c == UNSIGNED_SHORT_COLUMN:
+ return unpack_uint16(self.read(UNSIGNED_SHORT_LENGTH))
+ elif c == UNSIGNED_INT24_COLUMN:
+ return unpack_int24(self.read(UNSIGNED_INT24_LENGTH))
+ elif c == UNSIGNED_INT64_COLUMN:
+ # TODO: what was 'longlong'? confirm it wasn't used?
+ return unpack_int64(self.read(UNSIGNED_INT64_LENGTH))
+
+ def read_length_coded_string(self):
+ """Read a 'Length Coded String' from the data buffer.
+
+ A 'Length Coded String' consists first of a length coded
+ (unsigned, positive) integer represented in 1-9 bytes followed by
+ that many bytes of binary data. (For example "cat" would be "3cat".)
+ """
+ length = self.read_length_coded_binary()
+ if length is None:
+ return None
+ return self.read(length)
+
+ def is_ok_packet(self):
+ return byte2int(self.get_bytes(0)) == 0
+
+ def is_eof_packet(self):
+ return byte2int(self.get_bytes(0)) == 254 # 'fe'
+
+ def is_resultset_packet(self):
+ field_count = byte2int(self.get_bytes(0))
+ return field_count >= 1 and field_count <= 250
+
+ def is_error_packet(self):
+ return byte2int(self.get_bytes(0)) == 255
+
+ def check_error(self):
+ if self.is_error_packet():
+ self.rewind()
+ self.advance(1) # field_count == error (we already know that)
+ errno = unpack_uint16(self.read(2))
+ if DEBUG: print "errno = %d" % errno
+ raise_mysql_exception(self.__data)
+
+ def dump(self):
+ dump_packet(self.__data)
+
+
+class FieldDescriptorPacket(MysqlPacket):
+ """A MysqlPacket that represents a specific column's metadata in the result.
+
+ Parsing is automatically done and the results are exported via public
+ attributes on the class such as: db, table_name, name, length, type_code.
+ """
+
+ def __init__(self, *args):
+ MysqlPacket.__init__(self, *args)
+ self.__parse_field_descriptor()
+
+ def __parse_field_descriptor(self):
+ """Parse the 'Field Descriptor' (Metadata) packet.
+
+ This is compatible with MySQL 4.1+ (not compatible with MySQL 4.0).
+ """
+ self.catalog = self.read_length_coded_string()
+ self.db = self.read_length_coded_string()
+ self.table_name = self.read_length_coded_string()
+ self.org_table = self.read_length_coded_string()
+ self.name = self.read_length_coded_string().decode(self.connection.charset)
+ self.org_name = self.read_length_coded_string()
+ self.advance(1) # non-null filler
+ self.charsetnr = struct.unpack(' 2:
+ use_unicode = True
+
+ if compress or named_pipe:
+ raise NotImplementedError, "compress and named_pipe arguments are not supported"
+
+ if ssl and (ssl.has_key('capath') or ssl.has_key('cipher')):
+ raise NotImplementedError, 'ssl options capath and cipher are not supported'
+
+ self.ssl = False
+ if ssl:
+ if not SSL_ENABLED:
+ raise NotImplementedError, "ssl module not found"
+ self.ssl = True
+ client_flag |= SSL
+ for k in ('key', 'cert', 'ca'):
+ v = None
+ if ssl.has_key(k):
+ v = ssl[k]
+ setattr(self, k, v)
+
+ if read_default_group and not read_default_file:
+ if sys.platform.startswith("win"):
+ read_default_file = "c:\\my.ini"
+ else:
+ read_default_file = "/etc/my.cnf"
+
+ if read_default_file:
+ if not read_default_group:
+ read_default_group = "client"
+
+ cfg = ConfigParser.RawConfigParser()
+ cfg.read(os.path.expanduser(read_default_file))
+
+ def _config(key, default):
+ try:
+ return cfg.get(read_default_group,key)
+ except:
+ return default
+
+ user = _config("user",user)
+ passwd = _config("password",passwd)
+ host = _config("host", host)
+ db = _config("db",db)
+ unix_socket = _config("socket",unix_socket)
+ port = _config("port", port)
+ charset = _config("default-character-set", charset)
+
+ self.host = host
+ self.port = port
+ self.user = user
+ self.password = passwd
+ self.db = db
+ self.unix_socket = unix_socket
+ if charset:
+ self.charset = charset
+ self.use_unicode = True
+ else:
+ self.charset = DEFAULT_CHARSET
+ self.use_unicode = False
+
+ if use_unicode is not None:
+ self.use_unicode = use_unicode
+
+ client_flag |= CAPABILITIES
+ client_flag |= MULTI_STATEMENTS
+ if self.db:
+ client_flag |= CONNECT_WITH_DB
+ self.client_flag = client_flag
+
+ self.cursorclass = cursorclass
+ self.connect_timeout = connect_timeout
+
+ self._connect()
+
+ self.messages = []
+ self.set_charset(charset)
+ self.encoders = encoders
+ self.decoders = conv
+
+ self._result = None
+ self._affected_rows = 0
+ self.host_info = "Not connected"
+
+ self.autocommit(False)
+
+ if sql_mode is not None:
+ c = self.cursor()
+ c.execute("SET sql_mode=%s", (sql_mode,))
+
+ self.commit()
+
+ if init_command is not None:
+ c = self.cursor()
+ c.execute(init_command)
+
+ self.commit()
+
+
+ def close(self):
+ ''' Send the quit message and close the socket '''
+ if self.socket is None:
+ raise Error("Already closed")
+ send_data = struct.pack('= i + 1:
+ i += 1
+
+ self.server_capabilities = struct.unpack('= i+12-1:
+ rest_salt = data[i:i+12]
+ self.salt += rest_salt
+
+ def get_server_info(self):
+ return self.server_version
+
+ Warning = Warning
+ Error = Error
+ InterfaceError = InterfaceError
+ DatabaseError = DatabaseError
+ DataError = DataError
+ OperationalError = OperationalError
+ IntegrityError = IntegrityError
+ InternalError = InternalError
+ ProgrammingError = ProgrammingError
+ NotSupportedError = NotSupportedError
+
+# TODO: move OK and EOF packet parsing/logic into a proper subclass
+# of MysqlPacket like has been done with FieldDescriptorPacket.
+class MySQLResult(object):
+
+ def __init__(self, connection):
+ from weakref import proxy
+ self.connection = proxy(connection)
+ self.affected_rows = None
+ self.insert_id = None
+ self.server_status = 0
+ self.warning_count = 0
+ self.message = None
+ self.field_count = 0
+ self.description = None
+ self.rows = None
+ self.has_next = None
+
+ def read(self):
+ self.first_packet = self.connection.read_packet()
+
+ # TODO: use classes for different packet types?
+ if self.first_packet.is_ok_packet():
+ self._read_ok_packet()
+ else:
+ self._read_result_packet()
+
+ def _read_ok_packet(self):
+ self.first_packet.advance(1) # field_count (always '0')
+ self.affected_rows = self.first_packet.read_length_coded_binary()
+ self.insert_id = self.first_packet.read_length_coded_binary()
+ self.server_status = struct.unpack(' 2
+
+try:
+ set
+except NameError:
+ try:
+ from sets import BaseSet as set
+ except ImportError:
+ from sets import Set as set
+
+ESCAPE_REGEX = re.compile(r"[\0\n\r\032\'\"\\]")
+ESCAPE_MAP = {'\0': '\\0', '\n': '\\n', '\r': '\\r', '\032': '\\Z',
+ '\'': '\\\'', '"': '\\"', '\\': '\\\\'}
+
+def escape_item(val, charset):
+ if type(val) in [tuple, list, set]:
+ return escape_sequence(val, charset)
+ if type(val) is dict:
+ return escape_dict(val, charset)
+ if PYTHON3 and hasattr(val, "decode") and not isinstance(val, unicode):
+ # deal with py3k bytes
+ val = val.decode(charset)
+ encoder = encoders[type(val)]
+ val = encoder(val)
+ if type(val) is str:
+ return val
+ val = val.encode(charset)
+ return val
+
+def escape_dict(val, charset):
+ n = {}
+ for k, v in val.items():
+ quoted = escape_item(v, charset)
+ n[k] = quoted
+ return n
+
+def escape_sequence(val, charset):
+ n = []
+ for item in val:
+ quoted = escape_item(item, charset)
+ n.append(quoted)
+ return "(" + ",".join(n) + ")"
+
+def escape_set(val, charset):
+ val = map(lambda x: escape_item(x, charset), val)
+ return ','.join(val)
+
+def escape_bool(value):
+ return str(int(value))
+
+def escape_object(value):
+ return str(value)
+
+escape_int = escape_long = escape_object
+
+def escape_float(value):
+ return ('%.15g' % value)
+
+def escape_string(value):
+ return ("'%s'" % ESCAPE_REGEX.sub(
+ lambda match: ESCAPE_MAP.get(match.group(0)), value))
+
+def escape_unicode(value):
+ return escape_string(value)
+
+def escape_None(value):
+ return 'NULL'
+
+def escape_timedelta(obj):
+ seconds = int(obj.seconds) % 60
+ minutes = int(obj.seconds // 60) % 60
+ hours = int(obj.seconds // 3600) % 24 + int(obj.days) * 24
+ return escape_string('%02d:%02d:%02d' % (hours, minutes, seconds))
+
+def escape_time(obj):
+ s = "%02d:%02d:%02d" % (int(obj.hour), int(obj.minute),
+ int(obj.second))
+ if obj.microsecond:
+ s += ".%f" % obj.microsecond
+
+ return escape_string(s)
+
+def escape_datetime(obj):
+ return escape_string(obj.strftime("%Y-%m-%d %H:%M:%S"))
+
+def escape_date(obj):
+ return escape_string(obj.strftime("%Y-%m-%d"))
+
+def escape_struct_time(obj):
+ return escape_datetime(datetime.datetime(*obj[:6]))
+
+def convert_datetime(connection, field, obj):
+ """Returns a DATETIME or TIMESTAMP column value as a datetime object:
+
+ >>> datetime_or_None('2007-02-25 23:06:20')
+ datetime.datetime(2007, 2, 25, 23, 6, 20)
+ >>> datetime_or_None('2007-02-25T23:06:20')
+ datetime.datetime(2007, 2, 25, 23, 6, 20)
+
+ Illegal values are returned as None:
+
+ >>> datetime_or_None('2007-02-31T23:06:20') is None
+ True
+ >>> datetime_or_None('0000-00-00 00:00:00') is None
+ True
+
+ """
+ if not isinstance(obj, unicode):
+ obj = obj.decode(connection.charset)
+ if ' ' in obj:
+ sep = ' '
+ elif 'T' in obj:
+ sep = 'T'
+ else:
+ return convert_date(connection, field, obj)
+
+ try:
+ ymd, hms = obj.split(sep, 1)
+ return datetime.datetime(*[ int(x) for x in ymd.split('-')+hms.split(':') ])
+ except ValueError:
+ return convert_date(connection, field, obj)
+
+def convert_timedelta(connection, field, obj):
+ """Returns a TIME column as a timedelta object:
+
+ >>> timedelta_or_None('25:06:17')
+ datetime.timedelta(1, 3977)
+ >>> timedelta_or_None('-25:06:17')
+ datetime.timedelta(-2, 83177)
+
+ Illegal values are returned as None:
+
+ >>> timedelta_or_None('random crap') is None
+ True
+
+ Note that MySQL always returns TIME columns as (+|-)HH:MM:SS, but
+ can accept values as (+|-)DD HH:MM:SS. The latter format will not
+ be parsed correctly by this function.
+ """
+ from math import modf
+ try:
+ if not isinstance(obj, unicode):
+ obj = obj.decode(connection.charset)
+ hours, minutes, seconds = tuple([int(x) for x in obj.split(':')])
+ tdelta = datetime.timedelta(
+ hours = int(hours),
+ minutes = int(minutes),
+ seconds = int(seconds),
+ microseconds = int(modf(float(seconds))[0]*1000000),
+ )
+ return tdelta
+ except ValueError:
+ return None
+
+def convert_time(connection, field, obj):
+ """Returns a TIME column as a time object:
+
+ >>> time_or_None('15:06:17')
+ datetime.time(15, 6, 17)
+
+ Illegal values are returned as None:
+
+ >>> time_or_None('-25:06:17') is None
+ True
+ >>> time_or_None('random crap') is None
+ True
+
+ Note that MySQL always returns TIME columns as (+|-)HH:MM:SS, but
+ can accept values as (+|-)DD HH:MM:SS. The latter format will not
+ be parsed correctly by this function.
+
+ Also note that MySQL's TIME column corresponds more closely to
+ Python's timedelta and not time. However if you want TIME columns
+ to be treated as time-of-day and not a time offset, then you can
+ use set this function as the converter for FIELD_TYPE.TIME.
+ """
+ from math import modf
+ try:
+ hour, minute, second = obj.split(':')
+ return datetime.time(hour=int(hour), minute=int(minute),
+ second=int(second),
+ microsecond=int(modf(float(second))[0]*1000000))
+ except ValueError:
+ return None
+
+def convert_date(connection, field, obj):
+ """Returns a DATE column as a date object:
+
+ >>> date_or_None('2007-02-26')
+ datetime.date(2007, 2, 26)
+
+ Illegal values are returned as None:
+
+ >>> date_or_None('2007-02-31') is None
+ True
+ >>> date_or_None('0000-00-00') is None
+ True
+
+ """
+ try:
+ if not isinstance(obj, unicode):
+ obj = obj.decode(connection.charset)
+ return datetime.date(*[ int(x) for x in obj.split('-', 2) ])
+ except ValueError:
+ return None
+
+def convert_mysql_timestamp(connection, field, timestamp):
+ """Convert a MySQL TIMESTAMP to a Timestamp object.
+
+ MySQL >= 4.1 returns TIMESTAMP in the same format as DATETIME:
+
+ >>> mysql_timestamp_converter('2007-02-25 22:32:17')
+ datetime.datetime(2007, 2, 25, 22, 32, 17)
+
+ MySQL < 4.1 uses a big string of numbers:
+
+ >>> mysql_timestamp_converter('20070225223217')
+ datetime.datetime(2007, 2, 25, 22, 32, 17)
+
+ Illegal values are returned as None:
+
+ >>> mysql_timestamp_converter('2007-02-31 22:32:17') is None
+ True
+ >>> mysql_timestamp_converter('00000000000000') is None
+ True
+
+ """
+ if not isinstance(timestamp, unicode):
+ timestamp = timestamp.decode(connection.charset)
+
+ if timestamp[4] == '-':
+ return convert_datetime(connection, field, timestamp)
+ timestamp += "0"*(14-len(timestamp)) # padding
+ year, month, day, hour, minute, second = \
+ int(timestamp[:4]), int(timestamp[4:6]), int(timestamp[6:8]), \
+ int(timestamp[8:10]), int(timestamp[10:12]), int(timestamp[12:14])
+ try:
+ return datetime.datetime(year, month, day, hour, minute, second)
+ except ValueError:
+ return None
+
+def convert_set(s):
+ return set(s.split(","))
+
+def convert_bit(connection, field, b):
+ #b = "\x00" * (8 - len(b)) + b # pad w/ zeroes
+ #return struct.unpack(">Q", b)[0]
+ #
+ # the snippet above is right, but MySQLdb doesn't process bits,
+ # so we shouldn't either
+ return b
+
+def convert_characters(connection, field, data):
+ field_charset = charset_by_id(field.charsetnr).name
+ if field.flags & FLAG.SET:
+ return convert_set(data.decode(field_charset))
+ if field.flags & FLAG.BINARY:
+ return data
+
+ if connection.use_unicode:
+ data = data.decode(field_charset)
+ elif connection.charset != field_charset:
+ data = data.decode(field_charset)
+ data = data.encode(connection.charset)
+ return data
+
+def convert_int(connection, field, data):
+ return int(data)
+
+def convert_long(connection, field, data):
+ return long(data)
+
+def convert_float(connection, field, data):
+ return float(data)
+
+encoders = {
+ bool: escape_bool,
+ int: escape_int,
+ long: escape_long,
+ float: escape_float,
+ str: escape_string,
+ unicode: escape_unicode,
+ tuple: escape_sequence,
+ list:escape_sequence,
+ set:escape_sequence,
+ dict:escape_dict,
+ type(None):escape_None,
+ datetime.date: escape_date,
+ datetime.datetime : escape_datetime,
+ datetime.timedelta : escape_timedelta,
+ datetime.time : escape_time,
+ time.struct_time : escape_struct_time,
+ }
+
+decoders = {
+ FIELD_TYPE.BIT: convert_bit,
+ FIELD_TYPE.TINY: convert_int,
+ FIELD_TYPE.SHORT: convert_int,
+ FIELD_TYPE.LONG: convert_long,
+ FIELD_TYPE.FLOAT: convert_float,
+ FIELD_TYPE.DOUBLE: convert_float,
+ FIELD_TYPE.DECIMAL: convert_float,
+ FIELD_TYPE.NEWDECIMAL: convert_float,
+ FIELD_TYPE.LONGLONG: convert_long,
+ FIELD_TYPE.INT24: convert_int,
+ FIELD_TYPE.YEAR: convert_int,
+ FIELD_TYPE.TIMESTAMP: convert_mysql_timestamp,
+ FIELD_TYPE.DATETIME: convert_datetime,
+ FIELD_TYPE.TIME: convert_timedelta,
+ FIELD_TYPE.DATE: convert_date,
+ FIELD_TYPE.SET: convert_set,
+ FIELD_TYPE.BLOB: convert_characters,
+ FIELD_TYPE.TINY_BLOB: convert_characters,
+ FIELD_TYPE.MEDIUM_BLOB: convert_characters,
+ FIELD_TYPE.LONG_BLOB: convert_characters,
+ FIELD_TYPE.STRING: convert_characters,
+ FIELD_TYPE.VAR_STRING: convert_characters,
+ FIELD_TYPE.VARCHAR: convert_characters,
+ #FIELD_TYPE.BLOB: str,
+ #FIELD_TYPE.STRING: str,
+ #FIELD_TYPE.VAR_STRING: str,
+ #FIELD_TYPE.VARCHAR: str
+ }
+conversions = decoders # for MySQLdb compatibility
+
+try:
+ # python version > 2.3
+ from decimal import Decimal
+ def convert_decimal(connection, field, data):
+ data = data.decode(connection.charset)
+ return Decimal(data)
+ decoders[FIELD_TYPE.DECIMAL] = convert_decimal
+ decoders[FIELD_TYPE.NEWDECIMAL] = convert_decimal
+
+ def escape_decimal(obj):
+ return unicode(obj)
+ encoders[Decimal] = escape_decimal
+
+except ImportError:
+ pass
diff --git a/tools/testClient/pymysql/cursors.py b/tools/testClient/pymysql/cursors.py
new file mode 100644
index 00000000000..4e10f83f4fa
--- /dev/null
+++ b/tools/testClient/pymysql/cursors.py
@@ -0,0 +1,297 @@
+# -*- coding: utf-8 -*-
+import struct
+import re
+
+try:
+ import cStringIO as StringIO
+except ImportError:
+ import StringIO
+
+from err import Warning, Error, InterfaceError, DataError, \
+ DatabaseError, OperationalError, IntegrityError, InternalError, \
+ NotSupportedError, ProgrammingError
+
+insert_values = re.compile(r'\svalues\s*(\(.+\))', re.IGNORECASE)
+
+class Cursor(object):
+ '''
+ This is the object you use to interact with the database.
+ '''
+ def __init__(self, connection):
+ '''
+ Do not create an instance of a Cursor yourself. Call
+ connections.Connection.cursor().
+ '''
+ from weakref import proxy
+ self.connection = proxy(connection)
+ self.description = None
+ self.rownumber = 0
+ self.rowcount = -1
+ self.arraysize = 1
+ self._executed = None
+ self.messages = []
+ self.errorhandler = connection.errorhandler
+ self._has_next = None
+ self._rows = ()
+
+ def __del__(self):
+ '''
+ When this gets GC'd close it.
+ '''
+ self.close()
+
+ def close(self):
+ '''
+ Closing a cursor just exhausts all remaining data.
+ '''
+ if not self.connection:
+ return
+ try:
+ while self.nextset():
+ pass
+ except:
+ pass
+
+ self.connection = None
+
+ def _get_db(self):
+ if not self.connection:
+ self.errorhandler(self, ProgrammingError, "cursor closed")
+ return self.connection
+
+ def _check_executed(self):
+ if not self._executed:
+ self.errorhandler(self, ProgrammingError, "execute() first")
+
+ def setinputsizes(self, *args):
+ """Does nothing, required by DB API."""
+
+ def setoutputsizes(self, *args):
+ """Does nothing, required by DB API."""
+
+ def nextset(self):
+ ''' Get the next query set '''
+ if self._executed:
+ self.fetchall()
+ del self.messages[:]
+
+ if not self._has_next:
+ return None
+ connection = self._get_db()
+ connection.next_result()
+ self._do_get_result()
+ return True
+
+ def execute(self, query, args=None):
+ ''' Execute a query '''
+ from sys import exc_info
+
+ conn = self._get_db()
+ charset = conn.charset
+ del self.messages[:]
+
+ # TODO: make sure that conn.escape is correct
+
+ if args is not None:
+ if isinstance(args, tuple) or isinstance(args, list):
+ escaped_args = tuple(conn.escape(arg) for arg in args)
+ elif isinstance(args, dict):
+ escaped_args = dict((key, conn.escape(val)) for (key, val) in args.items())
+ else:
+ #If it's not a dictionary let's try escaping it anyways.
+ #Worst case it will throw a Value error
+ escaped_args = conn.escape(args)
+
+ query = query % escaped_args
+
+ if isinstance(query, unicode):
+ query = query.encode(charset)
+
+ result = 0
+ try:
+ result = self._query(query)
+ except:
+ exc, value, tb = exc_info()
+ del tb
+ self.messages.append((exc,value))
+ self.errorhandler(self, exc, value)
+
+ self._executed = query
+ return result
+
+ def executemany(self, query, args):
+ ''' Run several data against one query '''
+ del self.messages[:]
+ #conn = self._get_db()
+ if not args:
+ return
+ #charset = conn.charset
+ #if isinstance(query, unicode):
+ # query = query.encode(charset)
+
+ self.rowcount = sum([ self.execute(query, arg) for arg in args ])
+ return self.rowcount
+
+
+ def callproc(self, procname, args=()):
+ """Execute stored procedure procname with args
+
+ procname -- string, name of procedure to execute on server
+
+ args -- Sequence of parameters to use with procedure
+
+ Returns the original args.
+
+ Compatibility warning: PEP-249 specifies that any modified
+ parameters must be returned. This is currently impossible
+ as they are only available by storing them in a server
+ variable and then retrieved by a query. Since stored
+ procedures return zero or more result sets, there is no
+ reliable way to get at OUT or INOUT parameters via callproc.
+ The server variables are named @_procname_n, where procname
+ is the parameter above and n is the position of the parameter
+ (from zero). Once all result sets generated by the procedure
+ have been fetched, you can issue a SELECT @_procname_0, ...
+ query using .execute() to get any OUT or INOUT values.
+
+ Compatibility warning: The act of calling a stored procedure
+ itself creates an empty result set. This appears after any
+ result sets generated by the procedure. This is non-standard
+ behavior with respect to the DB-API. Be sure to use nextset()
+ to advance through all result sets; otherwise you may get
+ disconnected.
+ """
+ conn = self._get_db()
+ for index, arg in enumerate(args):
+ q = "SET @_%s_%d=%s" % (procname, index, conn.escape(arg))
+ if isinstance(q, unicode):
+ q = q.encode(conn.charset)
+ self._query(q)
+ self.nextset()
+
+ q = "CALL %s(%s)" % (procname,
+ ','.join(['@_%s_%d' % (procname, i)
+ for i in range(len(args))]))
+ if isinstance(q, unicode):
+ q = q.encode(conn.charset)
+ self._query(q)
+ self._executed = q
+
+ return args
+
+ def fetchone(self):
+ ''' Fetch the next row '''
+ self._check_executed()
+ if self._rows is None or self.rownumber >= len(self._rows):
+ return None
+ result = self._rows[self.rownumber]
+ self.rownumber += 1
+ return result
+
+ def fetchmany(self, size=None):
+ ''' Fetch several rows '''
+ self._check_executed()
+ end = self.rownumber + (size or self.arraysize)
+ result = self._rows[self.rownumber:end]
+ if self._rows is None:
+ return None
+ self.rownumber = min(end, len(self._rows))
+ return result
+
+ def fetchall(self):
+ ''' Fetch all the rows '''
+ self._check_executed()
+ if self._rows is None:
+ return None
+ if self.rownumber:
+ result = self._rows[self.rownumber:]
+ else:
+ result = self._rows
+ self.rownumber = len(self._rows)
+ return result
+
+ def scroll(self, value, mode='relative'):
+ self._check_executed()
+ if mode == 'relative':
+ r = self.rownumber + value
+ elif mode == 'absolute':
+ r = value
+ else:
+ self.errorhandler(self, ProgrammingError,
+ "unknown scroll mode %s" % mode)
+
+ if r < 0 or r >= len(self._rows):
+ self.errorhandler(self, IndexError, "out of range")
+ self.rownumber = r
+
+ def _query(self, q):
+ conn = self._get_db()
+ self._last_executed = q
+ conn.query(q)
+ self._do_get_result()
+ return self.rowcount
+
+ def _do_get_result(self):
+ conn = self._get_db()
+ self.rowcount = conn._result.affected_rows
+
+ self.rownumber = 0
+ self.description = conn._result.description
+ self.lastrowid = conn._result.insert_id
+ self._rows = conn._result.rows
+ self._has_next = conn._result.has_next
+
+ def __iter__(self):
+ return iter(self.fetchone, None)
+
+ Warning = Warning
+ Error = Error
+ InterfaceError = InterfaceError
+ DatabaseError = DatabaseError
+ DataError = DataError
+ OperationalError = OperationalError
+ IntegrityError = IntegrityError
+ InternalError = InternalError
+ ProgrammingError = ProgrammingError
+ NotSupportedError = NotSupportedError
+
+class DictCursor(Cursor):
+ """A cursor which returns results as a dictionary"""
+
+ def execute(self, query, args=None):
+ result = super(DictCursor, self).execute(query, args)
+ if self.description:
+ self._fields = [ field[0] for field in self.description ]
+ return result
+
+ def fetchone(self):
+ ''' Fetch the next row '''
+ self._check_executed()
+ if self._rows is None or self.rownumber >= len(self._rows):
+ return None
+ result = dict(zip(self._fields, self._rows[self.rownumber]))
+ self.rownumber += 1
+ return result
+
+ def fetchmany(self, size=None):
+ ''' Fetch several rows '''
+ self._check_executed()
+ if self._rows is None:
+ return None
+ end = self.rownumber + (size or self.arraysize)
+ result = [ dict(zip(self._fields, r)) for r in self._rows[self.rownumber:end] ]
+ self.rownumber = min(end, len(self._rows))
+ return tuple(result)
+
+ def fetchall(self):
+ ''' Fetch all the rows '''
+ self._check_executed()
+ if self._rows is None:
+ return None
+ if self.rownumber:
+ result = [ dict(zip(self._fields, r)) for r in self._rows[self.rownumber:] ]
+ else:
+ result = [ dict(zip(self._fields, r)) for r in self._rows ]
+ self.rownumber = len(self._rows)
+ return tuple(result)
+
diff --git a/tools/testClient/pymysql/err.py b/tools/testClient/pymysql/err.py
new file mode 100644
index 00000000000..b4322c63354
--- /dev/null
+++ b/tools/testClient/pymysql/err.py
@@ -0,0 +1,147 @@
+import struct
+
+
+try:
+ StandardError, Warning
+except ImportError:
+ try:
+ from exceptions import StandardError, Warning
+ except ImportError:
+ import sys
+ e = sys.modules['exceptions']
+ StandardError = e.StandardError
+ Warning = e.Warning
+
+from constants import ER
+import sys
+
+class MySQLError(StandardError):
+
+ """Exception related to operation with MySQL."""
+
+
+class Warning(Warning, MySQLError):
+
+ """Exception raised for important warnings like data truncations
+ while inserting, etc."""
+
+class Error(MySQLError):
+
+ """Exception that is the base class of all other error exceptions
+ (not Warning)."""
+
+
+class InterfaceError(Error):
+
+ """Exception raised for errors that are related to the database
+ interface rather than the database itself."""
+
+
+class DatabaseError(Error):
+
+ """Exception raised for errors that are related to the
+ database."""
+
+
+class DataError(DatabaseError):
+
+ """Exception raised for errors that are due to problems with the
+ processed data like division by zero, numeric value out of range,
+ etc."""
+
+
+class OperationalError(DatabaseError):
+
+ """Exception raised for errors that are related to the database's
+ operation and not necessarily under the control of the programmer,
+ e.g. an unexpected disconnect occurs, the data source name is not
+ found, a transaction could not be processed, a memory allocation
+ error occurred during processing, etc."""
+
+
+class IntegrityError(DatabaseError):
+
+ """Exception raised when the relational integrity of the database
+ is affected, e.g. a foreign key check fails, duplicate key,
+ etc."""
+
+
+class InternalError(DatabaseError):
+
+ """Exception raised when the database encounters an internal
+ error, e.g. the cursor is not valid anymore, the transaction is
+ out of sync, etc."""
+
+
+class ProgrammingError(DatabaseError):
+
+ """Exception raised for programming errors, e.g. table not found
+ or already exists, syntax error in the SQL statement, wrong number
+ of parameters specified, etc."""
+
+
+class NotSupportedError(DatabaseError):
+
+ """Exception raised in case a method or database API was used
+ which is not supported by the database, e.g. requesting a
+ .rollback() on a connection that does not support transaction or
+ has transactions turned off."""
+
+
+error_map = {}
+
+def _map_error(exc, *errors):
+ for error in errors:
+ error_map[error] = exc
+
+_map_error(ProgrammingError, ER.DB_CREATE_EXISTS, ER.SYNTAX_ERROR,
+ ER.PARSE_ERROR, ER.NO_SUCH_TABLE, ER.WRONG_DB_NAME,
+ ER.WRONG_TABLE_NAME, ER.FIELD_SPECIFIED_TWICE,
+ ER.INVALID_GROUP_FUNC_USE, ER.UNSUPPORTED_EXTENSION,
+ ER.TABLE_MUST_HAVE_COLUMNS, ER.CANT_DO_THIS_DURING_AN_TRANSACTION)
+_map_error(DataError, ER.WARN_DATA_TRUNCATED, ER.WARN_NULL_TO_NOTNULL,
+ ER.WARN_DATA_OUT_OF_RANGE, ER.NO_DEFAULT, ER.PRIMARY_CANT_HAVE_NULL,
+ ER.DATA_TOO_LONG, ER.DATETIME_FUNCTION_OVERFLOW)
+_map_error(IntegrityError, ER.DUP_ENTRY, ER.NO_REFERENCED_ROW,
+ ER.NO_REFERENCED_ROW_2, ER.ROW_IS_REFERENCED, ER.ROW_IS_REFERENCED_2,
+ ER.CANNOT_ADD_FOREIGN)
+_map_error(NotSupportedError, ER.WARNING_NOT_COMPLETE_ROLLBACK,
+ ER.NOT_SUPPORTED_YET, ER.FEATURE_DISABLED, ER.UNKNOWN_STORAGE_ENGINE)
+_map_error(OperationalError, ER.DBACCESS_DENIED_ERROR, ER.ACCESS_DENIED_ERROR,
+ ER.TABLEACCESS_DENIED_ERROR, ER.COLUMNACCESS_DENIED_ERROR)
+
+del _map_error, ER
+
+
+def _get_error_info(data):
+ errno = struct.unpack(' tuple)
+ c.execute("SELECT * from dictcursor where name='bob'")
+ r = c.fetchall()
+ self.assertEqual((bob,),r,"fetch a 1 row result via fetchall failed via DictCursor")
+ # same test again but iterate over the
+ c.execute("SELECT * from dictcursor where name='bob'")
+ for r in c:
+ self.assertEqual(bob, r,"fetch a 1 row result via iteration failed via DictCursor")
+ # get all 3 row via fetchall
+ c.execute("SELECT * from dictcursor")
+ r = c.fetchall()
+ self.assertEqual((bob,jim,fred), r, "fetchall failed via DictCursor")
+ #same test again but do a list comprehension
+ c.execute("SELECT * from dictcursor")
+ r = [x for x in c]
+ self.assertEqual([bob,jim,fred], r, "list comprehension failed via DictCursor")
+ # get all 2 row via fetchmany
+ c.execute("SELECT * from dictcursor")
+ r = c.fetchmany(2)
+ self.assertEqual((bob,jim), r, "fetchmany failed via DictCursor")
+ finally:
+ c.execute("drop table dictcursor")
+
+__all__ = ["TestDictCursor"]
+
+if __name__ == "__main__":
+ import unittest
+ unittest.main()
diff --git a/tools/testClient/pymysql/tests/test_basic.py b/tools/testClient/pymysql/tests/test_basic.py
new file mode 100644
index 00000000000..c8fdd297f44
--- /dev/null
+++ b/tools/testClient/pymysql/tests/test_basic.py
@@ -0,0 +1,193 @@
+from pymysql.tests import base
+from pymysql import util
+
+import time
+import datetime
+
+class TestConversion(base.PyMySQLTestCase):
+ def test_datatypes(self):
+ """ test every data type """
+ conn = self.connections[0]
+ c = conn.cursor()
+ c.execute("create table test_datatypes (b bit, i int, l bigint, f real, s varchar(32), u varchar(32), bb blob, d date, dt datetime, ts timestamp, td time, t time, st datetime)")
+ try:
+ # insert values
+ v = (True, -3, 123456789012, 5.7, "hello'\" world", u"Espa\xc3\xb1ol", "binary\x00data".encode(conn.charset), datetime.date(1988,2,2), datetime.datetime.now(), datetime.timedelta(5,6), datetime.time(16,32), time.localtime())
+ c.execute("insert into test_datatypes (b,i,l,f,s,u,bb,d,dt,td,t,st) values (%s,%s,%s,%s,%s,%s,%s,%s,%s,%s,%s,%s)", v)
+ c.execute("select b,i,l,f,s,u,bb,d,dt,td,t,st from test_datatypes")
+ r = c.fetchone()
+ self.assertEqual(util.int2byte(1), r[0])
+ self.assertEqual(v[1:8], r[1:8])
+ # mysql throws away microseconds so we need to check datetimes
+ # specially. additionally times are turned into timedeltas.
+ self.assertEqual(datetime.datetime(*v[8].timetuple()[:6]), r[8])
+ self.assertEqual(v[9], r[9]) # just timedeltas
+ self.assertEqual(datetime.timedelta(0, 60 * (v[10].hour * 60 + v[10].minute)), r[10])
+ self.assertEqual(datetime.datetime(*v[-1][:6]), r[-1])
+
+ c.execute("delete from test_datatypes")
+
+ # check nulls
+ c.execute("insert into test_datatypes (b,i,l,f,s,u,bb,d,dt,td,t,st) values (%s,%s,%s,%s,%s,%s,%s,%s,%s,%s,%s,%s)", [None] * 12)
+ c.execute("select b,i,l,f,s,u,bb,d,dt,td,t,st from test_datatypes")
+ r = c.fetchone()
+ self.assertEqual(tuple([None] * 12), r)
+
+ c.execute("delete from test_datatypes")
+
+ # check sequence type
+ c.execute("insert into test_datatypes (i, l) values (2,4), (6,8), (10,12)")
+ c.execute("select l from test_datatypes where i in %s order by i", ((2,6),))
+ r = c.fetchall()
+ self.assertEqual(((4,),(8,)), r)
+ finally:
+ c.execute("drop table test_datatypes")
+
+ def test_dict(self):
+ """ test dict escaping """
+ conn = self.connections[0]
+ c = conn.cursor()
+ c.execute("create table test_dict (a integer, b integer, c integer)")
+ try:
+ c.execute("insert into test_dict (a,b,c) values (%(a)s, %(b)s, %(c)s)", {"a":1,"b":2,"c":3})
+ c.execute("select a,b,c from test_dict")
+ self.assertEqual((1,2,3), c.fetchone())
+ finally:
+ c.execute("drop table test_dict")
+
+ def test_string(self):
+ conn = self.connections[0]
+ c = conn.cursor()
+ c.execute("create table test_dict (a text)")
+ test_value = "I am a test string"
+ try:
+ c.execute("insert into test_dict (a) values (%s)", test_value)
+ c.execute("select a from test_dict")
+ self.assertEqual((test_value,), c.fetchone())
+ finally:
+ c.execute("drop table test_dict")
+
+ def test_integer(self):
+ conn = self.connections[0]
+ c = conn.cursor()
+ c.execute("create table test_dict (a integer)")
+ test_value = 12345
+ try:
+ c.execute("insert into test_dict (a) values (%s)", test_value)
+ c.execute("select a from test_dict")
+ self.assertEqual((test_value,), c.fetchone())
+ finally:
+ c.execute("drop table test_dict")
+
+
+ def test_big_blob(self):
+ """ test tons of data """
+ conn = self.connections[0]
+ c = conn.cursor()
+ c.execute("create table test_big_blob (b blob)")
+ try:
+ data = "pymysql" * 1024
+ c.execute("insert into test_big_blob (b) values (%s)", (data,))
+ c.execute("select b from test_big_blob")
+ self.assertEqual(data.encode(conn.charset), c.fetchone()[0])
+ finally:
+ c.execute("drop table test_big_blob")
+
+class TestCursor(base.PyMySQLTestCase):
+ # this test case does not work quite right yet, however,
+ # we substitute in None for the erroneous field which is
+ # compatible with the DB-API 2.0 spec and has not broken
+ # any unit tests for anything we've tried.
+
+ #def test_description(self):
+ # """ test description attribute """
+ # # result is from MySQLdb module
+ # r = (('Host', 254, 11, 60, 60, 0, 0),
+ # ('User', 254, 16, 16, 16, 0, 0),
+ # ('Password', 254, 41, 41, 41, 0, 0),
+ # ('Select_priv', 254, 1, 1, 1, 0, 0),
+ # ('Insert_priv', 254, 1, 1, 1, 0, 0),
+ # ('Update_priv', 254, 1, 1, 1, 0, 0),
+ # ('Delete_priv', 254, 1, 1, 1, 0, 0),
+ # ('Create_priv', 254, 1, 1, 1, 0, 0),
+ # ('Drop_priv', 254, 1, 1, 1, 0, 0),
+ # ('Reload_priv', 254, 1, 1, 1, 0, 0),
+ # ('Shutdown_priv', 254, 1, 1, 1, 0, 0),
+ # ('Process_priv', 254, 1, 1, 1, 0, 0),
+ # ('File_priv', 254, 1, 1, 1, 0, 0),
+ # ('Grant_priv', 254, 1, 1, 1, 0, 0),
+ # ('References_priv', 254, 1, 1, 1, 0, 0),
+ # ('Index_priv', 254, 1, 1, 1, 0, 0),
+ # ('Alter_priv', 254, 1, 1, 1, 0, 0),
+ # ('Show_db_priv', 254, 1, 1, 1, 0, 0),
+ # ('Super_priv', 254, 1, 1, 1, 0, 0),
+ # ('Create_tmp_table_priv', 254, 1, 1, 1, 0, 0),
+ # ('Lock_tables_priv', 254, 1, 1, 1, 0, 0),
+ # ('Execute_priv', 254, 1, 1, 1, 0, 0),
+ # ('Repl_slave_priv', 254, 1, 1, 1, 0, 0),
+ # ('Repl_client_priv', 254, 1, 1, 1, 0, 0),
+ # ('Create_view_priv', 254, 1, 1, 1, 0, 0),
+ # ('Show_view_priv', 254, 1, 1, 1, 0, 0),
+ # ('Create_routine_priv', 254, 1, 1, 1, 0, 0),
+ # ('Alter_routine_priv', 254, 1, 1, 1, 0, 0),
+ # ('Create_user_priv', 254, 1, 1, 1, 0, 0),
+ # ('Event_priv', 254, 1, 1, 1, 0, 0),
+ # ('Trigger_priv', 254, 1, 1, 1, 0, 0),
+ # ('ssl_type', 254, 0, 9, 9, 0, 0),
+ # ('ssl_cipher', 252, 0, 65535, 65535, 0, 0),
+ # ('x509_issuer', 252, 0, 65535, 65535, 0, 0),
+ # ('x509_subject', 252, 0, 65535, 65535, 0, 0),
+ # ('max_questions', 3, 1, 11, 11, 0, 0),
+ # ('max_updates', 3, 1, 11, 11, 0, 0),
+ # ('max_connections', 3, 1, 11, 11, 0, 0),
+ # ('max_user_connections', 3, 1, 11, 11, 0, 0))
+ # conn = self.connections[0]
+ # c = conn.cursor()
+ # c.execute("select * from mysql.user")
+ #
+ # self.assertEqual(r, c.description)
+
+ def test_fetch_no_result(self):
+ """ test a fetchone() with no rows """
+ conn = self.connections[0]
+ c = conn.cursor()
+ c.execute("create table test_nr (b varchar(32))")
+ try:
+ data = "pymysql"
+ c.execute("insert into test_nr (b) values (%s)", (data,))
+ self.assertEqual(None, c.fetchone())
+ finally:
+ c.execute("drop table test_nr")
+
+ def test_aggregates(self):
+ """ test aggregate functions """
+ conn = self.connections[0]
+ c = conn.cursor()
+ try:
+ c.execute('create table test_aggregates (i integer)')
+ for i in xrange(0, 10):
+ c.execute('insert into test_aggregates (i) values (%s)', (i,))
+ c.execute('select sum(i) from test_aggregates')
+ r, = c.fetchone()
+ self.assertEqual(sum(range(0,10)), r)
+ finally:
+ c.execute('drop table test_aggregates')
+
+ def test_single_tuple(self):
+ """ test a single tuple """
+ conn = self.connections[0]
+ c = conn.cursor()
+ try:
+ c.execute("create table mystuff (id integer primary key)")
+ c.execute("insert into mystuff (id) values (1)")
+ c.execute("insert into mystuff (id) values (2)")
+ c.execute("select id from mystuff where id in %s", ((1,),))
+ self.assertEqual([(1,)], list(c.fetchall()))
+ finally:
+ c.execute("drop table mystuff")
+
+__all__ = ["TestConversion","TestCursor"]
+
+if __name__ == "__main__":
+ import unittest
+ unittest.main()
diff --git a/tools/testClient/pymysql/tests/test_example.py b/tools/testClient/pymysql/tests/test_example.py
new file mode 100644
index 00000000000..2da05db31c6
--- /dev/null
+++ b/tools/testClient/pymysql/tests/test_example.py
@@ -0,0 +1,32 @@
+import pymysql
+from pymysql.tests import base
+
+class TestExample(base.PyMySQLTestCase):
+ def test_example(self):
+ conn = pymysql.connect(host='127.0.0.1', port=3306, user='root', passwd='', db='mysql')
+
+
+ cur = conn.cursor()
+
+ cur.execute("SELECT Host,User FROM user")
+
+ # print cur.description
+
+ # r = cur.fetchall()
+ # print r
+ # ...or...
+ u = False
+
+ for r in cur.fetchall():
+ u = u or conn.user in r
+
+ self.assertTrue(u)
+
+ cur.close()
+ conn.close()
+
+__all__ = ["TestExample"]
+
+if __name__ == "__main__":
+ import unittest
+ unittest.main()
diff --git a/tools/testClient/pymysql/tests/test_issues.py b/tools/testClient/pymysql/tests/test_issues.py
new file mode 100644
index 00000000000..38d71639c90
--- /dev/null
+++ b/tools/testClient/pymysql/tests/test_issues.py
@@ -0,0 +1,268 @@
+import pymysql
+from pymysql.tests import base
+
+import sys
+
+try:
+ import imp
+ reload = imp.reload
+except AttributeError:
+ pass
+
+import datetime
+
+class TestOldIssues(base.PyMySQLTestCase):
+ def test_issue_3(self):
+ """ undefined methods datetime_or_None, date_or_None """
+ conn = self.connections[0]
+ c = conn.cursor()
+ c.execute("create table issue3 (d date, t time, dt datetime, ts timestamp)")
+ try:
+ c.execute("insert into issue3 (d, t, dt, ts) values (%s,%s,%s,%s)", (None, None, None, None))
+ c.execute("select d from issue3")
+ self.assertEqual(None, c.fetchone()[0])
+ c.execute("select t from issue3")
+ self.assertEqual(None, c.fetchone()[0])
+ c.execute("select dt from issue3")
+ self.assertEqual(None, c.fetchone()[0])
+ c.execute("select ts from issue3")
+ self.assertTrue(isinstance(c.fetchone()[0], datetime.datetime))
+ finally:
+ c.execute("drop table issue3")
+
+ def test_issue_4(self):
+ """ can't retrieve TIMESTAMP fields """
+ conn = self.connections[0]
+ c = conn.cursor()
+ c.execute("create table issue4 (ts timestamp)")
+ try:
+ c.execute("insert into issue4 (ts) values (now())")
+ c.execute("select ts from issue4")
+ self.assertTrue(isinstance(c.fetchone()[0], datetime.datetime))
+ finally:
+ c.execute("drop table issue4")
+
+ def test_issue_5(self):
+ """ query on information_schema.tables fails """
+ con = self.connections[0]
+ cur = con.cursor()
+ cur.execute("select * from information_schema.tables")
+
+ def test_issue_6(self):
+ """ exception: TypeError: ord() expected a character, but string of length 0 found """
+ conn = pymysql.connect(host="localhost",user="root",passwd="",db="mysql")
+ c = conn.cursor()
+ c.execute("select * from user")
+ conn.close()
+
+ def test_issue_8(self):
+ """ Primary Key and Index error when selecting data """
+ conn = self.connections[0]
+ c = conn.cursor()
+ c.execute("""CREATE TABLE `test` (`station` int(10) NOT NULL DEFAULT '0', `dh`
+datetime NOT NULL DEFAULT '0000-00-00 00:00:00', `echeance` int(1) NOT NULL
+DEFAULT '0', `me` double DEFAULT NULL, `mo` double DEFAULT NULL, PRIMARY
+KEY (`station`,`dh`,`echeance`)) ENGINE=MyISAM DEFAULT CHARSET=latin1;""")
+ try:
+ self.assertEqual(0, c.execute("SELECT * FROM test"))
+ c.execute("ALTER TABLE `test` ADD INDEX `idx_station` (`station`)")
+ self.assertEqual(0, c.execute("SELECT * FROM test"))
+ finally:
+ c.execute("drop table test")
+
+ def test_issue_9(self):
+ """ sets DeprecationWarning in Python 2.6 """
+ try:
+ reload(pymysql)
+ except DeprecationWarning:
+ self.fail()
+
+ def test_issue_10(self):
+ """ Allocate a variable to return when the exception handler is permissive """
+ conn = self.connections[0]
+ conn.errorhandler = lambda cursor, errorclass, errorvalue: None
+ cur = conn.cursor()
+ cur.execute( "create table t( n int )" )
+ cur.execute( "create table t( n int )" )
+
+ def test_issue_13(self):
+ """ can't handle large result fields """
+ conn = self.connections[0]
+ cur = conn.cursor()
+ try:
+ cur.execute("create table issue13 (t text)")
+ # ticket says 18k
+ size = 18*1024
+ cur.execute("insert into issue13 (t) values (%s)", ("x" * size,))
+ cur.execute("select t from issue13")
+ # use assert_ so that obscenely huge error messages don't print
+ r = cur.fetchone()[0]
+ self.assert_("x" * size == r)
+ finally:
+ cur.execute("drop table issue13")
+
+ def test_issue_14(self):
+ """ typo in converters.py """
+ self.assertEqual('1', pymysql.converters.escape_item(1, "utf8"))
+ self.assertEqual('1', pymysql.converters.escape_item(1L, "utf8"))
+
+ self.assertEqual('1', pymysql.converters.escape_object(1))
+ self.assertEqual('1', pymysql.converters.escape_object(1L))
+
+ def test_issue_15(self):
+ """ query should be expanded before perform character encoding """
+ conn = self.connections[0]
+ c = conn.cursor()
+ c.execute("create table issue15 (t varchar(32))")
+ try:
+ c.execute("insert into issue15 (t) values (%s)", (u'\xe4\xf6\xfc',))
+ c.execute("select t from issue15")
+ self.assertEqual(u'\xe4\xf6\xfc', c.fetchone()[0])
+ finally:
+ c.execute("drop table issue15")
+
+ def test_issue_16(self):
+ """ Patch for string and tuple escaping """
+ conn = self.connections[0]
+ c = conn.cursor()
+ c.execute("create table issue16 (name varchar(32) primary key, email varchar(32))")
+ try:
+ c.execute("insert into issue16 (name, email) values ('pete', 'floydophone')")
+ c.execute("select email from issue16 where name=%s", ("pete",))
+ self.assertEqual("floydophone", c.fetchone()[0])
+ finally:
+ c.execute("drop table issue16")
+
+ def test_issue_17(self):
+ """ could not connect mysql use passwod """
+ conn = self.connections[0]
+ host = self.databases[0]["host"]
+ db = self.databases[0]["db"]
+ c = conn.cursor()
+ # grant access to a table to a user with a password
+ try:
+ c.execute("create table issue17 (x varchar(32) primary key)")
+ c.execute("insert into issue17 (x) values ('hello, world!')")
+ c.execute("grant all privileges on %s.issue17 to 'issue17user'@'%%' identified by '1234'" % db)
+ conn.commit()
+
+ conn2 = pymysql.connect(host=host, user="issue17user", passwd="1234", db=db)
+ c2 = conn2.cursor()
+ c2.execute("select x from issue17")
+ self.assertEqual("hello, world!", c2.fetchone()[0])
+ finally:
+ c.execute("drop table issue17")
+
+def _uni(s, e):
+ # hack for py3
+ if sys.version_info[0] > 2:
+ return unicode(bytes(s, sys.getdefaultencoding()), e)
+ else:
+ return unicode(s, e)
+
+class TestNewIssues(base.PyMySQLTestCase):
+ def test_issue_34(self):
+ try:
+ pymysql.connect(host="localhost", port=1237, user="root")
+ self.fail()
+ except pymysql.OperationalError, e:
+ self.assertEqual(2003, e.args[0])
+ except:
+ self.fail()
+
+ def test_issue_33(self):
+ conn = pymysql.connect(host="localhost", user="root", db=self.databases[0]["db"], charset="utf8")
+ c = conn.cursor()
+ try:
+ c.execute(_uni("create table hei\xc3\x9fe (name varchar(32))", "utf8"))
+ c.execute(_uni("insert into hei\xc3\x9fe (name) values ('Pi\xc3\xb1ata')", "utf8"))
+ c.execute(_uni("select name from hei\xc3\x9fe", "utf8"))
+ self.assertEqual(_uni("Pi\xc3\xb1ata","utf8"), c.fetchone()[0])
+ finally:
+ c.execute(_uni("drop table hei\xc3\x9fe", "utf8"))
+
+ # Will fail without manual intervention:
+ #def test_issue_35(self):
+ #
+ # conn = self.connections[0]
+ # c = conn.cursor()
+ # print "sudo killall -9 mysqld within the next 10 seconds"
+ # try:
+ # c.execute("select sleep(10)")
+ # self.fail()
+ # except pymysql.OperationalError, e:
+ # self.assertEqual(2013, e.args[0])
+
+ def test_issue_36(self):
+ conn = self.connections[0]
+ c = conn.cursor()
+ # kill connections[0]
+ original_count = c.execute("show processlist")
+ kill_id = None
+ for id,user,host,db,command,time,state,info in c.fetchall():
+ if info == "show processlist":
+ kill_id = id
+ break
+ # now nuke the connection
+ conn.kill(kill_id)
+ # make sure this connection has broken
+ try:
+ c.execute("show tables")
+ self.fail()
+ except:
+ pass
+ # check the process list from the other connection
+ self.assertEqual(original_count - 1, self.connections[1].cursor().execute("show processlist"))
+ del self.connections[0]
+
+ def test_issue_37(self):
+ conn = self.connections[0]
+ c = conn.cursor()
+ self.assertEqual(1, c.execute("SELECT @foo"))
+ self.assertEqual((None,), c.fetchone())
+ self.assertEqual(0, c.execute("SET @foo = 'bar'"))
+ c.execute("set @foo = 'bar'")
+
+ def test_issue_38(self):
+ conn = self.connections[0]
+ c = conn.cursor()
+ datum = "a" * 1024 * 1023 # reduced size for most default mysql installs
+
+ try:
+ c.execute("create table issue38 (id integer, data mediumblob)")
+ c.execute("insert into issue38 values (1, %s)", (datum,))
+ finally:
+ c.execute("drop table issue38")
+
+ def disabled_test_issue_54(self):
+ conn = self.connections[0]
+ c = conn.cursor()
+ big_sql = "select * from issue54 where "
+ big_sql += " and ".join("%d=%d" % (i,i) for i in xrange(0, 100000))
+
+ try:
+ c.execute("create table issue54 (id integer primary key)")
+ c.execute("insert into issue54 (id) values (7)")
+ c.execute(big_sql)
+ self.assertEquals(7, c.fetchone()[0])
+ finally:
+ c.execute("drop table issue54")
+
+class TestGitHubIssues(base.PyMySQLTestCase):
+ def test_issue_66(self):
+ conn = self.connections[0]
+ c = conn.cursor()
+ self.assertEquals(0, conn.insert_id())
+ try:
+ c.execute("create table issue66 (id integer primary key auto_increment, x integer)")
+ c.execute("insert into issue66 (x) values (1)")
+ c.execute("insert into issue66 (x) values (1)")
+ self.assertEquals(2, conn.insert_id())
+ finally:
+ c.execute("drop table issue66")
+
+__all__ = ["TestOldIssues", "TestNewIssues", "TestGitHubIssues"]
+
+if __name__ == "__main__":
+ import unittest
+ unittest.main()
diff --git a/tools/testClient/pymysql/tests/thirdparty/__init__.py b/tools/testClient/pymysql/tests/thirdparty/__init__.py
new file mode 100644
index 00000000000..bfcc075fc4b
--- /dev/null
+++ b/tools/testClient/pymysql/tests/thirdparty/__init__.py
@@ -0,0 +1,5 @@
+from test_MySQLdb import *
+
+if __name__ == "__main__":
+ import unittest
+ unittest.main()
diff --git a/tools/testClient/pymysql/tests/thirdparty/test_MySQLdb/__init__.py b/tools/testClient/pymysql/tests/thirdparty/test_MySQLdb/__init__.py
new file mode 100644
index 00000000000..b64f273cf08
--- /dev/null
+++ b/tools/testClient/pymysql/tests/thirdparty/test_MySQLdb/__init__.py
@@ -0,0 +1,7 @@
+from test_MySQLdb_capabilities import test_MySQLdb as test_capabilities
+from test_MySQLdb_nonstandard import *
+from test_MySQLdb_dbapi20 import test_MySQLdb as test_dbapi2
+
+if __name__ == "__main__":
+ import unittest
+ unittest.main()
diff --git a/tools/testClient/pymysql/tests/thirdparty/test_MySQLdb/capabilities.py b/tools/testClient/pymysql/tests/thirdparty/test_MySQLdb/capabilities.py
new file mode 100644
index 00000000000..ddd012330e5
--- /dev/null
+++ b/tools/testClient/pymysql/tests/thirdparty/test_MySQLdb/capabilities.py
@@ -0,0 +1,292 @@
+#!/usr/bin/env python -O
+""" Script to test database capabilities and the DB-API interface
+ for functionality and memory leaks.
+
+ Adapted from a script by M-A Lemburg.
+
+"""
+from time import time
+import array
+import unittest
+
+
+class DatabaseTest(unittest.TestCase):
+
+ db_module = None
+ connect_args = ()
+ connect_kwargs = dict(use_unicode=True, charset="utf8")
+ create_table_extra = "ENGINE=INNODB CHARACTER SET UTF8"
+ rows = 10
+ debug = False
+
+ def setUp(self):
+ import gc
+ db = self.db_module.connect(*self.connect_args, **self.connect_kwargs)
+ self.connection = db
+ self.cursor = db.cursor()
+ self.BLOBText = ''.join([chr(i) for i in range(256)] * 100);
+ self.BLOBUText = u''.join([unichr(i) for i in range(16834)])
+ self.BLOBBinary = self.db_module.Binary(''.join([chr(i) for i in range(256)] * 16))
+
+ leak_test = True
+
+ def tearDown(self):
+ if self.leak_test:
+ import gc
+ del self.cursor
+ orphans = gc.collect()
+ self.assertFalse(orphans, "%d orphaned objects found after deleting cursor" % orphans)
+
+ del self.connection
+ orphans = gc.collect()
+ self.assertFalse(orphans, "%d orphaned objects found after deleting connection" % orphans)
+
+ def table_exists(self, name):
+ try:
+ self.cursor.execute('select * from %s where 1=0' % name)
+ except:
+ return False
+ else:
+ return True
+
+ def quote_identifier(self, ident):
+ return '"%s"' % ident
+
+ def new_table_name(self):
+ i = id(self.cursor)
+ while True:
+ name = self.quote_identifier('tb%08x' % i)
+ if not self.table_exists(name):
+ return name
+ i = i + 1
+
+ def create_table(self, columndefs):
+
+ """ Create a table using a list of column definitions given in
+ columndefs.
+
+ generator must be a function taking arguments (row_number,
+ col_number) returning a suitable data object for insertion
+ into the table.
+
+ """
+ self.table = self.new_table_name()
+ self.cursor.execute('CREATE TABLE %s (%s) %s' %
+ (self.table,
+ ',\n'.join(columndefs),
+ self.create_table_extra))
+
+ def check_data_integrity(self, columndefs, generator):
+ # insert
+ self.create_table(columndefs)
+ insert_statement = ('INSERT INTO %s VALUES (%s)' %
+ (self.table,
+ ','.join(['%s'] * len(columndefs))))
+ data = [ [ generator(i,j) for j in range(len(columndefs)) ]
+ for i in range(self.rows) ]
+ if self.debug:
+ print data
+ self.cursor.executemany(insert_statement, data)
+ self.connection.commit()
+ # verify
+ self.cursor.execute('select * from %s' % self.table)
+ l = self.cursor.fetchall()
+ if self.debug:
+ print l
+ self.assertEquals(len(l), self.rows)
+ try:
+ for i in range(self.rows):
+ for j in range(len(columndefs)):
+ self.assertEquals(l[i][j], generator(i,j))
+ finally:
+ if not self.debug:
+ self.cursor.execute('drop table %s' % (self.table))
+
+ def test_transactions(self):
+ columndefs = ( 'col1 INT', 'col2 VARCHAR(255)')
+ def generator(row, col):
+ if col == 0: return row
+ else: return ('%i' % (row%10))*255
+ self.create_table(columndefs)
+ insert_statement = ('INSERT INTO %s VALUES (%s)' %
+ (self.table,
+ ','.join(['%s'] * len(columndefs))))
+ data = [ [ generator(i,j) for j in range(len(columndefs)) ]
+ for i in range(self.rows) ]
+ self.cursor.executemany(insert_statement, data)
+ # verify
+ self.connection.commit()
+ self.cursor.execute('select * from %s' % self.table)
+ l = self.cursor.fetchall()
+ self.assertEquals(len(l), self.rows)
+ for i in range(self.rows):
+ for j in range(len(columndefs)):
+ self.assertEquals(l[i][j], generator(i,j))
+ delete_statement = 'delete from %s where col1=%%s' % self.table
+ self.cursor.execute(delete_statement, (0,))
+ self.cursor.execute('select col1 from %s where col1=%s' % \
+ (self.table, 0))
+ l = self.cursor.fetchall()
+ self.assertFalse(l, "DELETE didn't work")
+ self.connection.rollback()
+ self.cursor.execute('select col1 from %s where col1=%s' % \
+ (self.table, 0))
+ l = self.cursor.fetchall()
+ self.assertTrue(len(l) == 1, "ROLLBACK didn't work")
+ self.cursor.execute('drop table %s' % (self.table))
+
+ def test_truncation(self):
+ columndefs = ( 'col1 INT', 'col2 VARCHAR(255)')
+ def generator(row, col):
+ if col == 0: return row
+ else: return ('%i' % (row%10))*((255-self.rows/2)+row)
+ self.create_table(columndefs)
+ insert_statement = ('INSERT INTO %s VALUES (%s)' %
+ (self.table,
+ ','.join(['%s'] * len(columndefs))))
+
+ try:
+ self.cursor.execute(insert_statement, (0, '0'*256))
+ except Warning:
+ if self.debug: print self.cursor.messages
+ except self.connection.DataError:
+ pass
+ else:
+ self.fail("Over-long column did not generate warnings/exception with single insert")
+
+ self.connection.rollback()
+
+ try:
+ for i in range(self.rows):
+ data = []
+ for j in range(len(columndefs)):
+ data.append(generator(i,j))
+ self.cursor.execute(insert_statement,tuple(data))
+ except Warning:
+ if self.debug: print self.cursor.messages
+ except self.connection.DataError:
+ pass
+ else:
+ self.fail("Over-long columns did not generate warnings/exception with execute()")
+
+ self.connection.rollback()
+
+ try:
+ data = [ [ generator(i,j) for j in range(len(columndefs)) ]
+ for i in range(self.rows) ]
+ self.cursor.executemany(insert_statement, data)
+ except Warning:
+ if self.debug: print self.cursor.messages
+ except self.connection.DataError:
+ pass
+ else:
+ self.fail("Over-long columns did not generate warnings/exception with executemany()")
+
+ self.connection.rollback()
+ self.cursor.execute('drop table %s' % (self.table))
+
+ def test_CHAR(self):
+ # Character data
+ def generator(row,col):
+ return ('%i' % ((row+col) % 10)) * 255
+ self.check_data_integrity(
+ ('col1 char(255)','col2 char(255)'),
+ generator)
+
+ def test_INT(self):
+ # Number data
+ def generator(row,col):
+ return row*row
+ self.check_data_integrity(
+ ('col1 INT',),
+ generator)
+
+ def test_DECIMAL(self):
+ # DECIMAL
+ def generator(row,col):
+ from decimal import Decimal
+ return Decimal("%d.%02d" % (row, col))
+ self.check_data_integrity(
+ ('col1 DECIMAL(5,2)',),
+ generator)
+
+ def test_DATE(self):
+ ticks = time()
+ def generator(row,col):
+ return self.db_module.DateFromTicks(ticks+row*86400-col*1313)
+ self.check_data_integrity(
+ ('col1 DATE',),
+ generator)
+
+ def test_TIME(self):
+ ticks = time()
+ def generator(row,col):
+ return self.db_module.TimeFromTicks(ticks+row*86400-col*1313)
+ self.check_data_integrity(
+ ('col1 TIME',),
+ generator)
+
+ def test_DATETIME(self):
+ ticks = time()
+ def generator(row,col):
+ return self.db_module.TimestampFromTicks(ticks+row*86400-col*1313)
+ self.check_data_integrity(
+ ('col1 DATETIME',),
+ generator)
+
+ def test_TIMESTAMP(self):
+ ticks = time()
+ def generator(row,col):
+ return self.db_module.TimestampFromTicks(ticks+row*86400-col*1313)
+ self.check_data_integrity(
+ ('col1 TIMESTAMP',),
+ generator)
+
+ def test_fractional_TIMESTAMP(self):
+ ticks = time()
+ def generator(row,col):
+ return self.db_module.TimestampFromTicks(ticks+row*86400-col*1313+row*0.7*col/3.0)
+ self.check_data_integrity(
+ ('col1 TIMESTAMP',),
+ generator)
+
+ def test_LONG(self):
+ def generator(row,col):
+ if col == 0:
+ return row
+ else:
+ return self.BLOBUText # 'BLOB Text ' * 1024
+ self.check_data_integrity(
+ ('col1 INT', 'col2 LONG'),
+ generator)
+
+ def test_TEXT(self):
+ def generator(row,col):
+ if col == 0:
+ return row
+ else:
+ return self.BLOBUText[:5192] # 'BLOB Text ' * 1024
+ self.check_data_integrity(
+ ('col1 INT', 'col2 TEXT'),
+ generator)
+
+ def test_LONG_BYTE(self):
+ def generator(row,col):
+ if col == 0:
+ return row
+ else:
+ return self.BLOBBinary # 'BLOB\000Binary ' * 1024
+ self.check_data_integrity(
+ ('col1 INT','col2 LONG BYTE'),
+ generator)
+
+ def test_BLOB(self):
+ def generator(row,col):
+ if col == 0:
+ return row
+ else:
+ return self.BLOBBinary # 'BLOB\000Binary ' * 1024
+ self.check_data_integrity(
+ ('col1 INT','col2 BLOB'),
+ generator)
+
diff --git a/tools/testClient/pymysql/tests/thirdparty/test_MySQLdb/dbapi20.py b/tools/testClient/pymysql/tests/thirdparty/test_MySQLdb/dbapi20.py
new file mode 100644
index 00000000000..a419e34a46c
--- /dev/null
+++ b/tools/testClient/pymysql/tests/thirdparty/test_MySQLdb/dbapi20.py
@@ -0,0 +1,853 @@
+#!/usr/bin/env python
+''' Python DB API 2.0 driver compliance unit test suite.
+
+ This software is Public Domain and may be used without restrictions.
+
+ "Now we have booze and barflies entering the discussion, plus rumours of
+ DBAs on drugs... and I won't tell you what flashes through my mind each
+ time I read the subject line with 'Anal Compliance' in it. All around
+ this is turning out to be a thoroughly unwholesome unit test."
+
+ -- Ian Bicking
+'''
+
+__rcs_id__ = '$Id$'
+__version__ = '$Revision$'[11:-2]
+__author__ = 'Stuart Bishop '
+
+import unittest
+import time
+
+# $Log$
+# Revision 1.1.2.1 2006/02/25 03:44:32 adustman
+# Generic DB-API unit test module
+#
+# Revision 1.10 2003/10/09 03:14:14 zenzen
+# Add test for DB API 2.0 optional extension, where database exceptions
+# are exposed as attributes on the Connection object.
+#
+# Revision 1.9 2003/08/13 01:16:36 zenzen
+# Minor tweak from Stefan Fleiter
+#
+# Revision 1.8 2003/04/10 00:13:25 zenzen
+# Changes, as per suggestions by M.-A. Lemburg
+# - Add a table prefix, to ensure namespace collisions can always be avoided
+#
+# Revision 1.7 2003/02/26 23:33:37 zenzen
+# Break out DDL into helper functions, as per request by David Rushby
+#
+# Revision 1.6 2003/02/21 03:04:33 zenzen
+# Stuff from Henrik Ekelund:
+# added test_None
+# added test_nextset & hooks
+#
+# Revision 1.5 2003/02/17 22:08:43 zenzen
+# Implement suggestions and code from Henrik Eklund - test that cursor.arraysize
+# defaults to 1 & generic cursor.callproc test added
+#
+# Revision 1.4 2003/02/15 00:16:33 zenzen
+# Changes, as per suggestions and bug reports by M.-A. Lemburg,
+# Matthew T. Kromer, Federico Di Gregorio and Daniel Dittmar
+# - Class renamed
+# - Now a subclass of TestCase, to avoid requiring the driver stub
+# to use multiple inheritance
+# - Reversed the polarity of buggy test in test_description
+# - Test exception heirarchy correctly
+# - self.populate is now self._populate(), so if a driver stub
+# overrides self.ddl1 this change propogates
+# - VARCHAR columns now have a width, which will hopefully make the
+# DDL even more portible (this will be reversed if it causes more problems)
+# - cursor.rowcount being checked after various execute and fetchXXX methods
+# - Check for fetchall and fetchmany returning empty lists after results
+# are exhausted (already checking for empty lists if select retrieved
+# nothing
+# - Fix bugs in test_setoutputsize_basic and test_setinputsizes
+#
+
+class DatabaseAPI20Test(unittest.TestCase):
+ ''' Test a database self.driver for DB API 2.0 compatibility.
+ This implementation tests Gadfly, but the TestCase
+ is structured so that other self.drivers can subclass this
+ test case to ensure compiliance with the DB-API. It is
+ expected that this TestCase may be expanded in the future
+ if ambiguities or edge conditions are discovered.
+
+ The 'Optional Extensions' are not yet being tested.
+
+ self.drivers should subclass this test, overriding setUp, tearDown,
+ self.driver, connect_args and connect_kw_args. Class specification
+ should be as follows:
+
+ import dbapi20
+ class mytest(dbapi20.DatabaseAPI20Test):
+ [...]
+
+ Don't 'import DatabaseAPI20Test from dbapi20', or you will
+ confuse the unit tester - just 'import dbapi20'.
+ '''
+
+ # The self.driver module. This should be the module where the 'connect'
+ # method is to be found
+ driver = None
+ connect_args = () # List of arguments to pass to connect
+ connect_kw_args = {} # Keyword arguments for connect
+ table_prefix = 'dbapi20test_' # If you need to specify a prefix for tables
+
+ ddl1 = 'create table %sbooze (name varchar(20))' % table_prefix
+ ddl2 = 'create table %sbarflys (name varchar(20))' % table_prefix
+ xddl1 = 'drop table %sbooze' % table_prefix
+ xddl2 = 'drop table %sbarflys' % table_prefix
+
+ lowerfunc = 'lower' # Name of stored procedure to convert string->lowercase
+
+ # Some drivers may need to override these helpers, for example adding
+ # a 'commit' after the execute.
+ def executeDDL1(self,cursor):
+ cursor.execute(self.ddl1)
+
+ def executeDDL2(self,cursor):
+ cursor.execute(self.ddl2)
+
+ def setUp(self):
+ ''' self.drivers should override this method to perform required setup
+ if any is necessary, such as creating the database.
+ '''
+ pass
+
+ def tearDown(self):
+ ''' self.drivers should override this method to perform required cleanup
+ if any is necessary, such as deleting the test database.
+ The default drops the tables that may be created.
+ '''
+ con = self._connect()
+ try:
+ cur = con.cursor()
+ for ddl in (self.xddl1,self.xddl2):
+ try:
+ cur.execute(ddl)
+ con.commit()
+ except self.driver.Error:
+ # Assume table didn't exist. Other tests will check if
+ # execute is busted.
+ pass
+ finally:
+ con.close()
+
+ def _connect(self):
+ try:
+ return self.driver.connect(
+ *self.connect_args,**self.connect_kw_args
+ )
+ except AttributeError:
+ self.fail("No connect method found in self.driver module")
+
+ def test_connect(self):
+ con = self._connect()
+ con.close()
+
+ def test_apilevel(self):
+ try:
+ # Must exist
+ apilevel = self.driver.apilevel
+ # Must equal 2.0
+ self.assertEqual(apilevel,'2.0')
+ except AttributeError:
+ self.fail("Driver doesn't define apilevel")
+
+ def test_threadsafety(self):
+ try:
+ # Must exist
+ threadsafety = self.driver.threadsafety
+ # Must be a valid value
+ self.assertTrue(threadsafety in (0,1,2,3))
+ except AttributeError:
+ self.fail("Driver doesn't define threadsafety")
+
+ def test_paramstyle(self):
+ try:
+ # Must exist
+ paramstyle = self.driver.paramstyle
+ # Must be a valid value
+ self.assertTrue(paramstyle in (
+ 'qmark','numeric','named','format','pyformat'
+ ))
+ except AttributeError:
+ self.fail("Driver doesn't define paramstyle")
+
+ def test_Exceptions(self):
+ # Make sure required exceptions exist, and are in the
+ # defined heirarchy.
+ self.assertTrue(issubclass(self.driver.Warning,StandardError))
+ self.assertTrue(issubclass(self.driver.Error,StandardError))
+ self.assertTrue(
+ issubclass(self.driver.InterfaceError,self.driver.Error)
+ )
+ self.assertTrue(
+ issubclass(self.driver.DatabaseError,self.driver.Error)
+ )
+ self.assertTrue(
+ issubclass(self.driver.OperationalError,self.driver.Error)
+ )
+ self.assertTrue(
+ issubclass(self.driver.IntegrityError,self.driver.Error)
+ )
+ self.assertTrue(
+ issubclass(self.driver.InternalError,self.driver.Error)
+ )
+ self.assertTrue(
+ issubclass(self.driver.ProgrammingError,self.driver.Error)
+ )
+ self.assertTrue(
+ issubclass(self.driver.NotSupportedError,self.driver.Error)
+ )
+
+ def test_ExceptionsAsConnectionAttributes(self):
+ # OPTIONAL EXTENSION
+ # Test for the optional DB API 2.0 extension, where the exceptions
+ # are exposed as attributes on the Connection object
+ # I figure this optional extension will be implemented by any
+ # driver author who is using this test suite, so it is enabled
+ # by default.
+ con = self._connect()
+ drv = self.driver
+ self.assertTrue(con.Warning is drv.Warning)
+ self.assertTrue(con.Error is drv.Error)
+ self.assertTrue(con.InterfaceError is drv.InterfaceError)
+ self.assertTrue(con.DatabaseError is drv.DatabaseError)
+ self.assertTrue(con.OperationalError is drv.OperationalError)
+ self.assertTrue(con.IntegrityError is drv.IntegrityError)
+ self.assertTrue(con.InternalError is drv.InternalError)
+ self.assertTrue(con.ProgrammingError is drv.ProgrammingError)
+ self.assertTrue(con.NotSupportedError is drv.NotSupportedError)
+
+
+ def test_commit(self):
+ con = self._connect()
+ try:
+ # Commit must work, even if it doesn't do anything
+ con.commit()
+ finally:
+ con.close()
+
+ def test_rollback(self):
+ con = self._connect()
+ # If rollback is defined, it should either work or throw
+ # the documented exception
+ if hasattr(con,'rollback'):
+ try:
+ con.rollback()
+ except self.driver.NotSupportedError:
+ pass
+
+ def test_cursor(self):
+ con = self._connect()
+ try:
+ cur = con.cursor()
+ finally:
+ con.close()
+
+ def test_cursor_isolation(self):
+ con = self._connect()
+ try:
+ # Make sure cursors created from the same connection have
+ # the documented transaction isolation level
+ cur1 = con.cursor()
+ cur2 = con.cursor()
+ self.executeDDL1(cur1)
+ cur1.execute("insert into %sbooze values ('Victoria Bitter')" % (
+ self.table_prefix
+ ))
+ cur2.execute("select name from %sbooze" % self.table_prefix)
+ booze = cur2.fetchall()
+ self.assertEqual(len(booze),1)
+ self.assertEqual(len(booze[0]),1)
+ self.assertEqual(booze[0][0],'Victoria Bitter')
+ finally:
+ con.close()
+
+ def test_description(self):
+ con = self._connect()
+ try:
+ cur = con.cursor()
+ self.executeDDL1(cur)
+ self.assertEqual(cur.description,None,
+ 'cursor.description should be none after executing a '
+ 'statement that can return no rows (such as DDL)'
+ )
+ cur.execute('select name from %sbooze' % self.table_prefix)
+ self.assertEqual(len(cur.description),1,
+ 'cursor.description describes too many columns'
+ )
+ self.assertEqual(len(cur.description[0]),7,
+ 'cursor.description[x] tuples must have 7 elements'
+ )
+ self.assertEqual(cur.description[0][0].lower(),'name',
+ 'cursor.description[x][0] must return column name'
+ )
+ self.assertEqual(cur.description[0][1],self.driver.STRING,
+ 'cursor.description[x][1] must return column type. Got %r'
+ % cur.description[0][1]
+ )
+
+ # Make sure self.description gets reset
+ self.executeDDL2(cur)
+ self.assertEqual(cur.description,None,
+ 'cursor.description not being set to None when executing '
+ 'no-result statements (eg. DDL)'
+ )
+ finally:
+ con.close()
+
+ def test_rowcount(self):
+ con = self._connect()
+ try:
+ cur = con.cursor()
+ self.executeDDL1(cur)
+ self.assertEqual(cur.rowcount,-1,
+ 'cursor.rowcount should be -1 after executing no-result '
+ 'statements'
+ )
+ cur.execute("insert into %sbooze values ('Victoria Bitter')" % (
+ self.table_prefix
+ ))
+ self.assertTrue(cur.rowcount in (-1,1),
+ 'cursor.rowcount should == number or rows inserted, or '
+ 'set to -1 after executing an insert statement'
+ )
+ cur.execute("select name from %sbooze" % self.table_prefix)
+ self.assertTrue(cur.rowcount in (-1,1),
+ 'cursor.rowcount should == number of rows returned, or '
+ 'set to -1 after executing a select statement'
+ )
+ self.executeDDL2(cur)
+ self.assertEqual(cur.rowcount,-1,
+ 'cursor.rowcount not being reset to -1 after executing '
+ 'no-result statements'
+ )
+ finally:
+ con.close()
+
+ lower_func = 'lower'
+ def test_callproc(self):
+ con = self._connect()
+ try:
+ cur = con.cursor()
+ if self.lower_func and hasattr(cur,'callproc'):
+ r = cur.callproc(self.lower_func,('FOO',))
+ self.assertEqual(len(r),1)
+ self.assertEqual(r[0],'FOO')
+ r = cur.fetchall()
+ self.assertEqual(len(r),1,'callproc produced no result set')
+ self.assertEqual(len(r[0]),1,
+ 'callproc produced invalid result set'
+ )
+ self.assertEqual(r[0][0],'foo',
+ 'callproc produced invalid results'
+ )
+ finally:
+ con.close()
+
+ def test_close(self):
+ con = self._connect()
+ try:
+ cur = con.cursor()
+ finally:
+ con.close()
+
+ # cursor.execute should raise an Error if called after connection
+ # closed
+ self.assertRaises(self.driver.Error,self.executeDDL1,cur)
+
+ # connection.commit should raise an Error if called after connection'
+ # closed.'
+ self.assertRaises(self.driver.Error,con.commit)
+
+ # connection.close should raise an Error if called more than once
+ self.assertRaises(self.driver.Error,con.close)
+
+ def test_execute(self):
+ con = self._connect()
+ try:
+ cur = con.cursor()
+ self._paraminsert(cur)
+ finally:
+ con.close()
+
+ def _paraminsert(self,cur):
+ self.executeDDL1(cur)
+ cur.execute("insert into %sbooze values ('Victoria Bitter')" % (
+ self.table_prefix
+ ))
+ self.assertTrue(cur.rowcount in (-1,1))
+
+ if self.driver.paramstyle == 'qmark':
+ cur.execute(
+ 'insert into %sbooze values (?)' % self.table_prefix,
+ ("Cooper's",)
+ )
+ elif self.driver.paramstyle == 'numeric':
+ cur.execute(
+ 'insert into %sbooze values (:1)' % self.table_prefix,
+ ("Cooper's",)
+ )
+ elif self.driver.paramstyle == 'named':
+ cur.execute(
+ 'insert into %sbooze values (:beer)' % self.table_prefix,
+ {'beer':"Cooper's"}
+ )
+ elif self.driver.paramstyle == 'format':
+ cur.execute(
+ 'insert into %sbooze values (%%s)' % self.table_prefix,
+ ("Cooper's",)
+ )
+ elif self.driver.paramstyle == 'pyformat':
+ cur.execute(
+ 'insert into %sbooze values (%%(beer)s)' % self.table_prefix,
+ {'beer':"Cooper's"}
+ )
+ else:
+ self.fail('Invalid paramstyle')
+ self.assertTrue(cur.rowcount in (-1,1))
+
+ cur.execute('select name from %sbooze' % self.table_prefix)
+ res = cur.fetchall()
+ self.assertEqual(len(res),2,'cursor.fetchall returned too few rows')
+ beers = [res[0][0],res[1][0]]
+ beers.sort()
+ self.assertEqual(beers[0],"Cooper's",
+ 'cursor.fetchall retrieved incorrect data, or data inserted '
+ 'incorrectly'
+ )
+ self.assertEqual(beers[1],"Victoria Bitter",
+ 'cursor.fetchall retrieved incorrect data, or data inserted '
+ 'incorrectly'
+ )
+
+ def test_executemany(self):
+ con = self._connect()
+ try:
+ cur = con.cursor()
+ self.executeDDL1(cur)
+ largs = [ ("Cooper's",) , ("Boag's",) ]
+ margs = [ {'beer': "Cooper's"}, {'beer': "Boag's"} ]
+ if self.driver.paramstyle == 'qmark':
+ cur.executemany(
+ 'insert into %sbooze values (?)' % self.table_prefix,
+ largs
+ )
+ elif self.driver.paramstyle == 'numeric':
+ cur.executemany(
+ 'insert into %sbooze values (:1)' % self.table_prefix,
+ largs
+ )
+ elif self.driver.paramstyle == 'named':
+ cur.executemany(
+ 'insert into %sbooze values (:beer)' % self.table_prefix,
+ margs
+ )
+ elif self.driver.paramstyle == 'format':
+ cur.executemany(
+ 'insert into %sbooze values (%%s)' % self.table_prefix,
+ largs
+ )
+ elif self.driver.paramstyle == 'pyformat':
+ cur.executemany(
+ 'insert into %sbooze values (%%(beer)s)' % (
+ self.table_prefix
+ ),
+ margs
+ )
+ else:
+ self.fail('Unknown paramstyle')
+ self.assertTrue(cur.rowcount in (-1,2),
+ 'insert using cursor.executemany set cursor.rowcount to '
+ 'incorrect value %r' % cur.rowcount
+ )
+ cur.execute('select name from %sbooze' % self.table_prefix)
+ res = cur.fetchall()
+ self.assertEqual(len(res),2,
+ 'cursor.fetchall retrieved incorrect number of rows'
+ )
+ beers = [res[0][0],res[1][0]]
+ beers.sort()
+ self.assertEqual(beers[0],"Boag's",'incorrect data retrieved')
+ self.assertEqual(beers[1],"Cooper's",'incorrect data retrieved')
+ finally:
+ con.close()
+
+ def test_fetchone(self):
+ con = self._connect()
+ try:
+ cur = con.cursor()
+
+ # cursor.fetchone should raise an Error if called before
+ # executing a select-type query
+ self.assertRaises(self.driver.Error,cur.fetchone)
+
+ # cursor.fetchone should raise an Error if called after
+ # executing a query that cannnot return rows
+ self.executeDDL1(cur)
+ self.assertRaises(self.driver.Error,cur.fetchone)
+
+ cur.execute('select name from %sbooze' % self.table_prefix)
+ self.assertEqual(cur.fetchone(),None,
+ 'cursor.fetchone should return None if a query retrieves '
+ 'no rows'
+ )
+ self.assertTrue(cur.rowcount in (-1,0))
+
+ # cursor.fetchone should raise an Error if called after
+ # executing a query that cannnot return rows
+ cur.execute("insert into %sbooze values ('Victoria Bitter')" % (
+ self.table_prefix
+ ))
+ self.assertRaises(self.driver.Error,cur.fetchone)
+
+ cur.execute('select name from %sbooze' % self.table_prefix)
+ r = cur.fetchone()
+ self.assertEqual(len(r),1,
+ 'cursor.fetchone should have retrieved a single row'
+ )
+ self.assertEqual(r[0],'Victoria Bitter',
+ 'cursor.fetchone retrieved incorrect data'
+ )
+ self.assertEqual(cur.fetchone(),None,
+ 'cursor.fetchone should return None if no more rows available'
+ )
+ self.assertTrue(cur.rowcount in (-1,1))
+ finally:
+ con.close()
+
+ samples = [
+ 'Carlton Cold',
+ 'Carlton Draft',
+ 'Mountain Goat',
+ 'Redback',
+ 'Victoria Bitter',
+ 'XXXX'
+ ]
+
+ def _populate(self):
+ ''' Return a list of sql commands to setup the DB for the fetch
+ tests.
+ '''
+ populate = [
+ "insert into %sbooze values ('%s')" % (self.table_prefix,s)
+ for s in self.samples
+ ]
+ return populate
+
+ def test_fetchmany(self):
+ con = self._connect()
+ try:
+ cur = con.cursor()
+
+ # cursor.fetchmany should raise an Error if called without
+ #issuing a query
+ self.assertRaises(self.driver.Error,cur.fetchmany,4)
+
+ self.executeDDL1(cur)
+ for sql in self._populate():
+ cur.execute(sql)
+
+ cur.execute('select name from %sbooze' % self.table_prefix)
+ r = cur.fetchmany()
+ self.assertEqual(len(r),1,
+ 'cursor.fetchmany retrieved incorrect number of rows, '
+ 'default of arraysize is one.'
+ )
+ cur.arraysize=10
+ r = cur.fetchmany(3) # Should get 3 rows
+ self.assertEqual(len(r),3,
+ 'cursor.fetchmany retrieved incorrect number of rows'
+ )
+ r = cur.fetchmany(4) # Should get 2 more
+ self.assertEqual(len(r),2,
+ 'cursor.fetchmany retrieved incorrect number of rows'
+ )
+ r = cur.fetchmany(4) # Should be an empty sequence
+ self.assertEqual(len(r),0,
+ 'cursor.fetchmany should return an empty sequence after '
+ 'results are exhausted'
+ )
+ self.assertTrue(cur.rowcount in (-1,6))
+
+ # Same as above, using cursor.arraysize
+ cur.arraysize=4
+ cur.execute('select name from %sbooze' % self.table_prefix)
+ r = cur.fetchmany() # Should get 4 rows
+ self.assertEqual(len(r),4,
+ 'cursor.arraysize not being honoured by fetchmany'
+ )
+ r = cur.fetchmany() # Should get 2 more
+ self.assertEqual(len(r),2)
+ r = cur.fetchmany() # Should be an empty sequence
+ self.assertEqual(len(r),0)
+ self.assertTrue(cur.rowcount in (-1,6))
+
+ cur.arraysize=6
+ cur.execute('select name from %sbooze' % self.table_prefix)
+ rows = cur.fetchmany() # Should get all rows
+ self.assertTrue(cur.rowcount in (-1,6))
+ self.assertEqual(len(rows),6)
+ self.assertEqual(len(rows),6)
+ rows = [r[0] for r in rows]
+ rows.sort()
+
+ # Make sure we get the right data back out
+ for i in range(0,6):
+ self.assertEqual(rows[i],self.samples[i],
+ 'incorrect data retrieved by cursor.fetchmany'
+ )
+
+ rows = cur.fetchmany() # Should return an empty list
+ self.assertEqual(len(rows),0,
+ 'cursor.fetchmany should return an empty sequence if '
+ 'called after the whole result set has been fetched'
+ )
+ self.assertTrue(cur.rowcount in (-1,6))
+
+ self.executeDDL2(cur)
+ cur.execute('select name from %sbarflys' % self.table_prefix)
+ r = cur.fetchmany() # Should get empty sequence
+ self.assertEqual(len(r),0,
+ 'cursor.fetchmany should return an empty sequence if '
+ 'query retrieved no rows'
+ )
+ self.assertTrue(cur.rowcount in (-1,0))
+
+ finally:
+ con.close()
+
+ def test_fetchall(self):
+ con = self._connect()
+ try:
+ cur = con.cursor()
+ # cursor.fetchall should raise an Error if called
+ # without executing a query that may return rows (such
+ # as a select)
+ self.assertRaises(self.driver.Error, cur.fetchall)
+
+ self.executeDDL1(cur)
+ for sql in self._populate():
+ cur.execute(sql)
+
+ # cursor.fetchall should raise an Error if called
+ # after executing a a statement that cannot return rows
+ self.assertRaises(self.driver.Error,cur.fetchall)
+
+ cur.execute('select name from %sbooze' % self.table_prefix)
+ rows = cur.fetchall()
+ self.assertTrue(cur.rowcount in (-1,len(self.samples)))
+ self.assertEqual(len(rows),len(self.samples),
+ 'cursor.fetchall did not retrieve all rows'
+ )
+ rows = [r[0] for r in rows]
+ rows.sort()
+ for i in range(0,len(self.samples)):
+ self.assertEqual(rows[i],self.samples[i],
+ 'cursor.fetchall retrieved incorrect rows'
+ )
+ rows = cur.fetchall()
+ self.assertEqual(
+ len(rows),0,
+ 'cursor.fetchall should return an empty list if called '
+ 'after the whole result set has been fetched'
+ )
+ self.assertTrue(cur.rowcount in (-1,len(self.samples)))
+
+ self.executeDDL2(cur)
+ cur.execute('select name from %sbarflys' % self.table_prefix)
+ rows = cur.fetchall()
+ self.assertTrue(cur.rowcount in (-1,0))
+ self.assertEqual(len(rows),0,
+ 'cursor.fetchall should return an empty list if '
+ 'a select query returns no rows'
+ )
+
+ finally:
+ con.close()
+
+ def test_mixedfetch(self):
+ con = self._connect()
+ try:
+ cur = con.cursor()
+ self.executeDDL1(cur)
+ for sql in self._populate():
+ cur.execute(sql)
+
+ cur.execute('select name from %sbooze' % self.table_prefix)
+ rows1 = cur.fetchone()
+ rows23 = cur.fetchmany(2)
+ rows4 = cur.fetchone()
+ rows56 = cur.fetchall()
+ self.assertTrue(cur.rowcount in (-1,6))
+ self.assertEqual(len(rows23),2,
+ 'fetchmany returned incorrect number of rows'
+ )
+ self.assertEqual(len(rows56),2,
+ 'fetchall returned incorrect number of rows'
+ )
+
+ rows = [rows1[0]]
+ rows.extend([rows23[0][0],rows23[1][0]])
+ rows.append(rows4[0])
+ rows.extend([rows56[0][0],rows56[1][0]])
+ rows.sort()
+ for i in range(0,len(self.samples)):
+ self.assertEqual(rows[i],self.samples[i],
+ 'incorrect data retrieved or inserted'
+ )
+ finally:
+ con.close()
+
+ def help_nextset_setUp(self,cur):
+ ''' Should create a procedure called deleteme
+ that returns two result sets, first the
+ number of rows in booze then "name from booze"
+ '''
+ raise NotImplementedError,'Helper not implemented'
+ #sql="""
+ # create procedure deleteme as
+ # begin
+ # select count(*) from booze
+ # select name from booze
+ # end
+ #"""
+ #cur.execute(sql)
+
+ def help_nextset_tearDown(self,cur):
+ 'If cleaning up is needed after nextSetTest'
+ raise NotImplementedError,'Helper not implemented'
+ #cur.execute("drop procedure deleteme")
+
+ def test_nextset(self):
+ con = self._connect()
+ try:
+ cur = con.cursor()
+ if not hasattr(cur,'nextset'):
+ return
+
+ try:
+ self.executeDDL1(cur)
+ sql=self._populate()
+ for sql in self._populate():
+ cur.execute(sql)
+
+ self.help_nextset_setUp(cur)
+
+ cur.callproc('deleteme')
+ numberofrows=cur.fetchone()
+ assert numberofrows[0]== len(self.samples)
+ assert cur.nextset()
+ names=cur.fetchall()
+ assert len(names) == len(self.samples)
+ s=cur.nextset()
+ assert s == None,'No more return sets, should return None'
+ finally:
+ self.help_nextset_tearDown(cur)
+
+ finally:
+ con.close()
+
+ def test_nextset(self):
+ raise NotImplementedError,'Drivers need to override this test'
+
+ def test_arraysize(self):
+ # Not much here - rest of the tests for this are in test_fetchmany
+ con = self._connect()
+ try:
+ cur = con.cursor()
+ self.assertTrue(hasattr(cur,'arraysize'),
+ 'cursor.arraysize must be defined'
+ )
+ finally:
+ con.close()
+
+ def test_setinputsizes(self):
+ con = self._connect()
+ try:
+ cur = con.cursor()
+ cur.setinputsizes( (25,) )
+ self._paraminsert(cur) # Make sure cursor still works
+ finally:
+ con.close()
+
+ def test_setoutputsize_basic(self):
+ # Basic test is to make sure setoutputsize doesn't blow up
+ con = self._connect()
+ try:
+ cur = con.cursor()
+ cur.setoutputsize(1000)
+ cur.setoutputsize(2000,0)
+ self._paraminsert(cur) # Make sure the cursor still works
+ finally:
+ con.close()
+
+ def test_setoutputsize(self):
+ # Real test for setoutputsize is driver dependant
+ raise NotImplementedError,'Driver need to override this test'
+
+ def test_None(self):
+ con = self._connect()
+ try:
+ cur = con.cursor()
+ self.executeDDL1(cur)
+ cur.execute('insert into %sbooze values (NULL)' % self.table_prefix)
+ cur.execute('select name from %sbooze' % self.table_prefix)
+ r = cur.fetchall()
+ self.assertEqual(len(r),1)
+ self.assertEqual(len(r[0]),1)
+ self.assertEqual(r[0][0],None,'NULL value not returned as None')
+ finally:
+ con.close()
+
+ def test_Date(self):
+ d1 = self.driver.Date(2002,12,25)
+ d2 = self.driver.DateFromTicks(time.mktime((2002,12,25,0,0,0,0,0,0)))
+ # Can we assume this? API doesn't specify, but it seems implied
+ # self.assertEqual(str(d1),str(d2))
+
+ def test_Time(self):
+ t1 = self.driver.Time(13,45,30)
+ t2 = self.driver.TimeFromTicks(time.mktime((2001,1,1,13,45,30,0,0,0)))
+ # Can we assume this? API doesn't specify, but it seems implied
+ # self.assertEqual(str(t1),str(t2))
+
+ def test_Timestamp(self):
+ t1 = self.driver.Timestamp(2002,12,25,13,45,30)
+ t2 = self.driver.TimestampFromTicks(
+ time.mktime((2002,12,25,13,45,30,0,0,0))
+ )
+ # Can we assume this? API doesn't specify, but it seems implied
+ # self.assertEqual(str(t1),str(t2))
+
+ def test_Binary(self):
+ b = self.driver.Binary('Something')
+ b = self.driver.Binary('')
+
+ def test_STRING(self):
+ self.assertTrue(hasattr(self.driver,'STRING'),
+ 'module.STRING must be defined'
+ )
+
+ def test_BINARY(self):
+ self.assertTrue(hasattr(self.driver,'BINARY'),
+ 'module.BINARY must be defined.'
+ )
+
+ def test_NUMBER(self):
+ self.assertTrue(hasattr(self.driver,'NUMBER'),
+ 'module.NUMBER must be defined.'
+ )
+
+ def test_DATETIME(self):
+ self.assertTrue(hasattr(self.driver,'DATETIME'),
+ 'module.DATETIME must be defined.'
+ )
+
+ def test_ROWID(self):
+ self.assertTrue(hasattr(self.driver,'ROWID'),
+ 'module.ROWID must be defined.'
+ )
+
diff --git a/tools/testClient/pymysql/tests/thirdparty/test_MySQLdb/test_MySQLdb_capabilities.py b/tools/testClient/pymysql/tests/thirdparty/test_MySQLdb/test_MySQLdb_capabilities.py
new file mode 100644
index 00000000000..e0bc93439c2
--- /dev/null
+++ b/tools/testClient/pymysql/tests/thirdparty/test_MySQLdb/test_MySQLdb_capabilities.py
@@ -0,0 +1,115 @@
+#!/usr/bin/env python
+import capabilities
+import unittest
+import pymysql
+from pymysql.tests import base
+import warnings
+
+warnings.filterwarnings('error')
+
+class test_MySQLdb(capabilities.DatabaseTest):
+
+ db_module = pymysql
+ connect_args = ()
+ connect_kwargs = base.PyMySQLTestCase.databases[0].copy()
+ connect_kwargs.update(dict(read_default_file='~/.my.cnf',
+ use_unicode=True,
+ charset='utf8', sql_mode="ANSI,STRICT_TRANS_TABLES,TRADITIONAL"))
+
+ create_table_extra = "ENGINE=INNODB CHARACTER SET UTF8"
+ leak_test = False
+
+ def quote_identifier(self, ident):
+ return "`%s`" % ident
+
+ def test_TIME(self):
+ from datetime import timedelta
+ def generator(row,col):
+ return timedelta(0, row*8000)
+ self.check_data_integrity(
+ ('col1 TIME',),
+ generator)
+
+ def test_TINYINT(self):
+ # Number data
+ def generator(row,col):
+ v = (row*row) % 256
+ if v > 127:
+ v = v-256
+ return v
+ self.check_data_integrity(
+ ('col1 TINYINT',),
+ generator)
+
+ def test_stored_procedures(self):
+ db = self.connection
+ c = self.cursor
+ try:
+ self.create_table(('pos INT', 'tree CHAR(20)'))
+ c.executemany("INSERT INTO %s (pos,tree) VALUES (%%s,%%s)" % self.table,
+ list(enumerate('ash birch cedar larch pine'.split())))
+ db.commit()
+
+ c.execute("""
+ CREATE PROCEDURE test_sp(IN t VARCHAR(255))
+ BEGIN
+ SELECT pos FROM %s WHERE tree = t;
+ END
+ """ % self.table)
+ db.commit()
+
+ c.callproc('test_sp', ('larch',))
+ rows = c.fetchall()
+ self.assertEquals(len(rows), 1)
+ self.assertEquals(rows[0][0], 3)
+ c.nextset()
+ finally:
+ c.execute("DROP PROCEDURE IF EXISTS test_sp")
+ c.execute('drop table %s' % (self.table))
+
+ def test_small_CHAR(self):
+ # Character data
+ def generator(row,col):
+ i = ((row+1)*(col+1)+62)%256
+ if i == 62: return ''
+ if i == 63: return None
+ return chr(i)
+ self.check_data_integrity(
+ ('col1 char(1)','col2 char(1)'),
+ generator)
+
+ def test_bug_2671682(self):
+ from pymysql.constants import ER
+ try:
+ self.cursor.execute("describe some_non_existent_table");
+ except self.connection.ProgrammingError, msg:
+ self.assertTrue(msg.args[0] == ER.NO_SUCH_TABLE)
+
+ def test_insert_values(self):
+ from pymysql.cursors import insert_values
+ query = """INSERT FOO (a, b, c) VALUES (a, b, c)"""
+ matched = insert_values.search(query)
+ self.assertTrue(matched)
+ values = matched.group(1)
+ self.assertTrue(values == "(a, b, c)")
+
+ def test_ping(self):
+ self.connection.ping()
+
+ def test_literal_int(self):
+ self.assertTrue("2" == self.connection.literal(2))
+
+ def test_literal_float(self):
+ self.assertTrue("3.1415" == self.connection.literal(3.1415))
+
+ def test_literal_string(self):
+ self.assertTrue("'foo'" == self.connection.literal("foo"))
+
+
+if __name__ == '__main__':
+ if test_MySQLdb.leak_test:
+ import gc
+ gc.enable()
+ gc.set_debug(gc.DEBUG_LEAK)
+ unittest.main()
+
diff --git a/tools/testClient/pymysql/tests/thirdparty/test_MySQLdb/test_MySQLdb_dbapi20.py b/tools/testClient/pymysql/tests/thirdparty/test_MySQLdb/test_MySQLdb_dbapi20.py
new file mode 100644
index 00000000000..83c002fdf39
--- /dev/null
+++ b/tools/testClient/pymysql/tests/thirdparty/test_MySQLdb/test_MySQLdb_dbapi20.py
@@ -0,0 +1,205 @@
+#!/usr/bin/env python
+import dbapi20
+import unittest
+import pymysql
+from pymysql.tests import base
+
+class test_MySQLdb(dbapi20.DatabaseAPI20Test):
+ driver = pymysql
+ connect_args = ()
+ connect_kw_args = base.PyMySQLTestCase.databases[0].copy()
+ connect_kw_args.update(dict(read_default_file='~/.my.cnf',
+ charset='utf8',
+ sql_mode="ANSI,STRICT_TRANS_TABLES,TRADITIONAL"))
+
+ def test_setoutputsize(self): pass
+ def test_setoutputsize_basic(self): pass
+ def test_nextset(self): pass
+
+ """The tests on fetchone and fetchall and rowcount bogusly
+ test for an exception if the statement cannot return a
+ result set. MySQL always returns a result set; it's just that
+ some things return empty result sets."""
+
+ def test_fetchall(self):
+ con = self._connect()
+ try:
+ cur = con.cursor()
+ # cursor.fetchall should raise an Error if called
+ # without executing a query that may return rows (such
+ # as a select)
+ self.assertRaises(self.driver.Error, cur.fetchall)
+
+ self.executeDDL1(cur)
+ for sql in self._populate():
+ cur.execute(sql)
+
+ # cursor.fetchall should raise an Error if called
+ # after executing a a statement that cannot return rows
+## self.assertRaises(self.driver.Error,cur.fetchall)
+
+ cur.execute('select name from %sbooze' % self.table_prefix)
+ rows = cur.fetchall()
+ self.assertTrue(cur.rowcount in (-1,len(self.samples)))
+ self.assertEqual(len(rows),len(self.samples),
+ 'cursor.fetchall did not retrieve all rows'
+ )
+ rows = [r[0] for r in rows]
+ rows.sort()
+ for i in range(0,len(self.samples)):
+ self.assertEqual(rows[i],self.samples[i],
+ 'cursor.fetchall retrieved incorrect rows'
+ )
+ rows = cur.fetchall()
+ self.assertEqual(
+ len(rows),0,
+ 'cursor.fetchall should return an empty list if called '
+ 'after the whole result set has been fetched'
+ )
+ self.assertTrue(cur.rowcount in (-1,len(self.samples)))
+
+ self.executeDDL2(cur)
+ cur.execute('select name from %sbarflys' % self.table_prefix)
+ rows = cur.fetchall()
+ self.assertTrue(cur.rowcount in (-1,0))
+ self.assertEqual(len(rows),0,
+ 'cursor.fetchall should return an empty list if '
+ 'a select query returns no rows'
+ )
+
+ finally:
+ con.close()
+
+ def test_fetchone(self):
+ con = self._connect()
+ try:
+ cur = con.cursor()
+
+ # cursor.fetchone should raise an Error if called before
+ # executing a select-type query
+ self.assertRaises(self.driver.Error,cur.fetchone)
+
+ # cursor.fetchone should raise an Error if called after
+ # executing a query that cannnot return rows
+ self.executeDDL1(cur)
+## self.assertRaises(self.driver.Error,cur.fetchone)
+
+ cur.execute('select name from %sbooze' % self.table_prefix)
+ self.assertEqual(cur.fetchone(),None,
+ 'cursor.fetchone should return None if a query retrieves '
+ 'no rows'
+ )
+ self.assertTrue(cur.rowcount in (-1,0))
+
+ # cursor.fetchone should raise an Error if called after
+ # executing a query that cannnot return rows
+ cur.execute("insert into %sbooze values ('Victoria Bitter')" % (
+ self.table_prefix
+ ))
+## self.assertRaises(self.driver.Error,cur.fetchone)
+
+ cur.execute('select name from %sbooze' % self.table_prefix)
+ r = cur.fetchone()
+ self.assertEqual(len(r),1,
+ 'cursor.fetchone should have retrieved a single row'
+ )
+ self.assertEqual(r[0],'Victoria Bitter',
+ 'cursor.fetchone retrieved incorrect data'
+ )
+## self.assertEqual(cur.fetchone(),None,
+## 'cursor.fetchone should return None if no more rows available'
+## )
+ self.assertTrue(cur.rowcount in (-1,1))
+ finally:
+ con.close()
+
+ # Same complaint as for fetchall and fetchone
+ def test_rowcount(self):
+ con = self._connect()
+ try:
+ cur = con.cursor()
+ self.executeDDL1(cur)
+## self.assertEqual(cur.rowcount,-1,
+## 'cursor.rowcount should be -1 after executing no-result '
+## 'statements'
+## )
+ cur.execute("insert into %sbooze values ('Victoria Bitter')" % (
+ self.table_prefix
+ ))
+## self.assertTrue(cur.rowcount in (-1,1),
+## 'cursor.rowcount should == number or rows inserted, or '
+## 'set to -1 after executing an insert statement'
+## )
+ cur.execute("select name from %sbooze" % self.table_prefix)
+ self.assertTrue(cur.rowcount in (-1,1),
+ 'cursor.rowcount should == number of rows returned, or '
+ 'set to -1 after executing a select statement'
+ )
+ self.executeDDL2(cur)
+## self.assertEqual(cur.rowcount,-1,
+## 'cursor.rowcount not being reset to -1 after executing '
+## 'no-result statements'
+## )
+ finally:
+ con.close()
+
+ def test_callproc(self):
+ pass # performed in test_MySQL_capabilities
+
+ def help_nextset_setUp(self,cur):
+ ''' Should create a procedure called deleteme
+ that returns two result sets, first the
+ number of rows in booze then "name from booze"
+ '''
+ sql="""
+ create procedure deleteme()
+ begin
+ select count(*) from %(tp)sbooze;
+ select name from %(tp)sbooze;
+ end
+ """ % dict(tp=self.table_prefix)
+ cur.execute(sql)
+
+ def help_nextset_tearDown(self,cur):
+ 'If cleaning up is needed after nextSetTest'
+ cur.execute("drop procedure deleteme")
+
+ def test_nextset(self):
+ from warnings import warn
+ con = self._connect()
+ try:
+ cur = con.cursor()
+ if not hasattr(cur,'nextset'):
+ return
+
+ try:
+ self.executeDDL1(cur)
+ sql=self._populate()
+ for sql in self._populate():
+ cur.execute(sql)
+
+ self.help_nextset_setUp(cur)
+
+ cur.callproc('deleteme')
+ numberofrows=cur.fetchone()
+ assert numberofrows[0]== len(self.samples)
+ assert cur.nextset()
+ names=cur.fetchall()
+ assert len(names) == len(self.samples)
+ s=cur.nextset()
+ if s:
+ empty = cur.fetchall()
+ self.assertEquals(len(empty), 0,
+ "non-empty result set after other result sets")
+ #warn("Incompatibility: MySQL returns an empty result set for the CALL itself",
+ # Warning)
+ #assert s == None,'No more return sets, should return None'
+ finally:
+ self.help_nextset_tearDown(cur)
+
+ finally:
+ con.close()
+
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/tools/testClient/pymysql/tests/thirdparty/test_MySQLdb/test_MySQLdb_nonstandard.py b/tools/testClient/pymysql/tests/thirdparty/test_MySQLdb/test_MySQLdb_nonstandard.py
new file mode 100644
index 00000000000..f49369cb4f7
--- /dev/null
+++ b/tools/testClient/pymysql/tests/thirdparty/test_MySQLdb/test_MySQLdb_nonstandard.py
@@ -0,0 +1,90 @@
+import unittest
+
+import pymysql
+_mysql = pymysql
+from pymysql.constants import FIELD_TYPE
+from pymysql.tests import base
+
+
+class TestDBAPISet(unittest.TestCase):
+ def test_set_equality(self):
+ self.assertTrue(pymysql.STRING == pymysql.STRING)
+
+ def test_set_inequality(self):
+ self.assertTrue(pymysql.STRING != pymysql.NUMBER)
+
+ def test_set_equality_membership(self):
+ self.assertTrue(FIELD_TYPE.VAR_STRING == pymysql.STRING)
+
+ def test_set_inequality_membership(self):
+ self.assertTrue(FIELD_TYPE.DATE != pymysql.STRING)
+
+
+class CoreModule(unittest.TestCase):
+ """Core _mysql module features."""
+
+ def test_NULL(self):
+ """Should have a NULL constant."""
+ self.assertEqual(_mysql.NULL, 'NULL')
+
+ def test_version(self):
+ """Version information sanity."""
+ self.assertTrue(isinstance(_mysql.__version__, str))
+
+ self.assertTrue(isinstance(_mysql.version_info, tuple))
+ self.assertEqual(len(_mysql.version_info), 5)
+
+ def test_client_info(self):
+ self.assertTrue(isinstance(_mysql.get_client_info(), str))
+
+ def test_thread_safe(self):
+ self.assertTrue(isinstance(_mysql.thread_safe(), int))
+
+
+class CoreAPI(unittest.TestCase):
+ """Test _mysql interaction internals."""
+
+ def setUp(self):
+ kwargs = base.PyMySQLTestCase.databases[0].copy()
+ kwargs["read_default_file"] = "~/.my.cnf"
+ self.conn = _mysql.connect(**kwargs)
+
+ def tearDown(self):
+ self.conn.close()
+
+ def test_thread_id(self):
+ tid = self.conn.thread_id()
+ self.assertTrue(isinstance(tid, int),
+ "thread_id didn't return an int.")
+
+ self.assertRaises(TypeError, self.conn.thread_id, ('evil',),
+ "thread_id shouldn't accept arguments.")
+
+ def test_affected_rows(self):
+ self.assertEquals(self.conn.affected_rows(), 0,
+ "Should return 0 before we do anything.")
+
+
+ #def test_debug(self):
+ ## FIXME Only actually tests if you lack SUPER
+ #self.assertRaises(pymysql.OperationalError,
+ #self.conn.dump_debug_info)
+
+ def test_charset_name(self):
+ self.assertTrue(isinstance(self.conn.character_set_name(), str),
+ "Should return a string.")
+
+ def test_host_info(self):
+ self.assertTrue(isinstance(self.conn.get_host_info(), str),
+ "Should return a string.")
+
+ def test_proto_info(self):
+ self.assertTrue(isinstance(self.conn.get_proto_info(), int),
+ "Should return an int.")
+
+ def test_server_info(self):
+ self.assertTrue(isinstance(self.conn.get_server_info(), basestring),
+ "Should return an str.")
+
+if __name__ == "__main__":
+ unittest.main()
diff --git a/tools/testClient/pymysql/times.py b/tools/testClient/pymysql/times.py
new file mode 100644
index 00000000000..c47db09eb9c
--- /dev/null
+++ b/tools/testClient/pymysql/times.py
@@ -0,0 +1,16 @@
+from time import localtime
+from datetime import date, datetime, time, timedelta
+
+Date = date
+Time = time
+TimeDelta = timedelta
+Timestamp = datetime
+
+def DateFromTicks(ticks):
+ return date(*localtime(ticks)[:3])
+
+def TimeFromTicks(ticks):
+ return time(*localtime(ticks)[3:6])
+
+def TimestampFromTicks(ticks):
+ return datetime(*localtime(ticks)[:6])
diff --git a/tools/testClient/pymysql/util.py b/tools/testClient/pymysql/util.py
new file mode 100644
index 00000000000..cc622e57b74
--- /dev/null
+++ b/tools/testClient/pymysql/util.py
@@ -0,0 +1,19 @@
+import struct
+
+def byte2int(b):
+ if isinstance(b, int):
+ return b
+ else:
+ return struct.unpack("!B", b)[0]
+
+def int2byte(i):
+ return struct.pack("!B", i)
+
+def join_bytes(bs):
+ if len(bs) == 0:
+ return ""
+ else:
+ rv = bs[0]
+ for b in bs[1:]:
+ rv += b
+ return rv
diff --git a/tools/testClient/remoteSSHClient.py b/tools/testClient/remoteSSHClient.py
new file mode 100644
index 00000000000..86a7e5c3747
--- /dev/null
+++ b/tools/testClient/remoteSSHClient.py
@@ -0,0 +1,36 @@
+import paramiko
+import cloudstackException
+class remoteSSHClient(object):
+ def __init__(self, host, port, user, passwd):
+ self.host = host
+ self.port = port
+ self.user = user
+ self.passwd = passwd
+ self.ssh = paramiko.SSHClient()
+ self.ssh.set_missing_host_key_policy(paramiko.AutoAddPolicy())
+ try:
+ self.ssh.connect(str(host),int(port), user, passwd)
+ except paramiko.SSHException, sshex:
+ raise cloudstackException.InvalidParameterException(repr(sshex))
+
+ def execute(self, command):
+ stdin, stdout, stderr = self.ssh.exec_command(command)
+ output = stdout.readlines()
+ errors = stderr.readlines()
+ results = []
+ if output is not None and len(output) == 0:
+ if errors is not None and len(errors) > 0:
+ for error in errors:
+ results.append(error.rstrip())
+
+ else:
+ for strOut in output:
+ results.append(strOut.rstrip())
+
+ return results
+
+
+if __name__ == "__main__":
+ ssh = remoteSSHClient("192.168.137.2", 22, "root", "password")
+ print ssh.execute("ls -l")
+ print ssh.execute("rm x")
\ No newline at end of file
diff --git a/tools/testClient/unitTest/test_advanced.py b/tools/testClient/unitTest/test_advanced.py
new file mode 100644
index 00000000000..c8992bbe893
--- /dev/null
+++ b/tools/testClient/unitTest/test_advanced.py
@@ -0,0 +1,95 @@
+import xml.dom.minidom
+import inspect
+import uuid
+import time
+from cloudstackAPI import *
+import cloudstackTestClient
+if __name__ == "__main__":
+ randomName = str(uuid.uuid4())
+ testClient = cloudstackTestClient.cloudstackTestClient("localhost")
+ api = testClient.getApiClient()
+
+ czcmd = createZone.createZoneCmd()
+ czcmd.dns1 = "8.8.8.8"
+ czcmd.internaldns1 = "192.168.110.254"
+ czcmd.name = "test" + randomName
+ czcmd.networktype = "Advanced"
+ czcmd.guestcidraddress = "10.1.1.0/24"
+ czcmd.vlan = "1160-1200"
+ czresponse = api.createZone(czcmd)
+ zoneId = czresponse.id
+
+ cvlancmd = createVlanIpRange.createVlanIpRangeCmd()
+ cvlancmd.zoneid = zoneId
+ cvlancmd.vlan = "2020"
+ cvlancmd.gateway = "172.16.112.1"
+ cvlancmd.netmask = "255.255.0.0"
+ cvlancmd.startip = "172.16.112.2"
+ cvlancmd.endip = "172.16.112.100"
+ api.createVlanIpRange(cvlancmd)
+
+ cpodcmd = createPod.createPodCmd()
+ cpodcmd.zoneid = zoneId
+ cpodcmd.gateway = "192.168.137.1"
+ cpodcmd.name = "testpod"+ randomName
+ cpodcmd.netmask = "255.255.255.0"
+ cpodcmd.startip = "192.168.137.200"
+ cpodcmd.endip = "192.168.137.220"
+ cpodresponse = api.createPod(cpodcmd)
+ podId = cpodresponse.id
+
+ aclustercmd = addCluster.addClusterCmd()
+ aclustercmd.clustername = "testcluster"+ randomName
+ aclustercmd.hypervisor = "KVM"
+ aclustercmd.podid = podId
+ aclustercmd.zoneid = zoneId
+ aclustercmd.clustertype = "CloudManaged"
+ clusterresponse = api.addCluster(aclustercmd)
+ clusterId = clusterresponse[0].id
+ '''
+ for i in range(1):
+ addhostcmd = addHost.addHostCmd()
+ addhostcmd.zoneid = zoneId
+ addhostcmd.podid = podId
+ addhostcmd.clusterid = clusterId
+ addhostcmd.hypervisor = "KVM"
+ addhostcmd.username = "root"
+ addhostcmd.password = "password"
+ addhostcmd.url = "http://192.168.137.2"
+ addhostresponse = api.addHost(addhostcmd)
+ print addhostresponse[0].id, addhostresponse[0].ipaddress
+ '''
+ createspcmd = createStoragePool.createStoragePoolCmd()
+ createspcmd.zoneid = zoneId
+ createspcmd.podid = podId
+ createspcmd.clusterid = clusterId
+ createspcmd.url = "nfs://nfs2.lab.vmops.com/export/home/edison/kvm2"
+ createspcmd.name = "storage pool" + randomName
+ createspresponse = api.createStoragePool(createspcmd)
+
+ addsscmd = addSecondaryStorage.addSecondaryStorageCmd()
+ addsscmd.url = "nfs://nfs2.lab.vmops.com/export/home/edison/xen/secondary"
+ addsscmd.zoneid = zoneId
+ api.addSecondaryStorage(addsscmd)
+
+ listtmcmd = listTemplates.listTemplatesCmd()
+ listtmcmd.id = 4
+ listtmcmd.zoneid = zoneId
+ listtmcmd.templatefilter = "featured"
+ listtmresponse = api.listTemplates(listtmcmd)
+ while True:
+ if listtmresponse is not None and listtmresponse[0].isready == "true":
+ break
+ time.sleep(30)
+ listtmresponse = api.listTemplates(listtmcmd)
+
+ vmId = []
+ for i in range(2):
+ cmd = deployVirtualMachine.deployVirtualMachineCmd()
+ cmd.zoneid = zoneId
+ cmd.hypervisor = "KVM"
+ cmd.serviceofferingid = "1"
+ cmd.templateid = listtmresponse[0].id
+ res = api.deployVirtualMachine(cmd)
+
+ vmId.append(res.id)
\ No newline at end of file
diff --git a/tools/testClient/unitTest/test_async.py b/tools/testClient/unitTest/test_async.py
new file mode 100644
index 00000000000..96018f6bc3e
--- /dev/null
+++ b/tools/testClient/unitTest/test_async.py
@@ -0,0 +1,49 @@
+from cloudstackAPI import *
+import cloudstackException
+import cloudstackTestClient
+import sys
+
+
+class jobs():
+ def __init__(self, zoneId):
+ self.zoneId = zoneId
+
+ def run(self):
+ try:
+ cmd = destroyVirtualMachine.destroyVirtualMachineCmd()
+ cmd.id = 4
+ self.apiClient.destroyVirtualMachine(cmd)
+ except cloudstackException.cloudstackAPIException, e:
+ print str(e)
+ except :
+ print sys.exc_info()
+
+if __name__ == "__main__":
+ ''' to logging the testclient
+ logger = logging.getLogger("test_async")
+ fh = logging.FileHandler("test.log")
+ logger.addHandler(fh)
+ logger.setLevel(logging.DEBUG)
+ testclient = cloudstackTestClient.cloudstackTestClient(mgtSvr="localhost", logging=logger)
+ '''
+ testclient = cloudstackTestClient.cloudstackTestClient(mgtSvr="localhost")
+ testclient.dbConfigure()
+ api = testclient.getApiClient()
+
+ #testclient.submitJobs(jobs(1), 10, 10, 1)
+
+ cmds = []
+ for i in range(20):
+ cmd = destroyVirtualMachine.destroyVirtualMachineCmd()
+ cmd.id = 4 + i
+ cmds.append(cmd)
+
+ asyncJobResult = testclient.submitCmdsAndWait(cmds)
+
+ for handle, jobStatus in asyncJobResult.iteritems():
+ if jobStatus.status:
+ print jobStatus.result.id, jobStatus.result.templatename, jobStatus.startTime, jobStatus.endTime
+ else:
+ print jobStatus.result, jobStatus.startTime, jobStatus.endTime
+
+ print jobStatus.duration
\ No newline at end of file
diff --git a/tools/testClient/unitTest/test_basic_zone.py b/tools/testClient/unitTest/test_basic_zone.py
new file mode 100644
index 00000000000..d53daf2cc70
--- /dev/null
+++ b/tools/testClient/unitTest/test_basic_zone.py
@@ -0,0 +1,215 @@
+import uuid
+from cloudstackAPI import *
+import cloudstackException
+import cloudstackTestClient
+import time
+if __name__ == "__main__":
+ hypervisor = "KVM"
+ hostNum = 30
+ templateId = 10
+ vmNum = 10
+
+ randomName = str(uuid.uuid4())
+
+ testClient = cloudstackTestClient.cloudstackTestClient("localhost")
+ api = testClient.getApiClient()
+
+
+ czcmd = createZone.createZoneCmd()
+ czcmd.dns1 = "8.8.8.8"
+ czcmd.internaldns1 = "192.168.110.254"
+ czcmd.name = "test" + randomName
+ czcmd.networktype = "Basic"
+
+ czresponse = api.createZone(czcmd)
+ zoneId = czresponse.id
+
+ cpodcmd = createPod.createPodCmd()
+ cpodcmd.zoneid = zoneId
+ cpodcmd.gateway = "192.168.137.1"
+ cpodcmd.name = "testpod"+ randomName
+ cpodcmd.netmask = "255.255.255.0"
+ cpodcmd.startip = "192.168.137.200"
+ cpodcmd.endip = "192.168.137.230"
+ cpodresponse = api.createPod(cpodcmd)
+ podId = cpodresponse.id
+
+ cvlancmd = createVlanIpRange.createVlanIpRangeCmd()
+ cvlancmd.zoneid = zoneId
+ cvlancmd.podid = podId
+ cvlancmd.gateway = "192.168.137.1"
+ cvlancmd.netmask = "255.255.255.0"
+ cvlancmd.startip = "192.168.137.100"
+ cvlancmd.endip = "192.168.137.190"
+ cvlancmd.forvirtualnetwork = "false"
+ cvlancmd.vlan = "untagged"
+
+ api.createVlanIpRange(cvlancmd)
+
+ aclustercmd = addCluster.addClusterCmd()
+ aclustercmd.clustername = "testcluster"+ randomName
+ aclustercmd.hypervisor = hypervisor
+ aclustercmd.podid = podId
+ aclustercmd.zoneid = zoneId
+ aclustercmd.clustertype = "CloudManaged"
+ clusterresponse = api.addCluster(aclustercmd)
+ clusterId = clusterresponse[0].id
+ '''
+ for i in range(hostNum):
+ addhostcmd = addHost.addHostCmd()
+ addhostcmd.zoneid = zoneId
+ addhostcmd.podid = podId
+ addhostcmd.clusterid = clusterId
+ addhostcmd.hypervisor = hypervisor
+ addhostcmd.username = "root"
+ addhostcmd.password = "password"
+ if hypervisor == "Simulator":
+ addhostcmd.url = "http://sim"
+ else:
+ addhostcmd.url = "http://192.168.137.4"
+ addhostresponse = api.addHost(addhostcmd)
+ print addhostresponse[0].id, addhostresponse[0].ipaddress
+
+ createspcmd = createStoragePool.createStoragePoolCmd()
+ createspcmd.zoneid = zoneId
+ createspcmd.podid = podId
+ createspcmd.clusterid = clusterId
+ createspcmd.url = "nfs://nfs2.lab.vmops.com/export/home/edison/primary"
+ createspcmd.name = "storage pool" + randomName
+ createspresponse = api.createStoragePool(createspcmd)
+
+ addsscmd = addSecondaryStorage.addSecondaryStorageCmd()
+ addsscmd.url = "nfs://nfs2.lab.vmops.com/export/home/edison/xen/secondary"
+ addsscmd.zoneid = zoneId
+ api.addSecondaryStorage(addsscmd)
+
+ listtmcmd = listTemplates.listTemplatesCmd()
+ listtmcmd.id = templateId
+ listtmcmd.zoneid = zoneId
+ listtmcmd.templatefilter = "featured"
+ listtmresponse = api.listTemplates(listtmcmd)
+ while True:
+ if listtmresponse is not None and listtmresponse[0].isready == "true":
+ break
+ time.sleep(30)
+ listtmresponse = api.listTemplates(listtmcmd)
+
+ vmId = []
+ for i in range(vmNum):
+ cmd = deployVirtualMachine.deployVirtualMachineCmd()
+ cmd.zoneid = zoneId
+ cmd.hypervisor = hypervisor
+ cmd.serviceofferingid = "1"
+ cmd.templateid = listtmresponse[0].id
+ res = api.deployVirtualMachine(cmd)
+
+ vmId.append(res.id)
+
+
+ registerTempl = registerTemplate.registerTemplateCmd()
+ registerTempl.displaytext = "test template4"
+ registerTempl.format = "QCOW2"
+ registerTempl.hypervisor = "Simulator"
+ registerTempl.name = "test template4"
+ registerTempl.ostypeid = "100"
+ registerTempl.url = "http://www.google.com/template.qcow2"
+ registerTempl.zoneid = 1
+ registerTemlResponse = api.registerTemplate(registerTempl)
+ newtemplateId = registerTemlResponse[0].id
+
+ listtempl = listTemplates.listTemplatesCmd()
+ listtempl.id = newtemplateId
+ listtempl.templatefilter = "self"
+ listemplResponse = api.listTemplates(listtempl)
+ while True:
+ if listemplResponse is not None:
+
+ if listemplResponse[0].isready == "true":
+ break
+ else:
+ print listemplResponse[0].status
+
+ time.sleep(30)
+ listemplResponse = api.listTemplates(listtempl)
+
+
+
+ for i in range(10):
+ cmd = deployVirtualMachine.deployVirtualMachineCmd()
+ cmd.zoneid = 1
+ cmd.hypervisor = hypervisor
+ cmd.serviceofferingid = "1"
+ #cmd.templateid = listemplResponse[0].id
+ cmd.templateid = 200
+ res = api.deployVirtualMachine(cmd)
+
+
+ createvolume = createVolume.createVolumeCmd()
+ createvolume.zoneid = 1
+ createvolume.diskofferingid = 9
+ createvolume.name = "test"
+
+ createvolumeresponse = api.createVolume(createvolume)
+ volumeId = createvolumeresponse.id
+
+ attachvolume = attachVolume.attachVolumeCmd()
+ attachvolume.id = volumeId
+ attachvolume.virtualmachineid = 9
+ attachvolumeresponse = api.attachVolume(attachvolume)
+
+ deattachevolume = detachVolume.detachVolumeCmd()
+ deattachevolume.id = volumeId
+ deattachvolumeresponse = api.detachVolume(deattachevolume)
+
+ createsnapshot = createSnapshot.createSnapshotCmd()
+ createsnapshot.volumeid = volumeId
+ createsnapshotresponse = api.createSnapshot(createsnapshot)
+ snapshotId = createsnapshotresponse.id
+
+ createtmpl = createTemplate.createTemplateCmd()
+ createtmpl.snapshotid = snapshotId
+ createtmpl.name = randomName[:10]
+ createtmpl.displaytext = randomName[:10]
+ createtmpl.ostypeid = 100
+ createtmpl.ispublic = "false"
+ createtmpl.passwordenabled = "false"
+ createtmpl.isfeatured = "false"
+ createtmplresponse = api.createTemplate(createtmpl)
+ templateId = createtmplresponse.id
+
+ createvolume = createVolume.createVolumeCmd()
+ createvolume.snapshotid = snapshotId
+ createvolume.name = "test"
+ createvolumeresponse = api.createVolume(createvolume)
+ volumeId = createvolumeresponse.id
+
+ cmd = deployVirtualMachine.deployVirtualMachineCmd()
+ cmd.zoneid = 1
+ cmd.hypervisor = hypervisor
+ cmd.serviceofferingid = "1"
+ cmd.templateid = templateId
+ cmd.name = "fdf"
+ res = api.deployVirtualMachine(cmd)
+
+ attachvolume = attachVolume.attachVolumeCmd()
+ attachvolume.id = volumeId
+ attachvolume.virtualmachineid = 1
+ attachvolumeresponse = api.attachVolume(attachvolume)
+
+ deattachevolume = detachVolume.detachVolumeCmd()
+ deattachevolume.id = volumeId
+ deattachvolumeresponse = api.detachVolume(deattachevolume)
+
+ deletetmpl = deleteTemplate.deleteTemplateCmd()
+ deletetmpl.id = templateId
+ deletetmpl.zoneid = 1
+ api.deleteTemplate(deletetmpl)
+
+ deletevolume = deleteVolume.deleteVolumeCmd()
+ deletevolume.id = volumeId
+ api.deleteVolume(deletevolume)
+
+ deletesnapshot = deleteSnapshot.deleteSnapshotCmd()
+ deletesnapshot.id = snapshotId
+
+ '''
\ No newline at end of file
diff --git a/tools/testClient/unitTest/test_local_storage.py b/tools/testClient/unitTest/test_local_storage.py
new file mode 100644
index 00000000000..70dd864bb0d
--- /dev/null
+++ b/tools/testClient/unitTest/test_local_storage.py
@@ -0,0 +1,243 @@
+import uuid
+import time
+import cloudstackAPIException
+import cloudstackTestClient
+from cloudstackAPI import *
+if __name__ == "__main__":
+ hypervisor = "Simulator"
+ hostNum = 100
+ clusterNum = hostNum/10
+ templateId = 10
+ vmNum = 10
+
+ randomName = str(uuid.uuid4())
+
+ testclient = cloudstackTestClient.cloudstackTestClient("localhost")
+ api = testclient.getApiClient()
+
+
+ updatecfg = updateConfiguration.updateConfigurationCmd()
+ updatecfg.name = "expunge.delay"
+ updatecfg.value = "100"
+ ret = api.updateConfiguration(updatecfg)
+
+ updatecfg.name = "expunge.interval"
+ updatecfg.value = "100"
+ ret = api.updateConfiguration(updatecfg)
+
+ updatecfg.name = "ping.interval"
+ updatecfg.value = "180"
+ ret = api.updateConfiguration(updatecfg)
+
+ updatecfg.name = "system.vm.use.local.storage"
+ updatecfg.value = "true"
+ ret = api.updateConfiguration(updatecfg)
+
+ updatecfg.name = "use.local.storage"
+ updatecfg.value = "true"
+ ret = api.updateConfiguration(updatecfg)
+
+
+
+ czcmd = createZone.createZoneCmd()
+ czcmd.dns1 = "8.8.8.8"
+ czcmd.internaldns1 = "192.168.110.254"
+ czcmd.name = "test" + randomName
+ czcmd.networktype = "Basic"
+
+ czresponse = api.createZone(czcmd)
+ zoneId = czresponse.id
+
+ cpodcmd = createPod.createPodCmd()
+ cpodcmd.zoneid = zoneId
+ cpodcmd.gateway = "192.168.137.1"
+ cpodcmd.name = "testpod"+ randomName
+ cpodcmd.netmask = "255.255.255.0"
+ cpodcmd.startip = "192.168.137.200"
+ cpodcmd.endip = "192.168.137.230"
+ cpodresponse = api.createPod(cpodcmd)
+ podId = cpodresponse.id
+
+ cvlancmd = createVlanIpRange.createVlanIpRangeCmd()
+ cvlancmd.zoneid = zoneId
+ cvlancmd.podid = podId
+ cvlancmd.gateway = "192.168.137.1"
+ cvlancmd.netmask = "255.255.255.0"
+ cvlancmd.startip = "192.168.137.100"
+ cvlancmd.endip = "192.168.137.190"
+ cvlancmd.forvirtualnetwork = "false"
+ cvlancmd.vlan = "untagged"
+
+ api.createVlanIpRange(cvlancmd)
+
+ aclustercmd = addCluster.addClusterCmd()
+ aclustercmd.clustername = "testcluster"+ randomName
+ aclustercmd.hypervisor = hypervisor
+ aclustercmd.podid = podId
+ aclustercmd.zoneid = zoneId
+ aclustercmd.clustertype = "CloudManaged"
+ clusterresponse = api.addCluster(aclustercmd)
+ clusterId = clusterresponse[0].id
+
+ for i in range(hostNum):
+ addhostcmd = addHost.addHostCmd()
+ addhostcmd.zoneid = zoneId
+ addhostcmd.podid = podId
+ addhostcmd.clusterid = clusterId
+ addhostcmd.hypervisor = hypervisor
+ addhostcmd.username = "root"
+ addhostcmd.password = "password"
+ if hypervisor == "Simulator":
+ addhostcmd.url = "http://sim"
+ else:
+ addhostcmd.url = "http://192.168.137.4"
+ addhostresponse = api.addHost(addhostcmd)
+ print addhostresponse[0].id, addhostresponse[0].ipaddress
+
+
+ createspcmd = createStoragePool.createStoragePoolCmd()
+ createspcmd.zoneid = zoneId
+ createspcmd.podid = podId
+ createspcmd.clusterid = clusterId
+ createspcmd.url = "nfs://nfs2.lab.vmops.com/export/home/edison/primary"
+ createspcmd.name = "storage pool" + randomName
+ createspresponse = api.createStoragePool(createspcmd)
+
+
+
+ addsscmd = addSecondaryStorage.addSecondaryStorageCmd()
+ addsscmd.url = "nfs://nfs2.lab.vmops.com/export/home/edison/xen/secondary"
+ addsscmd.zoneid = zoneId
+ api.addSecondaryStorage(addsscmd)
+
+ listtmcmd = listTemplates.listTemplatesCmd()
+ listtmcmd.id = templateId
+ listtmcmd.zoneid = zoneId
+ listtmcmd.templatefilter = "featured"
+ listtmresponse = api.listTemplates(listtmcmd)
+ while True:
+ if listtmresponse is not None and listtmresponse[0].isready == "true":
+ break
+ time.sleep(30)
+ listtmresponse = api.listTemplates(listtmcmd)
+
+ vmId = []
+ for i in range(vmNum):
+ cmd = deployVirtualMachine.deployVirtualMachineCmd()
+ cmd.zoneid = zoneId
+ cmd.hypervisor = hypervisor
+ cmd.serviceofferingid = "1"
+ cmd.templateid = listtmresponse[0].id
+ res = api.deployVirtualMachine(cmd)
+
+ vmId.append(res.id)
+
+ registerTempl = registerTemplate.registerTemplateCmd()
+ registerTempl.displaytext = "test template4"
+ registerTempl.format = "QCOW2"
+ registerTempl.hypervisor = "Simulator"
+ registerTempl.name = "test template4"
+ registerTempl.ostypeid = "100"
+ registerTempl.url = "http://www.google.com/template.qcow2"
+ registerTempl.zoneid = 1
+ registerTemlResponse = api.registerTemplate(registerTempl)
+ newtemplateId = registerTemlResponse[0].id
+
+ listtempl = listTemplates.listTemplatesCmd()
+ listtempl.id = newtemplateId
+ listtempl.templatefilter = "self"
+ listemplResponse = api.listTemplates(listtempl)
+ while True:
+ if listemplResponse is not None:
+
+ if listemplResponse[0].isready == "true":
+ break
+ else:
+ print listemplResponse[0].status
+
+ time.sleep(30)
+ listemplResponse = api.listTemplates(listtempl)
+
+
+
+ for i in range(10):
+ cmd = deployVirtualMachine.deployVirtualMachineCmd()
+ cmd.zoneid = 1
+ cmd.hypervisor = hypervisor
+ cmd.serviceofferingid = "1"
+ #cmd.templateid = listemplResponse[0].id
+ cmd.templateid = 200
+ res = api.deployVirtualMachine(cmd)
+
+
+ createvolume = createVolume.createVolumeCmd()
+ createvolume.zoneid = 1
+ createvolume.diskofferingid = 9
+ createvolume.name = "test"
+
+ createvolumeresponse = api.createVolume(createvolume)
+ volumeId = createvolumeresponse.id
+
+ attachvolume = attachVolume.attachVolumeCmd()
+ attachvolume.id = volumeId
+ attachvolume.virtualmachineid = 9
+ attachvolumeresponse = api.attachVolume(attachvolume)
+
+ deattachevolume = detachVolume.detachVolumeCmd()
+ deattachevolume.id = volumeId
+ deattachvolumeresponse = api.detachVolume(deattachevolume)
+
+ createsnapshot = createSnapshot.createSnapshotCmd()
+ createsnapshot.volumeid = volumeId
+ createsnapshotresponse = api.createSnapshot(createsnapshot)
+ snapshotId = createsnapshotresponse.id
+
+ createtmpl = createTemplate.createTemplateCmd()
+ createtmpl.snapshotid = snapshotId
+ createtmpl.name = randomName[:10]
+ createtmpl.displaytext = randomName[:10]
+ createtmpl.ostypeid = 100
+ createtmpl.ispublic = "false"
+ createtmpl.passwordenabled = "false"
+ createtmpl.isfeatured = "false"
+ createtmplresponse = api.createTemplate(createtmpl)
+ templateId = createtmplresponse.id
+
+ createvolume = createVolume.createVolumeCmd()
+ createvolume.snapshotid = snapshotId
+ createvolume.name = "test"
+ createvolumeresponse = api.createVolume(createvolume)
+ volumeId = createvolumeresponse.id
+
+ cmd = deployVirtualMachine.deployVirtualMachineCmd()
+ cmd.zoneid = 1
+ cmd.hypervisor = hypervisor
+ cmd.serviceofferingid = "1"
+ cmd.templateid = templateId
+ cmd.name = "fdf"
+ res = api.deployVirtualMachine(cmd)
+
+
+
+ attachvolume = attachVolume.attachVolumeCmd()
+ attachvolume.id = volumeId
+ attachvolume.virtualmachineid = 1
+ attachvolumeresponse = api.attachVolume(attachvolume)
+
+ deattachevolume = detachVolume.detachVolumeCmd()
+ deattachevolume.id = volumeId
+ deattachvolumeresponse = api.detachVolume(deattachevolume)
+
+ deletetmpl = deleteTemplate.deleteTemplateCmd()
+ deletetmpl.id = templateId
+ deletetmpl.zoneid = 1
+ api.deleteTemplate(deletetmpl)
+
+ deletevolume = deleteVolume.deleteVolumeCmd()
+ deletevolume.id = volumeId
+ api.deleteVolume(deletevolume)
+
+ deletesnapshot = deleteSnapshot.deleteSnapshotCmd()
+ deletesnapshot.id = snapshotId
+
+
\ No newline at end of file