mirror of
https://github.com/apache/cloudstack.git
synced 2025-10-26 08:42:29 +01:00
929 lines
31 KiB
Python
929 lines
31 KiB
Python
# Python implementation of the MySQL client-server protocol
|
|
# http://forge.mysql.com/wiki/MySQL_Internals_ClientServer_Protocol
|
|
|
|
try:
|
|
import hashlib
|
|
sha_new = lambda *args, **kwargs: hashlib.new("sha1", *args, **kwargs)
|
|
except ImportError:
|
|
import sha
|
|
sha_new = sha.new
|
|
|
|
import socket
|
|
try:
|
|
import ssl
|
|
SSL_ENABLED = True
|
|
except ImportError:
|
|
SSL_ENABLED = False
|
|
|
|
import struct
|
|
import sys
|
|
import os
|
|
import ConfigParser
|
|
|
|
try:
|
|
import cStringIO as StringIO
|
|
except ImportError:
|
|
import StringIO
|
|
|
|
from charset import MBLENGTH, charset_by_name, charset_by_id
|
|
from cursors import Cursor
|
|
from constants import FIELD_TYPE, FLAG
|
|
from constants import SERVER_STATUS
|
|
from constants.CLIENT import *
|
|
from constants.COMMAND import *
|
|
from util import join_bytes, byte2int, int2byte
|
|
from converters import escape_item, encoders, decoders
|
|
from err import raise_mysql_exception, Warning, Error, \
|
|
InterfaceError, DataError, DatabaseError, OperationalError, \
|
|
IntegrityError, InternalError, NotSupportedError, ProgrammingError
|
|
|
|
DEBUG = False
|
|
|
|
NULL_COLUMN = 251
|
|
UNSIGNED_CHAR_COLUMN = 251
|
|
UNSIGNED_SHORT_COLUMN = 252
|
|
UNSIGNED_INT24_COLUMN = 253
|
|
UNSIGNED_INT64_COLUMN = 254
|
|
UNSIGNED_CHAR_LENGTH = 1
|
|
UNSIGNED_SHORT_LENGTH = 2
|
|
UNSIGNED_INT24_LENGTH = 3
|
|
UNSIGNED_INT64_LENGTH = 8
|
|
|
|
DEFAULT_CHARSET = 'latin1'
|
|
|
|
|
|
def dump_packet(data):
|
|
|
|
def is_ascii(data):
|
|
if byte2int(data) >= 65 and byte2int(data) <= 122: #data.isalnum():
|
|
return data
|
|
return '.'
|
|
print "packet length %d" % len(data)
|
|
print "method call[1]: %s" % sys._getframe(1).f_code.co_name
|
|
print "method call[2]: %s" % sys._getframe(2).f_code.co_name
|
|
print "method call[3]: %s" % sys._getframe(3).f_code.co_name
|
|
print "method call[4]: %s" % sys._getframe(4).f_code.co_name
|
|
print "method call[5]: %s" % sys._getframe(5).f_code.co_name
|
|
print "-" * 88
|
|
dump_data = [data[i:i+16] for i in xrange(len(data)) if i%16 == 0]
|
|
for d in dump_data:
|
|
print ' '.join(map(lambda x:"%02X" % byte2int(x), d)) + \
|
|
' ' * (16 - len(d)) + ' ' * 2 + \
|
|
' '.join(map(lambda x:"%s" % is_ascii(x), d))
|
|
print "-" * 88
|
|
print ""
|
|
|
|
def _scramble(password, message):
|
|
if password == None or len(password) == 0:
|
|
return int2byte(0)
|
|
if DEBUG: print 'password=' + password
|
|
stage1 = sha_new(password).digest()
|
|
stage2 = sha_new(stage1).digest()
|
|
s = sha_new()
|
|
s.update(message)
|
|
s.update(stage2)
|
|
result = s.digest()
|
|
return _my_crypt(result, stage1)
|
|
|
|
def _my_crypt(message1, message2):
|
|
length = len(message1)
|
|
result = struct.pack('B', length)
|
|
for i in xrange(length):
|
|
x = (struct.unpack('B', message1[i:i+1])[0] ^ \
|
|
struct.unpack('B', message2[i:i+1])[0])
|
|
result += struct.pack('B', x)
|
|
return result
|
|
|
|
# old_passwords support ported from libmysql/password.c
|
|
SCRAMBLE_LENGTH_323 = 8
|
|
|
|
class RandStruct_323(object):
|
|
def __init__(self, seed1, seed2):
|
|
self.max_value = 0x3FFFFFFFL
|
|
self.seed1 = seed1 % self.max_value
|
|
self.seed2 = seed2 % self.max_value
|
|
|
|
def my_rnd(self):
|
|
self.seed1 = (self.seed1 * 3L + self.seed2) % self.max_value
|
|
self.seed2 = (self.seed1 + self.seed2 + 33L) % self.max_value
|
|
return float(self.seed1) / float(self.max_value)
|
|
|
|
def _scramble_323(password, message):
|
|
hash_pass = _hash_password_323(password)
|
|
hash_message = _hash_password_323(message[:SCRAMBLE_LENGTH_323])
|
|
hash_pass_n = struct.unpack(">LL", hash_pass)
|
|
hash_message_n = struct.unpack(">LL", hash_message)
|
|
|
|
rand_st = RandStruct_323(hash_pass_n[0] ^ hash_message_n[0],
|
|
hash_pass_n[1] ^ hash_message_n[1])
|
|
outbuf = StringIO.StringIO()
|
|
for _ in xrange(min(SCRAMBLE_LENGTH_323, len(message))):
|
|
outbuf.write(int2byte(int(rand_st.my_rnd() * 31) + 64))
|
|
extra = int2byte(int(rand_st.my_rnd() * 31))
|
|
out = outbuf.getvalue()
|
|
outbuf = StringIO.StringIO()
|
|
for c in out:
|
|
outbuf.write(int2byte(byte2int(c) ^ byte2int(extra)))
|
|
return outbuf.getvalue()
|
|
|
|
def _hash_password_323(password):
|
|
nr = 1345345333L
|
|
add = 7L
|
|
nr2 = 0x12345671L
|
|
|
|
for c in [byte2int(x) for x in password if x not in (' ', '\t')]:
|
|
nr^= (((nr & 63)+add)*c)+ (nr << 8) & 0xFFFFFFFF
|
|
nr2= (nr2 + ((nr2 << 8) ^ nr)) & 0xFFFFFFFF
|
|
add= (add + c) & 0xFFFFFFFF
|
|
|
|
r1 = nr & ((1L << 31) - 1L) # kill sign bits
|
|
r2 = nr2 & ((1L << 31) - 1L)
|
|
|
|
# pack
|
|
return struct.pack(">LL", r1, r2)
|
|
|
|
def pack_int24(n):
|
|
return struct.pack('BBB', n&0xFF, (n>>8)&0xFF, (n>>16)&0xFF)
|
|
|
|
def unpack_uint16(n):
|
|
return struct.unpack('<H', n[0:2])[0]
|
|
|
|
|
|
# TODO: stop using bit-shifting in these functions...
|
|
# TODO: rename to "uint" to make it clear they're unsigned...
|
|
def unpack_int24(n):
|
|
return struct.unpack('B',n[0])[0] + (struct.unpack('B', n[1])[0] << 8) +\
|
|
(struct.unpack('B',n[2])[0] << 16)
|
|
|
|
def unpack_int32(n):
|
|
return struct.unpack('B',n[0])[0] + (struct.unpack('B', n[1])[0] << 8) +\
|
|
(struct.unpack('B',n[2])[0] << 16) + (struct.unpack('B', n[3])[0] << 24)
|
|
|
|
def unpack_int64(n):
|
|
return struct.unpack('B',n[0])[0] + (struct.unpack('B', n[1])[0]<<8) +\
|
|
(struct.unpack('B',n[2])[0] << 16) + (struct.unpack('B',n[3])[0]<<24)+\
|
|
(struct.unpack('B',n[4])[0] << 32) + (struct.unpack('B',n[5])[0]<<40)+\
|
|
(struct.unpack('B',n[6])[0] << 48) + (struct.unpack('B',n[7])[0]<<56)
|
|
|
|
def defaulterrorhandler(connection, cursor, errorclass, errorvalue):
|
|
err = errorclass, errorvalue
|
|
if DEBUG:
|
|
raise
|
|
|
|
if cursor:
|
|
cursor.messages.append(err)
|
|
else:
|
|
connection.messages.append(err)
|
|
del cursor
|
|
del connection
|
|
|
|
if not issubclass(errorclass, Error):
|
|
raise Error(errorclass, errorvalue)
|
|
else:
|
|
raise errorclass, errorvalue
|
|
|
|
|
|
class MysqlPacket(object):
|
|
"""Representation of a MySQL response packet. Reads in the packet
|
|
from the network socket, removes packet header and provides an interface
|
|
for reading/parsing the packet results."""
|
|
|
|
def __init__(self, connection):
|
|
self.connection = connection
|
|
self.__position = 0
|
|
self.__recv_packet()
|
|
|
|
def __recv_packet(self):
|
|
"""Parse the packet header and read entire packet payload into buffer."""
|
|
packet_header = self.connection.rfile.read(4)
|
|
if len(packet_header) < 4:
|
|
raise OperationalError(2013, "Lost connection to MySQL server during query")
|
|
|
|
if DEBUG: dump_packet(packet_header)
|
|
packet_length_bin = packet_header[:3]
|
|
self.__packet_number = byte2int(packet_header[3])
|
|
# TODO: check packet_num is correct (+1 from last packet)
|
|
|
|
bin_length = packet_length_bin + int2byte(0) # pad little-endian number
|
|
bytes_to_read = struct.unpack('<I', bin_length)[0]
|
|
recv_data = self.connection.rfile.read(bytes_to_read)
|
|
if len(recv_data) < bytes_to_read:
|
|
raise OperationalError(2013, "Lost connection to MySQL server during query")
|
|
if DEBUG: dump_packet(recv_data)
|
|
self.__data = recv_data
|
|
|
|
def packet_number(self): return self.__packet_number
|
|
|
|
def get_all_data(self): return self.__data
|
|
|
|
def read(self, size):
|
|
"""Read the first 'size' bytes in packet and advance cursor past them."""
|
|
result = self.peek(size)
|
|
self.advance(size)
|
|
return result
|
|
|
|
def read_all(self):
|
|
"""Read all remaining data in the packet.
|
|
|
|
(Subsequent read() or peek() will return errors.)
|
|
"""
|
|
result = self.__data[self.__position:]
|
|
self.__position = None # ensure no subsequent read() or peek()
|
|
return result
|
|
|
|
def advance(self, length):
|
|
"""Advance the cursor in data buffer 'length' bytes."""
|
|
new_position = self.__position + length
|
|
if new_position < 0 or new_position > len(self.__data):
|
|
raise Exception('Invalid advance amount (%s) for cursor. '
|
|
'Position=%s' % (length, new_position))
|
|
self.__position = new_position
|
|
|
|
def rewind(self, position=0):
|
|
"""Set the position of the data buffer cursor to 'position'."""
|
|
if position < 0 or position > len(self.__data):
|
|
raise Exception("Invalid position to rewind cursor to: %s." % position)
|
|
self.__position = position
|
|
|
|
def peek(self, size):
|
|
"""Look at the first 'size' bytes in packet without moving cursor."""
|
|
result = self.__data[self.__position:(self.__position+size)]
|
|
if len(result) != size:
|
|
error = ('Result length not requested length:\n'
|
|
'Expected=%s. Actual=%s. Position: %s. Data Length: %s'
|
|
% (size, len(result), self.__position, len(self.__data)))
|
|
if DEBUG:
|
|
print error
|
|
self.dump()
|
|
raise AssertionError(error)
|
|
return result
|
|
|
|
def get_bytes(self, position, length=1):
|
|
"""Get 'length' bytes starting at 'position'.
|
|
|
|
Position is start of payload (first four packet header bytes are not
|
|
included) starting at index '0'.
|
|
|
|
No error checking is done. If requesting outside end of buffer
|
|
an empty string (or string shorter than 'length') may be returned!
|
|
"""
|
|
return self.__data[position:(position+length)]
|
|
|
|
def read_length_coded_binary(self):
|
|
"""Read a 'Length Coded Binary' number from the data buffer.
|
|
|
|
Length coded numbers can be anywhere from 1 to 9 bytes depending
|
|
on the value of the first byte.
|
|
"""
|
|
c = byte2int(self.read(1))
|
|
if c == NULL_COLUMN:
|
|
return None
|
|
if c < UNSIGNED_CHAR_COLUMN:
|
|
return c
|
|
elif c == UNSIGNED_SHORT_COLUMN:
|
|
return unpack_uint16(self.read(UNSIGNED_SHORT_LENGTH))
|
|
elif c == UNSIGNED_INT24_COLUMN:
|
|
return unpack_int24(self.read(UNSIGNED_INT24_LENGTH))
|
|
elif c == UNSIGNED_INT64_COLUMN:
|
|
# TODO: what was 'longlong'? confirm it wasn't used?
|
|
return unpack_int64(self.read(UNSIGNED_INT64_LENGTH))
|
|
|
|
def read_length_coded_string(self):
|
|
"""Read a 'Length Coded String' from the data buffer.
|
|
|
|
A 'Length Coded String' consists first of a length coded
|
|
(unsigned, positive) integer represented in 1-9 bytes followed by
|
|
that many bytes of binary data. (For example "cat" would be "3cat".)
|
|
"""
|
|
length = self.read_length_coded_binary()
|
|
if length is None:
|
|
return None
|
|
return self.read(length)
|
|
|
|
def is_ok_packet(self):
|
|
return byte2int(self.get_bytes(0)) == 0
|
|
|
|
def is_eof_packet(self):
|
|
return byte2int(self.get_bytes(0)) == 254 # 'fe'
|
|
|
|
def is_resultset_packet(self):
|
|
field_count = byte2int(self.get_bytes(0))
|
|
return field_count >= 1 and field_count <= 250
|
|
|
|
def is_error_packet(self):
|
|
return byte2int(self.get_bytes(0)) == 255
|
|
|
|
def check_error(self):
|
|
if self.is_error_packet():
|
|
self.rewind()
|
|
self.advance(1) # field_count == error (we already know that)
|
|
errno = unpack_uint16(self.read(2))
|
|
if DEBUG: print "errno = %d" % errno
|
|
raise_mysql_exception(self.__data)
|
|
|
|
def dump(self):
|
|
dump_packet(self.__data)
|
|
|
|
|
|
class FieldDescriptorPacket(MysqlPacket):
|
|
"""A MysqlPacket that represents a specific column's metadata in the result.
|
|
|
|
Parsing is automatically done and the results are exported via public
|
|
attributes on the class such as: db, table_name, name, length, type_code.
|
|
"""
|
|
|
|
def __init__(self, *args):
|
|
MysqlPacket.__init__(self, *args)
|
|
self.__parse_field_descriptor()
|
|
|
|
def __parse_field_descriptor(self):
|
|
"""Parse the 'Field Descriptor' (Metadata) packet.
|
|
|
|
This is compatible with MySQL 4.1+ (not compatible with MySQL 4.0).
|
|
"""
|
|
self.catalog = self.read_length_coded_string()
|
|
self.db = self.read_length_coded_string()
|
|
self.table_name = self.read_length_coded_string()
|
|
self.org_table = self.read_length_coded_string()
|
|
self.name = self.read_length_coded_string().decode(self.connection.charset)
|
|
self.org_name = self.read_length_coded_string()
|
|
self.advance(1) # non-null filler
|
|
self.charsetnr = struct.unpack('<H', self.read(2))[0]
|
|
self.length = struct.unpack('<I', self.read(4))[0]
|
|
self.type_code = byte2int(self.read(1))
|
|
self.flags = struct.unpack('<H', self.read(2))[0]
|
|
self.scale = byte2int(self.read(1)) # "decimals"
|
|
self.advance(2) # filler (always 0x00)
|
|
|
|
# 'default' is a length coded binary and is still in the buffer?
|
|
# not used for normal result sets...
|
|
|
|
def description(self):
|
|
"""Provides a 7-item tuple compatible with the Python PEP249 DB Spec."""
|
|
desc = []
|
|
desc.append(self.name)
|
|
desc.append(self.type_code)
|
|
desc.append(None) # TODO: display_length; should this be self.length?
|
|
desc.append(self.get_column_length()) # 'internal_size'
|
|
desc.append(self.get_column_length()) # 'precision' # TODO: why!?!?
|
|
desc.append(self.scale)
|
|
|
|
# 'null_ok' -- can this be True/False rather than 1/0?
|
|
# if so just do: desc.append(bool(self.flags % 2 == 0))
|
|
if self.flags % 2 == 0:
|
|
desc.append(1)
|
|
else:
|
|
desc.append(0)
|
|
return tuple(desc)
|
|
|
|
def get_column_length(self):
|
|
if self.type_code == FIELD_TYPE.VAR_STRING:
|
|
mblen = MBLENGTH.get(self.charsetnr, 1)
|
|
return self.length // mblen
|
|
return self.length
|
|
|
|
def __str__(self):
|
|
return ('%s %s.%s.%s, type=%s'
|
|
% (self.__class__, self.db, self.table_name, self.name,
|
|
self.type_code))
|
|
|
|
|
|
class Connection(object):
|
|
"""
|
|
Representation of a socket with a mysql server.
|
|
|
|
The proper way to get an instance of this class is to call
|
|
connect()."""
|
|
errorhandler = defaulterrorhandler
|
|
|
|
def __init__(self, host="localhost", user=None, passwd="",
|
|
db=None, port=3306, unix_socket=None,
|
|
charset='', sql_mode=None,
|
|
read_default_file=None, conv=decoders, use_unicode=None,
|
|
client_flag=0, cursorclass=Cursor, init_command=None,
|
|
connect_timeout=None, ssl=None, read_default_group=None,
|
|
compress=None, named_pipe=None):
|
|
"""
|
|
Establish a connection to the MySQL database. Accepts several
|
|
arguments:
|
|
|
|
host: Host where the database server is located
|
|
user: Username to log in as
|
|
passwd: Password to use.
|
|
db: Database to use, None to not use a particular one.
|
|
port: MySQL port to use, default is usually OK.
|
|
unix_socket: Optionally, you can use a unix socket rather than TCP/IP.
|
|
charset: Charset you want to use.
|
|
sql_mode: Default SQL_MODE to use.
|
|
read_default_file: Specifies my.cnf file to read these parameters from under the [client] section.
|
|
conv: Decoders dictionary to use instead of the default one. This is used to provide custom marshalling of types. See converters.
|
|
use_unicode: Whether or not to default to unicode strings. This option defaults to true for Py3k.
|
|
client_flag: Custom flags to send to MySQL. Find potential values in constants.CLIENT.
|
|
cursorclass: Custom cursor class to use.
|
|
init_command: Initial SQL statement to run when connection is established.
|
|
connect_timeout: Timeout before throwing an exception when connecting.
|
|
ssl: A dict of arguments similar to mysql_ssl_set()'s parameters. For now the capath and cipher arguments are not supported.
|
|
read_default_group: Group to read from in the configuration file.
|
|
compress; Not supported
|
|
named_pipe: Not supported
|
|
"""
|
|
|
|
if use_unicode is None and sys.version_info[0] > 2:
|
|
use_unicode = True
|
|
|
|
if compress or named_pipe:
|
|
raise NotImplementedError, "compress and named_pipe arguments are not supported"
|
|
|
|
if ssl and (ssl.has_key('capath') or ssl.has_key('cipher')):
|
|
raise NotImplementedError, 'ssl options capath and cipher are not supported'
|
|
|
|
self.ssl = False
|
|
if ssl:
|
|
if not SSL_ENABLED:
|
|
raise NotImplementedError, "ssl module not found"
|
|
self.ssl = True
|
|
client_flag |= SSL
|
|
for k in ('key', 'cert', 'ca'):
|
|
v = None
|
|
if ssl.has_key(k):
|
|
v = ssl[k]
|
|
setattr(self, k, v)
|
|
|
|
if read_default_group and not read_default_file:
|
|
if sys.platform.startswith("win"):
|
|
read_default_file = "c:\\my.ini"
|
|
else:
|
|
read_default_file = "/etc/my.cnf"
|
|
|
|
if read_default_file:
|
|
if not read_default_group:
|
|
read_default_group = "client"
|
|
|
|
cfg = ConfigParser.RawConfigParser()
|
|
cfg.read(os.path.expanduser(read_default_file))
|
|
|
|
def _config(key, default):
|
|
try:
|
|
return cfg.get(read_default_group,key)
|
|
except:
|
|
return default
|
|
|
|
user = _config("user",user)
|
|
passwd = _config("password",passwd)
|
|
host = _config("host", host)
|
|
db = _config("db",db)
|
|
unix_socket = _config("socket",unix_socket)
|
|
port = _config("port", port)
|
|
charset = _config("default-character-set", charset)
|
|
|
|
self.host = host
|
|
self.port = port
|
|
self.user = user
|
|
self.password = passwd
|
|
self.db = db
|
|
self.unix_socket = unix_socket
|
|
if charset:
|
|
self.charset = charset
|
|
self.use_unicode = True
|
|
else:
|
|
self.charset = DEFAULT_CHARSET
|
|
self.use_unicode = False
|
|
|
|
if use_unicode is not None:
|
|
self.use_unicode = use_unicode
|
|
|
|
client_flag |= CAPABILITIES
|
|
client_flag |= MULTI_STATEMENTS
|
|
if self.db:
|
|
client_flag |= CONNECT_WITH_DB
|
|
self.client_flag = client_flag
|
|
|
|
self.cursorclass = cursorclass
|
|
self.connect_timeout = connect_timeout
|
|
|
|
self._connect()
|
|
|
|
self.messages = []
|
|
self.set_charset(charset)
|
|
self.encoders = encoders
|
|
self.decoders = conv
|
|
|
|
self._result = None
|
|
self._affected_rows = 0
|
|
self.host_info = "Not connected"
|
|
|
|
self.autocommit(False)
|
|
|
|
if sql_mode is not None:
|
|
c = self.cursor()
|
|
c.execute("SET sql_mode=%s", (sql_mode,))
|
|
|
|
self.commit()
|
|
|
|
if init_command is not None:
|
|
c = self.cursor()
|
|
c.execute(init_command)
|
|
|
|
self.commit()
|
|
|
|
|
|
def close(self):
|
|
''' Send the quit message and close the socket '''
|
|
if self.socket is None:
|
|
raise Error("Already closed")
|
|
send_data = struct.pack('<i',1) + int2byte(COM_QUIT)
|
|
self.wfile.write(send_data)
|
|
self.wfile.close()
|
|
self.rfile.close()
|
|
self.socket.close()
|
|
self.socket = None
|
|
self.rfile = None
|
|
self.wfile = None
|
|
|
|
def autocommit(self, value):
|
|
''' Set whether or not to commit after every execute() '''
|
|
try:
|
|
self._execute_command(COM_QUERY, "SET AUTOCOMMIT = %s" % \
|
|
self.escape(value))
|
|
self.read_packet()
|
|
except:
|
|
exc,value,tb = sys.exc_info()
|
|
self.errorhandler(None, exc, value)
|
|
|
|
def commit(self):
|
|
''' Commit changes to stable storage '''
|
|
try:
|
|
self._execute_command(COM_QUERY, "COMMIT")
|
|
self.read_packet()
|
|
except:
|
|
exc,value,tb = sys.exc_info()
|
|
self.errorhandler(None, exc, value)
|
|
|
|
def rollback(self):
|
|
''' Roll back the current transaction '''
|
|
try:
|
|
self._execute_command(COM_QUERY, "ROLLBACK")
|
|
self.read_packet()
|
|
except:
|
|
exc,value,tb = sys.exc_info()
|
|
self.errorhandler(None, exc, value)
|
|
|
|
def escape(self, obj):
|
|
''' Escape whatever value you pass to it '''
|
|
return escape_item(obj, self.charset)
|
|
|
|
def literal(self, obj):
|
|
''' Alias for escape() '''
|
|
return escape_item(obj, self.charset)
|
|
|
|
def cursor(self, cursor=None):
|
|
''' Create a new cursor to execute queries with '''
|
|
if cursor:
|
|
return cursor(self)
|
|
return self.cursorclass(self)
|
|
|
|
def __enter__(self):
|
|
''' Context manager that returns a Cursor '''
|
|
return self.cursor()
|
|
|
|
def __exit__(self, exc, value, traceback):
|
|
''' On successful exit, commit. On exception, rollback. '''
|
|
if exc:
|
|
self.rollback()
|
|
else:
|
|
self.commit()
|
|
|
|
# The following methods are INTERNAL USE ONLY (called from Cursor)
|
|
def query(self, sql):
|
|
if DEBUG:
|
|
print "sending query: %s" % sql
|
|
self._execute_command(COM_QUERY, sql)
|
|
self._affected_rows = self._read_query_result()
|
|
return self._affected_rows
|
|
|
|
def next_result(self):
|
|
self._affected_rows = self._read_query_result()
|
|
return self._affected_rows
|
|
|
|
def affected_rows(self):
|
|
return self._affected_rows
|
|
|
|
def kill(self, thread_id):
|
|
arg = struct.pack('<I', thread_id)
|
|
try:
|
|
self._execute_command(COM_PROCESS_KILL, arg)
|
|
except:
|
|
exc,value,tb = sys.exc_info()
|
|
self.errorhandler(None, exc, value)
|
|
return
|
|
pkt = self.read_packet()
|
|
return pkt.is_ok_packet()
|
|
|
|
def ping(self, reconnect=True):
|
|
''' Check if the server is alive '''
|
|
try:
|
|
self._execute_command(COM_PING, "")
|
|
except:
|
|
if reconnect:
|
|
self._connect()
|
|
return self.ping(False)
|
|
else:
|
|
exc,value,tb = sys.exc_info()
|
|
self.errorhandler(None, exc, value)
|
|
return
|
|
|
|
pkt = self.read_packet()
|
|
return pkt.is_ok_packet()
|
|
|
|
def set_charset(self, charset):
|
|
try:
|
|
if charset:
|
|
self._execute_command(COM_QUERY, "SET NAMES %s" %
|
|
self.escape(charset))
|
|
self.read_packet()
|
|
self.charset = charset
|
|
except:
|
|
exc,value,tb = sys.exc_info()
|
|
self.errorhandler(None, exc, value)
|
|
|
|
def _connect(self):
|
|
try:
|
|
if self.unix_socket and (self.host == 'localhost' or self.host == '127.0.0.1'):
|
|
sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
|
|
t = sock.gettimeout()
|
|
sock.settimeout(self.connect_timeout)
|
|
sock.connect(self.unix_socket)
|
|
sock.settimeout(t)
|
|
self.host_info = "Localhost via UNIX socket"
|
|
if DEBUG: print 'connected using unix_socket'
|
|
else:
|
|
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
|
|
t = sock.gettimeout()
|
|
sock.settimeout(self.connect_timeout)
|
|
sock.connect((self.host, self.port))
|
|
sock.settimeout(t)
|
|
self.host_info = "socket %s:%d" % (self.host, self.port)
|
|
if DEBUG: print 'connected using socket'
|
|
self.socket = sock
|
|
self.rfile = self.socket.makefile("rb")
|
|
self.wfile = self.socket.makefile("wb")
|
|
self._get_server_information()
|
|
self._request_authentication()
|
|
except socket.error, e:
|
|
raise OperationalError(2003, "Can't connect to MySQL server on %r (%s)" % (self.host, e.args[0]))
|
|
|
|
def read_packet(self, packet_type=MysqlPacket):
|
|
"""Read an entire "mysql packet" in its entirety from the network
|
|
and return a MysqlPacket type that represents the results."""
|
|
|
|
packet = packet_type(self)
|
|
packet.check_error()
|
|
return packet
|
|
|
|
def _read_query_result(self):
|
|
result = MySQLResult(self)
|
|
result.read()
|
|
self._result = result
|
|
return result.affected_rows
|
|
|
|
def insert_id(self):
|
|
if self._result:
|
|
return self._result.insert_id
|
|
else:
|
|
return 0
|
|
|
|
def _send_command(self, command, sql):
|
|
#send_data = struct.pack('<i', len(sql) + 1) + command + sql
|
|
# could probably be more efficient, at least it's correct
|
|
if not self.socket:
|
|
self.errorhandler(None, InterfaceError, "(0, '')")
|
|
|
|
if isinstance(sql, unicode):
|
|
sql = sql.encode(self.charset)
|
|
|
|
prelude = struct.pack('<i', len(sql)+1) + int2byte(command)
|
|
self.wfile.write(prelude + sql)
|
|
self.wfile.flush()
|
|
if DEBUG: dump_packet(prelude + sql)
|
|
|
|
def _execute_command(self, command, sql):
|
|
self._send_command(command, sql)
|
|
|
|
def _request_authentication(self):
|
|
self._send_authentication()
|
|
|
|
def _send_authentication(self):
|
|
self.client_flag |= CAPABILITIES
|
|
if self.server_version.startswith('5'):
|
|
self.client_flag |= MULTI_RESULTS
|
|
|
|
if self.user is None:
|
|
raise ValueError, "Did not specify a username"
|
|
|
|
charset_id = charset_by_name(self.charset).id
|
|
self.user = self.user.encode(self.charset)
|
|
|
|
data_init = struct.pack('<i', self.client_flag) + struct.pack("<I", 1) + \
|
|
int2byte(charset_id) + int2byte(0)*23
|
|
|
|
next_packet = 1
|
|
|
|
if self.ssl:
|
|
data = pack_int24(len(data_init)) + int2byte(next_packet) + data_init
|
|
next_packet += 1
|
|
|
|
if DEBUG: dump_packet(data)
|
|
|
|
self.wfile.write(data)
|
|
self.wfile.flush()
|
|
self.socket = ssl.wrap_self.socketet(self.socket, keyfile=self.key,
|
|
certfile=self.cert,
|
|
ssl_version=ssl.PROTOCOL_TLSv1,
|
|
cert_reqs=ssl.CERT_REQUIRED,
|
|
ca_certs=self.ca)
|
|
self.rfile = self.socket.makefile("rb")
|
|
self.wfile = self.socket.makefile("wb")
|
|
|
|
data = data_init + self.user+int2byte(0) + _scramble(self.password.encode(self.charset), self.salt)
|
|
|
|
if self.db:
|
|
self.db = self.db.encode(self.charset)
|
|
data += self.db + int2byte(0)
|
|
|
|
data = pack_int24(len(data)) + int2byte(next_packet) + data
|
|
next_packet += 2
|
|
|
|
if DEBUG: dump_packet(data)
|
|
|
|
self.wfile.write(data)
|
|
self.wfile.flush()
|
|
|
|
auth_packet = MysqlPacket(self)
|
|
auth_packet.check_error()
|
|
if DEBUG: auth_packet.dump()
|
|
|
|
# if old_passwords is enabled the packet will be 1 byte long and
|
|
# have the octet 254
|
|
|
|
if auth_packet.is_eof_packet():
|
|
# send legacy handshake
|
|
#raise NotImplementedError, "old_passwords are not supported. Check to see if mysqld was started with --old-passwords, if old-passwords=1 in a my.cnf file, or if there are some short hashes in your mysql.user table."
|
|
# TODO: is this the correct charset?
|
|
data = _scramble_323(self.password.encode(self.charset), self.salt.encode(self.charset)) + int2byte(0)
|
|
data = pack_int24(len(data)) + int2byte(next_packet) + data
|
|
|
|
self.wfile.write(data)
|
|
self.wfile.flush()
|
|
auth_packet = MysqlPacket(self)
|
|
auth_packet.check_error()
|
|
if DEBUG: auth_packet.dump()
|
|
|
|
|
|
# _mysql support
|
|
def thread_id(self):
|
|
return self.server_thread_id[0]
|
|
|
|
def character_set_name(self):
|
|
return self.charset
|
|
|
|
def get_host_info(self):
|
|
return self.host_info
|
|
|
|
def get_proto_info(self):
|
|
return self.protocol_version
|
|
|
|
def _get_server_information(self):
|
|
i = 0
|
|
packet = MysqlPacket(self)
|
|
data = packet.get_all_data()
|
|
|
|
if DEBUG: dump_packet(data)
|
|
#packet_len = byte2int(data[i:i+1])
|
|
#i += 4
|
|
self.protocol_version = byte2int(data[i:i+1])
|
|
|
|
i += 1
|
|
server_end = data.find(int2byte(0), i)
|
|
# TODO: is this the correct charset? should it be default_charset?
|
|
self.server_version = data[i:server_end].decode(self.charset)
|
|
|
|
i = server_end + 1
|
|
self.server_thread_id = struct.unpack('<h', data[i:i+2])
|
|
|
|
i += 4
|
|
self.salt = data[i:i+8]
|
|
|
|
i += 9
|
|
if len(data) >= i + 1:
|
|
i += 1
|
|
|
|
self.server_capabilities = struct.unpack('<h', data[i:i+2])[0]
|
|
|
|
i += 1
|
|
self.server_language = byte2int(data[i:i+1])
|
|
self.server_charset = charset_by_id(self.server_language).name
|
|
|
|
i += 16
|
|
if len(data) >= i+12-1:
|
|
rest_salt = data[i:i+12]
|
|
self.salt += rest_salt
|
|
|
|
def get_server_info(self):
|
|
return self.server_version
|
|
|
|
Warning = Warning
|
|
Error = Error
|
|
InterfaceError = InterfaceError
|
|
DatabaseError = DatabaseError
|
|
DataError = DataError
|
|
OperationalError = OperationalError
|
|
IntegrityError = IntegrityError
|
|
InternalError = InternalError
|
|
ProgrammingError = ProgrammingError
|
|
NotSupportedError = NotSupportedError
|
|
|
|
# TODO: move OK and EOF packet parsing/logic into a proper subclass
|
|
# of MysqlPacket like has been done with FieldDescriptorPacket.
|
|
class MySQLResult(object):
|
|
|
|
def __init__(self, connection):
|
|
from weakref import proxy
|
|
self.connection = proxy(connection)
|
|
self.affected_rows = None
|
|
self.insert_id = None
|
|
self.server_status = 0
|
|
self.warning_count = 0
|
|
self.message = None
|
|
self.field_count = 0
|
|
self.description = None
|
|
self.rows = None
|
|
self.has_next = None
|
|
|
|
def read(self):
|
|
self.first_packet = self.connection.read_packet()
|
|
|
|
# TODO: use classes for different packet types?
|
|
if self.first_packet.is_ok_packet():
|
|
self._read_ok_packet()
|
|
else:
|
|
self._read_result_packet()
|
|
|
|
def _read_ok_packet(self):
|
|
self.first_packet.advance(1) # field_count (always '0')
|
|
self.affected_rows = self.first_packet.read_length_coded_binary()
|
|
self.insert_id = self.first_packet.read_length_coded_binary()
|
|
self.server_status = struct.unpack('<H', self.first_packet.read(2))[0]
|
|
self.warning_count = struct.unpack('<H', self.first_packet.read(2))[0]
|
|
self.message = self.first_packet.read_all()
|
|
|
|
def _read_result_packet(self):
|
|
self.field_count = byte2int(self.first_packet.read(1))
|
|
self._get_descriptions()
|
|
self._read_rowdata_packet()
|
|
|
|
# TODO: implement this as an iteratable so that it is more
|
|
# memory efficient and lower-latency to client...
|
|
def _read_rowdata_packet(self):
|
|
"""Read a rowdata packet for each data row in the result set."""
|
|
rows = []
|
|
while True:
|
|
packet = self.connection.read_packet()
|
|
if packet.is_eof_packet():
|
|
self.warning_count = packet.read(2)
|
|
server_status = struct.unpack('<h', packet.read(2))[0]
|
|
self.has_next = (server_status
|
|
& SERVER_STATUS.SERVER_MORE_RESULTS_EXISTS)
|
|
break
|
|
|
|
row = []
|
|
for field in self.fields:
|
|
if field.type_code in self.connection.decoders:
|
|
converter = self.connection.decoders[field.type_code]
|
|
|
|
if DEBUG: print "DEBUG: field=%s, converter=%s" % (field, converter)
|
|
data = packet.read_length_coded_string()
|
|
converted = None
|
|
if data != None:
|
|
converted = converter(self.connection, field, data)
|
|
|
|
row.append(converted)
|
|
|
|
rows.append(tuple(row))
|
|
|
|
self.affected_rows = len(rows)
|
|
self.rows = tuple(rows)
|
|
if DEBUG: self.rows
|
|
|
|
def _get_descriptions(self):
|
|
"""Read a column descriptor packet for each column in the result."""
|
|
self.fields = []
|
|
description = []
|
|
for i in xrange(self.field_count):
|
|
field = self.connection.read_packet(FieldDescriptorPacket)
|
|
self.fields.append(field)
|
|
description.append(field.description())
|
|
|
|
eof_packet = self.connection.read_packet()
|
|
assert eof_packet.is_eof_packet(), 'Protocol error, expecting EOF'
|
|
self.description = tuple(description)
|