From 6733134ad47783b6e2d01b388d8456e039fe1c77 Mon Sep 17 00:00:00 2001 From: "87414772@qq.com" <87414772@qq.com> Date: Tue, 24 Jun 2014 13:11:05 +0800 Subject: [PATCH 01/13] mysql lib --- 3rd/lua-mysqlaux/lua_mysqlaux.c | 372 +++++++++++++++ Makefile | 5 +- examples/config.mysql | 11 + examples/main_mysql.lua | 10 + lualib/mysql.lua | 801 ++++++++++++++++++++++++++++++++ test/testmysql.lua | 32 ++ 6 files changed, 1230 insertions(+), 1 deletion(-) create mode 100755 3rd/lua-mysqlaux/lua_mysqlaux.c create mode 100644 examples/config.mysql create mode 100644 examples/main_mysql.lua create mode 100755 lualib/mysql.lua create mode 100644 test/testmysql.lua diff --git a/3rd/lua-mysqlaux/lua_mysqlaux.c b/3rd/lua-mysqlaux/lua_mysqlaux.c new file mode 100755 index 000000000..6b639d7cb --- /dev/null +++ b/3rd/lua-mysqlaux/lua_mysqlaux.c @@ -0,0 +1,372 @@ +// +// lua_mysqlaux.c +// +// Created by changfeng on 6/17/14. +// Copyright (c) 2014 changfeng. All rights reserved. +// +#include +#include +#include + +#include +#include + +#define SHA1SIZE 20 +#define ROTL(bits,word) (((word) << (bits)) | ((word) >> (32-(bits)))) +typedef unsigned int uint32_t; + +struct sha +{ + uint32_t digest[5]; + uint32_t w[80]; + uint32_t a,b,c,d,e,f; + int err; +}; + + +static uint32_t padded_length_in_bits(uint32_t len) +{ + if(len%64 == 56) + { + len++; + } + while((len%64)!=56) + { + len++; + } + return len*8; +} + + +static int calculate_sha1(struct sha *sha1, unsigned char *text, uint32_t length) +{ + unsigned int i,j; + unsigned char *buffer=NULL, *pbuffer=NULL; + uint32_t bits=0; + uint32_t temp=0,k=0; + uint32_t lb = length*8; + + if (!sha1) + { + return 0; + } + // initialize the default digest values + sha1->digest[0] = 0x67452301; + sha1->digest[1] = 0xEFCDAB89; + sha1->digest[2] = 0x98BADCFE; + sha1->digest[3] = 0x10325476; + sha1->digest[4] = 0xC3D2E1F0; + sha1->a=sha1->b=sha1->c=sha1->d=sha1->e=sha1->f=0; + if (!text || !length) + { + return 0; + } + + bits = padded_length_in_bits(length); + buffer = (unsigned char *) malloc((bits/8)+8); + memset(buffer,0,(bits/8)+8); + if(buffer == NULL) + { + return 1; + } + pbuffer = buffer; + memcpy(buffer, text, length); + + + //add 1 on the last of the message.. + *(buffer+length) = 0x80; + for(i=length+1; i<(bits/8); i++) + { + *(buffer+i) = 0x00; + } + + *(buffer +(bits/8)+4+0) = (lb>>24) & 0xFF; + *(buffer +(bits/8)+4+1) = (lb>>16) & 0xFF; + *(buffer +(bits/8)+4+2) = (lb>>8) & 0xFF; + *(buffer +(bits/8)+4+3) = (lb>>0) & 0xFF; + + + //main loop + for(i=0; i<((bits+64)/512); i++) + { + //first empty the block for each pass.. + for(j=0; j<80; j++) + { + sha1->w[j] = 0x00; + } + + + //fill the first 16 words with the characters read directly from the buffer. + for(j=0; j<16; j++) + { + sha1->w[j] =buffer[j*4+0]; + sha1->w[j] = sha1->w[j]<<8; + sha1->w[j] |= buffer[j*4+1]; + sha1->w[j] = sha1->w[j]<<8; + sha1->w[j] |= buffer[j*4+2]; + sha1->w[j] = sha1->w[j]<<8; + sha1->w[j] |= buffer[j*4+3]; + } + + //fill the rest 64 words using the formula + for(j=16; j<80; j++) + { + sha1->w[j] = (ROTL(1,(sha1->w[j-3] ^ sha1->w[j-8] ^ sha1->w[j-14] ^ sha1->w[j-16]))); + } + + + //initialize hash for this chunck reading that has been stored in the structure digest + sha1->a = sha1->digest[0]; + sha1->b = sha1->digest[1]; + sha1->c = sha1->digest[2]; + sha1->d = sha1->digest[3]; + sha1->e = sha1->digest[4]; + + //for all the 80 32bit blocks calculate f and use k accordingly per specification. + for(j=0; j<80; j++) + { + if((j>=0) && (j<20)) + { + sha1->f = ((sha1->b)&(sha1->c)) | ((~(sha1->b))&(sha1->d)); + k = 0x5A827999; + + } + else if((j>=20) && (j<40)) + { + sha1->f = (sha1->b)^(sha1->c)^(sha1->d); + k = 0x6ED9EBA1; + } + else if((j>=40) && (j<60)) + { + sha1->f = ((sha1->b)&(sha1->c)) | ((sha1->b)&(sha1->d)) | ((sha1->c)&(sha1->d)); + k = 0x8F1BBCDC; + } + else if((j>=60) && (j<80)) + { + sha1->f = (sha1->b)^(sha1->c)^(sha1->d); + k = 0xCA62C1D6; + } + + temp = ROTL(5,(sha1->a)) + (sha1->f) + (sha1->e) + k + sha1->w[j]; + sha1->e = (sha1->d); + sha1->d = (sha1->c); + sha1->c = ROTL(30,(sha1->b)); + sha1->b = (sha1->a); + sha1->a = temp; + + //reset temp to 0 to be in safe side only, not mandatory. + temp =0x00; + + + } + + // append to total hash. + sha1->digest[0] += sha1->a; + sha1->digest[1] += sha1->b; + sha1->digest[2] += sha1->c; + sha1->digest[3] += sha1->d; + sha1->digest[4] += sha1->e; + + + //since we used 512bit size block per each pass, let us update the buffer pointer accordingly. + buffer = buffer+64; + + } + free(pbuffer); + return 0; +} + +static void int2ch4(int intVal,unsigned char *result) +{ + result[0]= (unsigned char)((intVal>>24) & 0x000000ff); + result[1]= (unsigned char)((intVal>>16) & 0x000000ff); + result[2]= (unsigned char)((intVal>> 8) & 0x000000ff); + result[3]= (unsigned char)((intVal>> 0) & 0x000000ff); +} + + +static int sha1_bin (lua_State *L) { + void * msg = NULL; + size_t len =0; + + if( lua_gettop(L) != 1 ){ + return 0; + } + if( lua_isnil(L,1) ) { + msg = NULL; + len =0; + }else{ + msg=luaL_checklstring(L,1,&len); + } + struct sha tmpsha; + calculate_sha1( &tmpsha, msg, (uint32_t)len); + unsigned char tmpret[SHA1SIZE+8]; + memset(tmpret,0,SHA1SIZE+8); + int i=0; + for ( i=0; i<5; i++) + { + int2ch4(tmpsha.digest[i], tmpret+i*4); + } + + lua_pushlstring(L, (char *)tmpret, SHA1SIZE); + return 1; +} + +static unsigned int num_escape_sql_str(unsigned char *dst, unsigned char *src, size_t size) +{ + unsigned int n =0; + while (size) { + /* the highest bit of all the UTF-8 chars + * is always 1 */ + if ((*src & 0x80) == 0) { + switch (*src) { + case '\0': + case '\b': + case '\n': + case '\r': + case '\t': + case 26: /* \z */ + case '\\': + case '\'': + case '"': + n++; + break; + default: + break; + } + } + src++; + size--; + } + return n; +} +static unsigned char* +escape_sql_str(unsigned char *dst, unsigned char *src, size_t size) +{ + + while (size) { + if ((*src & 0x80) == 0) { + switch (*src) { + case '\0': + *dst++ = '\\'; + *dst++ = '0'; + break; + + case '\b': + *dst++ = '\\'; + *dst++ = 'b'; + break; + + case '\n': + *dst++ = '\\'; + *dst++ = 'n'; + break; + + case '\r': + *dst++ = '\\'; + *dst++ = 'r'; + break; + + case '\t': + *dst++ = '\\'; + *dst++ = 't'; + break; + + case 26: + *dst++ = '\\'; + *dst++ = 'z'; + break; + + case '\\': + *dst++ = '\\'; + *dst++ = '\\'; + break; + + case '\'': + *dst++ = '\\'; + *dst++ = '\''; + break; + + case '"': + *dst++ = '\\'; + *dst++ = '"'; + break; + + default: + *dst++ = *src; + break; + } + } else { + *dst++ = *src; + } + src++; + size--; + } /* while (size) */ + + return dst; +} + + + + +static int +quote_sql_str(lua_State *L) +{ + size_t len, dlen, escape; + unsigned char *p; + unsigned char *src, *dst; + + if (lua_gettop(L) != 1) { + return luaL_error(L, "expecting one argument"); + } + + src = (unsigned char *) luaL_checklstring(L, 1, &len); + + if (len == 0) { + dst = (unsigned char *) "''"; + dlen = sizeof("''") - 1; + lua_pushlstring(L, (char *) dst, dlen); + return 1; + } + + escape = num_escape_sql_str(NULL, src, len); + + dlen = sizeof("''") - 1 + len + escape; + p = lua_newuserdata(L, dlen); + + dst = p; + + *p++ = '\''; + + if (escape == 0) { + memcpy(p, src, len); + p+=len; + } else { + p = (unsigned char *) escape_sql_str(p, src, len); + } + + *p++ = '\''; + + if (p != dst + dlen) { + return luaL_error(L, "quote sql string error"); + } + + lua_pushlstring(L, (char *) dst, p - dst); + + return 1; +} + + +static struct luaL_Reg mysqlauxlib[] = { + {"sha1_bin", sha1_bin}, + {"quote_sql_str",quote_sql_str}, + {NULL, NULL} +}; + + +int luaopen_mysqlaux_c (lua_State *L) { + lua_newtable(L); + luaL_setfuncs(L, mysqlauxlib, 0); + return 1; +} + diff --git a/Makefile b/Makefile index 3c5fd4465..03a4903f3 100644 --- a/Makefile +++ b/Makefile @@ -43,7 +43,7 @@ jemalloc : $(MALLOC_STATICLIB) CSERVICE = snlua logger gate master harbor dummy LUA_CLIB = skynet socketdriver int64 bson mongo md5 netpack \ cjson clientsocket memory profile multicast \ - cluster + cluster mysqlaux SKYNET_SRC = skynet_main.c skynet_handle.c skynet_module.c skynet_mq.c \ skynet_server.c skynet_start.c skynet_timer.c skynet_error.c \ @@ -110,6 +110,9 @@ $(LUA_CLIB_PATH)/multicast.so : lualib-src/lua-multicast.c | $(LUA_CLIB_PATH) $(LUA_CLIB_PATH)/cluster.so : lualib-src/lua-cluster.c | $(LUA_CLIB_PATH) $(CC) $(CFLAGS) $(SHARED) -Iskynet-src $^ -o $@ +$(LUA_CLIB_PATH)/mysqlaux.so : 3rd/lua-mysqlaux/lua_mysqlaux.c | $(LUA_CLIB_PATH) + $(CC) $(CFLAGS) $(SHARED) $^ -o $@ + clean : rm -f $(SKYNET_BUILD_PATH)/skynet $(CSERVICE_PATH)/*.so $(LUA_CLIB_PATH)/*.so diff --git a/examples/config.mysql b/examples/config.mysql new file mode 100644 index 000000000..8f23a02d6 --- /dev/null +++ b/examples/config.mysql @@ -0,0 +1,11 @@ +root = "./" +thread = 8 +logger = nil +harbor = 0 +start = "main_mysql" -- main script +bootstrap = "snlua bootstrap" -- The service for bootstrap +luaservice = root.."service/?.lua;"..root.."test/?.lua;"..root.."examples/?.lua" +lualoader = "lualib/loader.lua" +snax = root.."examples/?.lua;"..root.."test/?.lua" +cpath = root.."cservice/?.so" +-- daemon = "./skynet.pid" diff --git a/examples/main_mysql.lua b/examples/main_mysql.lua new file mode 100644 index 000000000..8a92b2cd3 --- /dev/null +++ b/examples/main_mysql.lua @@ -0,0 +1,10 @@ +local skynet = require "skynet" + + +skynet.start(function() + print("Main Server start") + local console = skynet.newservice("testmysql") + + print("Main Server exit") + skynet.exit() +end) diff --git a/lualib/mysql.lua b/lualib/mysql.lua new file mode 100755 index 000000000..137a12faf --- /dev/null +++ b/lualib/mysql.lua @@ -0,0 +1,801 @@ +-- Copyright (C) 2012 Yichun Zhang (agentzh) +-- Copyright (C) 2014 Chang Feng +local socketchannel = require "socketchannel" +local bit = require "bit32" +local mysqlaux = require "mysqlaux.c" + + + +local sub = string.sub +local strbyte = string.byte +local strchar = string.char +local strfind = string.find +local strrep = string.rep +local null = nil +local band = bit.band +local bxor = bit.bxor +local bor = bit.bor +local lshift = bit.lshift +local rshift = bit.rshift +local sha1= mysqlaux.sha1_bin +local concat = table.concat +local unpack = unpack +local setmetatable = setmetatable +local error = error +local tonumber = tonumber +local new_tab = function (narr, nrec) return {} end + + +local _M = { _VERSION = '0.13' } +-- constants + +local STATE_CONNECTED = 1 +local STATE_COMMAND_SENT = 2 + +local COM_QUERY = 0x03 + +local SERVER_MORE_RESULTS_EXISTS = 8 + +-- 16MB - 1, the default max allowed packet size used by libmysqlclient +local FULL_PACKET_SIZE = 16777215 + + +local mt = { __index = _M } + + +-- mysql field value type converters +local converters = new_tab(0, 8) + +for i = 0x01, 0x05 do + -- tiny, short, long, float, double + converters[i] = tonumber +end +-- converters[0x08] = tonumber -- long long +converters[0x09] = tonumber -- int24 +converters[0x0d] = tonumber -- year +converters[0xf6] = tonumber -- newdecimal + + +local function _get_byte2(data, i) + local a, b = strbyte(data, i, i + 1) + return bor(a, lshift(b, 8)), i + 2 +end + + +local function _get_byte3(data, i) + local a, b, c = strbyte(data, i, i + 2) + return bor(a, lshift(b, 8), lshift(c, 16)), i + 3 +end + + +local function _get_byte4(data, i) + local a, b, c, d = strbyte(data, i, i + 3) + return bor(a, lshift(b, 8), lshift(c, 16), lshift(d, 24)), i + 4 +end + + +local function _get_byte8(data, i) + local a, b, c, d, e, f, g, h = strbyte(data, i, i + 7) + + -- XXX workaround for the lack of 64-bit support in bitop: + local lo = bor(a, lshift(b, 8), lshift(c, 16), lshift(d, 24)) + local hi = bor(e, lshift(f, 8), lshift(g, 16), lshift(h, 24)) + return lo + hi * 4294967296, i + 8 + + -- return bor(a, lshift(b, 8), lshift(c, 16), lshift(d, 24), lshift(e, 32), + -- lshift(f, 40), lshift(g, 48), lshift(h, 56)), i + 8 +end + + +local function _set_byte2(n) + return strchar(band(n, 0xff), band(rshift(n, 8), 0xff)) +end + + +local function _set_byte3(n) + return strchar(band(n, 0xff), + band(rshift(n, 8), 0xff), + band(rshift(n, 16), 0xff)) +end + + +local function _set_byte4(n) + return strchar(band(n, 0xff), + band(rshift(n, 8), 0xff), + band(rshift(n, 16), 0xff), + band(rshift(n, 24), 0xff)) +end + + +local function _from_cstring(data, i) + local last = strfind(data, "\0", i, true) + if not last then + return nil, nil + end + + return sub(data, i, last), last + 1 +end + + +local function _to_cstring(data) + return data .. "\0" +end + + +local function _to_binary_coded_string(data) + return strchar(#data) .. data +end + + +local function _dump(data) + local len = #data + local bytes = new_tab(len, 0) + for i = 1, len do + bytes[i] = strbyte(data, i) + end + return concat(bytes, " ") +end + + + +local function _dumphex(bytes) + local result ={} + + for i = 1, string.len(bytes) do + local charcode = tonumber(strbyte(bytes, i, i)) + local hexstr = string.format("%02X", charcode) + result[i]=hexstr + end + + local res=table.concat(result, " ") + return res +end + + +local function _compute_token(password, scramble) + if password == "" then + return "" + end + --_dump(scramble) + --print("password=",password) + --print("password:", password, "scramble: ", _dumphex(scramble) ) + local stage1 = sha1(password) + --print("stage1:", _dumphex(stage1) ) + local stage2 = sha1(stage1) + --print("stage2:", _dumphex(stage2) ) + local stage3 = sha1(scramble .. stage2) + --print("stage3:", _dumphex(stage3) ) + local n = #stage1 + local bytes = new_tab(n, 0) + for i = 1, n do + bytes[i] = strchar(bxor(strbyte(stage3, i), strbyte(stage1, i))) + end + + return concat(bytes) +end + +local function _compose_packet(self, req, size) + self.packet_no = self.packet_no + 1 + + local packet = _set_byte3(size) .. strchar(self.packet_no) .. req + return packet +end + + +local function _send_packet(self, req, size) + local sock = self.sock + + self.packet_no = self.packet_no + 1 + + --print("packet no: ", self.packet_no) + + local packet = _set_byte3(size) .. strchar(self.packet_no) .. req + + --print("sending packet...") + + --return sock:send(packet) + return socket.write(self.sock,packet) +end + + +local function _recv_packet(self,sock) + + + local data = sock:read( 4) + if not data then + return nil, nil, "failed to receive packet header: " + end + --print("_recv_packet data type:" ,type(data) ) + --print("packet header: ", _dump(data)) + + local len, pos = _get_byte3(data, 1) + + --print("recv_packet packet length: ", len) + + if len == 0 then + return nil, nil, "empty packet" + end + + if len > self._max_packet_size then + return nil, nil, "packet size too big: " .. len + end + + local num = strbyte(data, pos) + + --print("recv packet: packet no: ", num) + + self.packet_no = num + + --data, err = sock:receive(len) + + data = sock:read(len) + + + if not data then + return nil, nil, "failed to read packet content: " + end + + --print("packet content: ", _dump(data)) + --print("packet content (ascii): ", data) + + local field_count = strbyte(data, 1) + --print("field count:",field_count) + local typ + if field_count == 0x00 then + typ = "OK" + elseif field_count == 0xff then + typ = "ERR" + elseif field_count == 0xfe then + typ = "EOF" + elseif field_count <= 250 then + typ = "DATA" + end + + --print("recv packet: typ= ", typ) + return data, typ +end + + +local function _from_length_coded_bin(data, pos) + local first = strbyte(data, pos) + + --print("LCB: first: ", first) + + if not first then + return nil, pos + end + + if first >= 0 and first <= 250 then + return first, pos + 1 + end + + if first == 251 then + return null, pos + 1 + end + + if first == 252 then + pos = pos + 1 + return _get_byte2(data, pos) + end + + if first == 253 then + pos = pos + 1 + return _get_byte3(data, pos) + end + + if first == 254 then + pos = pos + 1 + return _get_byte8(data, pos) + end + + return false, pos + 1 +end + + +local function _from_length_coded_str(data, pos) + local len + len, pos = _from_length_coded_bin(data, pos) + if len == nil or len == null then + return null, pos + end + + return sub(data, pos, pos + len - 1), pos + len +end + + +local function _parse_ok_packet(packet) + local res = new_tab(0, 5) + local pos + + res.affected_rows, pos = _from_length_coded_bin(packet, 2) + + --print("affected rows: ", res.affected_rows, ", pos:", pos) + + res.insert_id, pos = _from_length_coded_bin(packet, pos) + + --print("insert id: ", res.insert_id, ", pos:", pos) + + res.server_status, pos = _get_byte2(packet, pos) + + --print("server status: ", res.server_status, ", pos:", pos) + + res.warning_count, pos = _get_byte2(packet, pos) + + --print("warning count: ", res.warning_count, ", pos: ", pos) + + local message = sub(packet, pos) + if message and message ~= "" then + res.message = message + end + + --print("message: ", res.message, ", pos:", pos) + + return res +end + + +local function _parse_eof_packet(packet) + local pos = 2 + + local warning_count, pos = _get_byte2(packet, pos) + local status_flags = _get_byte2(packet, pos) + + return warning_count, status_flags +end + + +local function _parse_err_packet(packet) + local errno, pos = _get_byte2(packet, 2) + local marker = sub(packet, pos, pos) + local sqlstate + if marker == '#' then + -- with sqlstate + pos = pos + 1 + sqlstate = sub(packet, pos, pos + 5 - 1) + pos = pos + 5 + end + + local message = sub(packet, pos) + return errno, message, sqlstate +end + + +local function _parse_result_set_header_packet(packet) + local field_count, pos = _from_length_coded_bin(packet, 1) + + local extra + extra = _from_length_coded_bin(packet, pos) + + return field_count, extra +end + + +local function _parse_field_packet(data) + local col = new_tab(0, 2) + local catalog, db, table, orig_table, orig_name, charsetnr, length + local pos + catalog, pos = _from_length_coded_str(data, 1) + + --print("catalog: ", col.catalog, ", pos:", pos) + + db, pos = _from_length_coded_str(data, pos) + table, pos = _from_length_coded_str(data, pos) + orig_table, pos = _from_length_coded_str(data, pos) + col.name, pos = _from_length_coded_str(data, pos) + + orig_name, pos = _from_length_coded_str(data, pos) + + pos = pos + 1 -- ignore the filler + + charsetnr, pos = _get_byte2(data, pos) + + length, pos = _get_byte4(data, pos) + + col.type = strbyte(data, pos) + + --[[ + pos = pos + 1 + + col.flags, pos = _get_byte2(data, pos) + + col.decimals = strbyte(data, pos) + pos = pos + 1 + + local default = sub(data, pos + 2) + if default and default ~= "" then + col.default = default + end + --]] + + return col +end + + +local function _parse_row_data_packet(data, cols, compact) + local pos = 1 + local ncols = #cols + local row + if compact then + row = new_tab(ncols, 0) + else + row = new_tab(0, ncols) + end + for i = 1, ncols do + local value + value, pos = _from_length_coded_str(data, pos) + local col = cols[i] + local typ = col.type + local name = col.name + + --print("row field value: ", value, ", type: ", typ) + + if value ~= null then + local conv = converters[typ] + if conv then + value = conv(value) + end + end + + if compact then + row[i] = value + + else + row[name] = value + end + end + + return row +end + + +local function _recv_field_packet(self, sock) + local packet, typ, err = _recv_packet(self, sock) + if not packet then + return nil, err + end + + if typ == "ERR" then + local errno, msg, sqlstate = _parse_err_packet(packet) + return nil, msg, errno, sqlstate + end + + if typ ~= 'DATA' then + return nil, "bad field packet type: " .. typ + end + + -- typ == 'DATA' + + return _parse_field_packet(packet) +end + +local function _recv_decode_packet_resp(self) + return function(sock) + return true, _recv_packet(self,sock) + end +end + +local function _recv_auth_resp(self) + return function(sock) + --print("recv auth resp") + local packet, typ, err = _recv_packet(self,sock) + if not packet then + --print("recv auth resp : failed to receive the result packet") + error ("failed to receive the result packet"..err) + end + + --print("receive auth response packet type: ",typ) + if typ == 'ERR' then + local errno, msg, sqlstate = _parse_err_packet(packet) + error( string.format("errno:%d, msg:%s,sqlstate:%s",errno,msg,sqlstate)) + end + + if typ == 'EOF' then + error "old pre-4.1 authentication protocol not supported" + end + + if typ ~= 'OK' then + error "bad packet type: " + end + return true, true + end +end + + +local function _mysql_login(self,user,password,database) + + return function(sockchannel) + local packet, typ, err = sockchannel:response( _recv_decode_packet_resp(self) ) + --local aat={} + if not packet then + error( err ) + end + + if typ == "ERR" then + local errno, msg, sqlstate = _parse_err_packet(packet) + error( string.format("errno:%d, msg:%s,sqlstate:%s",errno,msg,sqlstate)) + end + + self.protocol_ver = strbyte(packet) + + --print("protocol version: ", self.protocol_ver) + + local server_ver, pos = _from_cstring(packet, 2) + if not server_ver then + error "bad handshake initialization packet: bad server version" + end + + --print("server version: ", server_ver) + + self._server_ver = server_ver + + + local thread_id, pos = _get_byte4(packet, pos) + + --print("thread id: ", thread_id) + + local scramble1 = sub(packet, pos, pos + 8 - 1) + --print("scramble1:",_dump(scramble1), "pos:",pos) + if not scramble1 then + error "1st part of scramble not found" + end + + pos = pos + 9 -- skip filler + + -- two lower bytes + self._server_capabilities, pos = _get_byte2(packet, pos) + + --print("server capabilities: ", self._server_capabilities) + + self._server_lang = strbyte(packet, pos) + pos = pos + 1 + + --print("server lang: ", self._server_lang) + + self._server_status, pos = _get_byte2(packet, pos) + + --print("server status: ", self._server_status) + + local more_capabilities + more_capabilities, pos = _get_byte2(packet, pos) + + self._server_capabilities = bor(self._server_capabilities, + lshift(more_capabilities, 16)) + + --print("server capabilities: ", self._server_capabilities) + + + -- local len = strbyte(packet, pos) + local len = 21 - 8 - 1 + + --print("scramble len: ", len) + + pos = pos + 1 + 10 + + local scramble_part2 = sub(packet, pos, pos + len - 1) + if not scramble_part2 then + error "2nd part of scramble not found" + end + + + local scramble = scramble1..scramble_part2 + --print("scramble:",_dump(scramble) ) + local token = _compute_token(password, scramble) + + -- local client_flags = self._server_capabilities + local client_flags = 260047; + + --print("token: ", _dump(token)) + + local req = _set_byte4(client_flags) + .. _set_byte4(self._max_packet_size) + .. "\0" -- TODO: add support for charset encoding + .. strrep("\0", 23) + .. _to_cstring(user) + .. _to_binary_coded_string(token) + .. _to_cstring(database) + + local packet_len = 4 + 4 + 1 + 23 + #user + 1 + + #token + 1 + #database + 1 + + --print("packet content length: ", packet_len) + --print("packet content: ", _dump(concat(req, ""))) + + local authpacket=_compose_packet(self,req,packet_len) + --print("mysql login authpacket len=",#authpacket) + return sockchannel:request(authpacket,_recv_auth_resp(self)) + end +end +function _M.connect( opts) + + local self = setmetatable( {}, mt) + + local max_packet_size = opts.max_packet_size + if not max_packet_size then + max_packet_size = 1024 * 1024 -- default 1 MB + end + self._max_packet_size = max_packet_size + self.compact = opts.compact_arrays + + + local database = opts.database or "" + local user = opts.user or "" + local password = opts.password or "" + + local channel = socketchannel.channel { + host = opts.host, + port = opts.port or 3306, + auth = _mysql_login(self,user,password,database ), + } + -- try connect first only once + channel:connect(true) + self.sockchannel = channel + + + return self +end + + + +function _M.close(self) + self.sockchannel:close() + setmetatable(self, nil) +end + + +function _M.server_ver(self) + return self._server_ver +end + +local function _compose_query(self, query) + + self.packet_no = -1 + + local cmd_packet = strchar(COM_QUERY) .. query + local packet_len = 1 + #query + + local querypacket = _compose_packet(self, cmd_packet, packet_len) + --print("compose query packet, len= ", #querypacket) + return querypacket +end + + + +local function read_result(self, sock) + --print("read_result") + local packet, typ, err = _recv_packet(self, sock) + if not packet then + --print("read result", err) + error( err ) + end + + if typ == "ERR" then + local errno, msg, sqlstate = _parse_err_packet(packet) + --print("read result ", msg, errno, sqlstate) + --return nil, msg, errno, sqlstate + error( string.format("errno:%d, msg:%s,sqlstate:%s",errno,msg,sqlstate)) + end + + if typ == 'OK' then + local res = _parse_ok_packet(packet) + if res and band(res.server_status, SERVER_MORE_RESULTS_EXISTS) ~= 0 then + --print("read result ", res, "again") + return res, "again" + end + --print("parse ok packet res=",res) + return res + end + + if typ ~= 'DATA' then + --print("read result", "packet type " ,typ , " not supported") + error( "packet type " .. typ .. " not supported" ) + end + + -- typ == 'DATA' + + --print("read the result set header packet") + + local field_count, extra = _parse_result_set_header_packet(packet) + + --print("field count: ", field_count) + + local cols = new_tab(field_count, 0) + for i = 1, field_count do + local col, err, errno, sqlstate = _recv_field_packet(self, sock) + if not col then + --return nil, err, errno, sqlstate + error( string.format("errno:%d, msg:%s,sqlstate:%s",errno,msg,sqlstate)) + end + + cols[i] = col + end + + local packet, typ, err = _recv_packet(self, sock) + if not packet then + error( err) + end + + if typ ~= 'EOF' then + error ( "unexpected packet type " .. typ .. " while eof packet is ".. "expected" ) + end + + -- typ == 'EOF' + + local compact = self.compact + + local rows = new_tab( 4, 0) + local i = 0 + while true do + --print("reading a row") + + packet, typ, err = _recv_packet(self, sock) + if not packet then + error (err) + end + + if typ == 'EOF' then + local warning_count, status_flags = _parse_eof_packet(packet) + + --print("status flags: ", status_flags) + + if band(status_flags, SERVER_MORE_RESULTS_EXISTS) ~= 0 then + return rows, "again" + end + + break + end + + -- if typ ~= 'DATA' then + -- return nil, 'bad row packet type: ' .. typ + -- end + + -- typ == 'DATA' + + local row = _parse_row_data_packet(packet, cols, compact) + i = i + 1 + rows[i] = row + end + + return rows +end + +local function _query_resp(self) + return function(sock) + --return true ,read_result(self,sock) + local res, more = read_result(self,sock) + if more ~= "again" then + return true, res + end + local mulitresultset = {res} + mulitresultset.mulitresultset = true + local i =2 + while more =="again" do + res, more = read_result(self,sock) + if not res then + return true, mulitresultset + end + mulitresultset[i]=res + i=i+1 + end + return true, mulitresultset + end +end +function _M.query(self, query) + local querypacket = _compose_query(self, query) + local sockchannel = self.sockchannel + if not self.query_resp then + self.query_resp = _query_resp(self) + end + return sockchannel:request( querypacket, self.query_resp ) +end + + +function _M.set_compact_arrays(self, value) + self.compact = value +end + + +function _M.quote_sql_str( str) + return mysqlaux.quote_sql_str(str) +end + +return _M diff --git a/test/testmysql.lua b/test/testmysql.lua new file mode 100644 index 000000000..b5621404b --- /dev/null +++ b/test/testmysql.lua @@ -0,0 +1,32 @@ +local skynet = require "skynet" +local mysql = require "mysql" + +skynet.start(function() + + local db=mysql.connect{ + host="192.168.1.218", + port=3306, + database="Battle_Data", + user="root", + password="1" + } + if not db then + print("failed to connect") + end + print("testmysql success to connect to mysql server") + + --local res=db:query("select * from test1;select * from test1") + local res=db:query("select * from G_BuildData_0 limit 10") + print(res) + for k,v in pairs(res) do + print("k=",k,"v=",v) + if type(v)=="table" then + for kk, vv in pairs(v) do + print("kk=",kk,"vv=",v) + end + end + end + + skynet.exit() +end) + From bf686da72304f2945e188a49c908da8c1188a902 Mon Sep 17 00:00:00 2001 From: changfeng <87414772@qq.com> Date: Wed, 25 Jun 2014 20:55:20 +0800 Subject: [PATCH 02/13] improve mysql lib test --- test/testmysql.lua | 69 +++++++++++++++++++++++++++++++++++++--------- 1 file changed, 56 insertions(+), 13 deletions(-) diff --git a/test/testmysql.lua b/test/testmysql.lua index b5621404b..8944d15c6 100644 --- a/test/testmysql.lua +++ b/test/testmysql.lua @@ -1,12 +1,56 @@ local skynet = require "skynet" local mysql = require "mysql" +local function dump(obj) + local getIndent, quoteStr, wrapKey, wrapVal, dumpObj + getIndent = function(level) + return string.rep("\t", level) + end + quoteStr = function(str) + return '"' .. string.gsub(str, '"', '\\"') .. '"' + end + wrapKey = function(val) + if type(val) == "number" then + return "[" .. val .. "]" + elseif type(val) == "string" then + return "[" .. quoteStr(val) .. "]" + else + return "[" .. tostring(val) .. "]" + end + end + wrapVal = function(val, level) + if type(val) == "table" then + return dumpObj(val, level) + elseif type(val) == "number" then + return val + elseif type(val) == "string" then + return quoteStr(val) + else + return tostring(val) + end + end + dumpObj = function(obj, level) + if type(obj) ~= "table" then + return wrapVal(obj) + end + level = level + 1 + local tokens = {} + tokens[#tokens + 1] = "{" + for k, v in pairs(obj) do + tokens[#tokens + 1] = getIndent(level) .. wrapKey(k) .. " = " .. wrapVal(v, level) .. "," + end + tokens[#tokens + 1] = getIndent(level - 1) .. "}" + return table.concat(tokens, "\n") + end + return dumpObj(obj, 0) +end + skynet.start(function() local db=mysql.connect{ - host="192.168.1.218", + host="127.0.0.1", port=3306, - database="Battle_Data", + database="skynet", user="root", password="1" } @@ -15,17 +59,16 @@ skynet.start(function() end print("testmysql success to connect to mysql server") - --local res=db:query("select * from test1;select * from test1") - local res=db:query("select * from G_BuildData_0 limit 10") - print(res) - for k,v in pairs(res) do - print("k=",k,"v=",v) - if type(v)=="table" then - for kk, vv in pairs(v) do - print("kk=",kk,"vv=",v) - end - end - end + local res = db:query("drop table if exists cats") + res = db:query("create table cats " + .."(id serial primary key, ".. "name varchar(5))") + print( dump( res ) ) + res = db:query("insert into cats (name) " + .. "values (\'Bob\'),(\'\'),(null)") + print ( dump( res ) ) + res = db:query("select * from cats order by id asc") + print ( dump( res ) ) + skynet.exit() end) From 315945b2bd374356e174857a884cc244f104ccc3 Mon Sep 17 00:00:00 2001 From: changfeng <87414772@qq.com> Date: Wed, 25 Jun 2014 22:15:06 +0800 Subject: [PATCH 03/13] no message --- test/testmysql.lua | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/test/testmysql.lua b/test/testmysql.lua index 8944d15c6..fe4bcb7e7 100644 --- a/test/testmysql.lua +++ b/test/testmysql.lua @@ -69,7 +69,15 @@ skynet.start(function() res = db:query("select * from cats order by id asc") print ( dump( res ) ) - + -- multiresultset test + res = db:query("select * from cats order by id asc ; select * from cats") + print ( dump( res ) ) + + print ( mysql.quote_sql_str([[\mysql escape %string test'test"]]) ) + + -- bad sql statement + res = pcall( db.query, db, "select * from notexisttable" ) + print( dump(res) ) skynet.exit() end) From 984f727385d92236baf9f4d2e7bb8b25f93e5410 Mon Sep 17 00:00:00 2001 From: changfeng <87414772@qq.com> Date: Wed, 25 Jun 2014 22:38:38 +0800 Subject: [PATCH 04/13] imporve mysql lib test --- test/testmysql.lua | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/test/testmysql.lua b/test/testmysql.lua index fe4bcb7e7..0debe1f97 100644 --- a/test/testmysql.lua +++ b/test/testmysql.lua @@ -63,9 +63,11 @@ skynet.start(function() res = db:query("create table cats " .."(id serial primary key, ".. "name varchar(5))") print( dump( res ) ) + res = db:query("insert into cats (name) " .. "values (\'Bob\'),(\'\'),(null)") print ( dump( res ) ) + res = db:query("select * from cats order by id asc") print ( dump( res ) ) @@ -76,8 +78,12 @@ skynet.start(function() print ( mysql.quote_sql_str([[\mysql escape %string test'test"]]) ) -- bad sql statement - res = pcall( db.query, db, "select * from notexisttable" ) + local ok, res = pcall( db.query, db, "select * from notexisttable" ) print( dump(res) ) + + res = db:query("select * from cats order by id asc") + print ( dump( res ) ) + skynet.exit() end) From c0e1365dc25fb69f5e944e0d225dad923ee5528f Mon Sep 17 00:00:00 2001 From: changfeng <87414772@qq.com> Date: Wed, 25 Jun 2014 23:11:00 +0800 Subject: [PATCH 05/13] modify error process in socket channel --- lualib/socketchannel.lua | 5 ++++- test/testmysql.lua | 4 ++-- 2 files changed, 6 insertions(+), 3 deletions(-) diff --git a/lualib/socketchannel.lua b/lualib/socketchannel.lua index 32615ac0c..c2170cbc6 100644 --- a/lualib/socketchannel.lua +++ b/lualib/socketchannel.lua @@ -132,9 +132,12 @@ local function dispatch_by_order(self) else close_channel_socket(self) local errmsg - if result ~= socket_error then + if result_ok ~= socket_error then errmsg = result_ok end + self.__result[co] = socket_error + self.__result_data[co] = errmsg + skynet.wakeup(co) wakeup_all(self, errmsg) end end diff --git a/test/testmysql.lua b/test/testmysql.lua index 0debe1f97..058192739 100644 --- a/test/testmysql.lua +++ b/test/testmysql.lua @@ -79,11 +79,11 @@ skynet.start(function() -- bad sql statement local ok, res = pcall( db.query, db, "select * from notexisttable" ) - print( dump(res) ) + print( "ok= ",ok, dump(res) ) res = db:query("select * from cats order by id asc") print ( dump( res ) ) - + skynet.exit() end) From ad5c37200b09d5332025f71f146dc17fb2b7efbb Mon Sep 17 00:00:00 2001 From: changfeng <87414772@qq.com> Date: Thu, 26 Jun 2014 00:25:10 +0800 Subject: [PATCH 06/13] improve mysql lib --- lualib/mysql.lua | 128 ++++++++++++++++++++++++++------------------- test/testmysql.lua | 5 +- 2 files changed, 76 insertions(+), 57 deletions(-) diff --git a/lualib/mysql.lua b/lualib/mysql.lua index 137a12faf..1d4d54a53 100755 --- a/lualib/mysql.lua +++ b/lualib/mysql.lua @@ -605,47 +605,8 @@ local function _mysql_login(self,user,password,database) return sockchannel:request(authpacket,_recv_auth_resp(self)) end end -function _M.connect( opts) - - local self = setmetatable( {}, mt) - - local max_packet_size = opts.max_packet_size - if not max_packet_size then - max_packet_size = 1024 * 1024 -- default 1 MB - end - self._max_packet_size = max_packet_size - self.compact = opts.compact_arrays - - - local database = opts.database or "" - local user = opts.user or "" - local password = opts.password or "" - - local channel = socketchannel.channel { - host = opts.host, - port = opts.port or 3306, - auth = _mysql_login(self,user,password,database ), - } - -- try connect first only once - channel:connect(true) - self.sockchannel = channel - - - return self -end - - - -function _M.close(self) - self.sockchannel:close() - setmetatable(self, nil) -end -function _M.server_ver(self) - return self._server_ver -end - local function _compose_query(self, query) self.packet_no = -1 @@ -665,14 +626,15 @@ local function read_result(self, sock) local packet, typ, err = _recv_packet(self, sock) if not packet then --print("read result", err) - error( err ) + return nil, err + --error( err ) end if typ == "ERR" then local errno, msg, sqlstate = _parse_err_packet(packet) --print("read result ", msg, errno, sqlstate) - --return nil, msg, errno, sqlstate - error( string.format("errno:%d, msg:%s,sqlstate:%s",errno,msg,sqlstate)) + return nil, msg, errno, sqlstate + --error( string.format("errno:%d, msg:%s,sqlstate:%s",errno,msg,sqlstate)) end if typ == 'OK' then @@ -687,7 +649,8 @@ local function read_result(self, sock) if typ ~= 'DATA' then --print("read result", "packet type " ,typ , " not supported") - error( "packet type " .. typ .. " not supported" ) + return nil, "packet type " .. typ .. " not supported" + --error( "packet type " .. typ .. " not supported" ) end -- typ == 'DATA' @@ -702,8 +665,8 @@ local function read_result(self, sock) for i = 1, field_count do local col, err, errno, sqlstate = _recv_field_packet(self, sock) if not col then - --return nil, err, errno, sqlstate - error( string.format("errno:%d, msg:%s,sqlstate:%s",errno,msg,sqlstate)) + return nil, err, errno, sqlstate + --error( string.format("errno:%d, msg:%s,sqlstate:%s",errno,msg,sqlstate)) end cols[i] = col @@ -711,11 +674,13 @@ local function read_result(self, sock) local packet, typ, err = _recv_packet(self, sock) if not packet then - error( err) + --error( err) + return nil, err end if typ ~= 'EOF' then - error ( "unexpected packet type " .. typ .. " while eof packet is ".. "expected" ) + --error ( "unexpected packet type " .. typ .. " while eof packet is ".. "expected" ) + return nil, "unexpected packet type " .. typ .. " while eof packet is ".. "expected" end -- typ == 'EOF' @@ -729,7 +694,8 @@ local function read_result(self, sock) packet, typ, err = _recv_packet(self, sock) if not packet then - error (err) + --error (err) + return nil, err end if typ == 'EOF' then @@ -761,15 +727,25 @@ end local function _query_resp(self) return function(sock) --return true ,read_result(self,sock) - local res, more = read_result(self,sock) - if more ~= "again" then + --local res, more = read_result(self,sock) + local res, err, errno, sqlstate = read_result(self,sock) + if not res then + local badresult ={} + badresult.badresult = true + badresult.err = err + badresult.errno = errno + badresult.sqlstate = sqlstate + return true , badresult + end + if err ~= "again" then return true, res end local mulitresultset = {res} mulitresultset.mulitresultset = true local i =2 - while more =="again" do - res, more = read_result(self,sock) + while err =="again" do + --res, more = read_result(self,sock) + res, err, errno, sqlstate = read_result(self,sock) if not res then return true, mulitresultset end @@ -779,6 +755,44 @@ local function _query_resp(self) return true, mulitresultset end end + +function _M.connect( opts) + + local self = setmetatable( {}, mt) + + local max_packet_size = opts.max_packet_size + if not max_packet_size then + max_packet_size = 1024 * 1024 -- default 1 MB + end + self._max_packet_size = max_packet_size + self.compact = opts.compact_arrays + + + local database = opts.database or "" + local user = opts.user or "" + local password = opts.password or "" + + local channel = socketchannel.channel { + host = opts.host, + port = opts.port or 3306, + auth = _mysql_login(self,user,password,database ), + } + -- try connect first only once + channel:connect(true) + self.sockchannel = channel + + + return self +end + + + +function _M.disconnect(self) + self.sockchannel:close() + setmetatable(self, nil) +end + + function _M.query(self, query) local querypacket = _compose_query(self, query) local sockchannel = self.sockchannel @@ -788,9 +802,8 @@ function _M.query(self, query) return sockchannel:request( querypacket, self.query_resp ) end - -function _M.set_compact_arrays(self, value) - self.compact = value +function _M.server_ver(self) + return self._server_ver end @@ -798,4 +811,9 @@ function _M.quote_sql_str( str) return mysqlaux.quote_sql_str(str) end +function _M.set_compact_arrays(self, value) + self.compact = value +end + + return _M diff --git a/test/testmysql.lua b/test/testmysql.lua index 058192739..992021090 100644 --- a/test/testmysql.lua +++ b/test/testmysql.lua @@ -78,12 +78,13 @@ skynet.start(function() print ( mysql.quote_sql_str([[\mysql escape %string test'test"]]) ) -- bad sql statement - local ok, res = pcall( db.query, db, "select * from notexisttable" ) - print( "ok= ",ok, dump(res) ) + local res = db:query("select * from notexisttable" ) + print( dump(res) ) res = db:query("select * from cats order by id asc") print ( dump( res ) ) + db:disconnect() skynet.exit() end) From 4cce476dfc58d624a609d1b59c99901250fd72c8 Mon Sep 17 00:00:00 2001 From: changfeng <87414772@qq.com> Date: Thu, 26 Jun 2014 00:28:08 +0800 Subject: [PATCH 07/13] no message --- test/testmysql.lua | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/test/testmysql.lua b/test/testmysql.lua index 992021090..87ffea671 100644 --- a/test/testmysql.lua +++ b/test/testmysql.lua @@ -52,7 +52,8 @@ skynet.start(function() port=3306, database="skynet", user="root", - password="1" + password="1", + max_packet_size = 1024 * 1024 } if not db then print("failed to connect") From fce05f0cfc3d2b8b1e742d4f5fdd99b2afdf8bee Mon Sep 17 00:00:00 2001 From: changfeng <87414772@qq.com> Date: Thu, 26 Jun 2014 01:04:11 +0800 Subject: [PATCH 08/13] no message --- lualib/mysql.lua | 2 ++ test/testmysql.lua | 36 +++++++++++++++++++++++++++++++----- 2 files changed, 33 insertions(+), 5 deletions(-) diff --git a/lualib/mysql.lua b/lualib/mysql.lua index 1d4d54a53..adaf08e64 100755 --- a/lualib/mysql.lua +++ b/lualib/mysql.lua @@ -481,12 +481,14 @@ local function _recv_auth_resp(self) if not packet then --print("recv auth resp : failed to receive the result packet") error ("failed to receive the result packet"..err) + --return nil,err end --print("receive auth response packet type: ",typ) if typ == 'ERR' then local errno, msg, sqlstate = _parse_err_packet(packet) error( string.format("errno:%d, msg:%s,sqlstate:%s",errno,msg,sqlstate)) + --return nil, errno,msg, sqlstate end if typ == 'EOF' then diff --git a/test/testmysql.lua b/test/testmysql.lua index 87ffea671..1423adbdd 100644 --- a/test/testmysql.lua +++ b/test/testmysql.lua @@ -45,6 +45,24 @@ local function dump(obj) return dumpObj(obj, 0) end +local function test2( db) + local i=1 + while true do + local res = db:query("select * from cats order by id asc") + print ( "test2 i=" ,i,"\n",dump( res ) ) + skynet.sleep(1000) + i=i+1 + end +end +local function test3( db) + local i=1 + while true do + local res = db:query("select * from cats order by id asc") + print ( "test3 i=" ,i,"\n",dump( res ) ) + skynet.sleep(1000) + i=i+1 + end +end skynet.start(function() local db=mysql.connect{ @@ -71,7 +89,10 @@ skynet.start(function() res = db:query("select * from cats order by id asc") print ( dump( res ) ) - + + -- test in another coroutine + skynet.fork( test2, db) + skynet.fork( test3, db) -- multiresultset test res = db:query("select * from cats order by id asc ; select * from cats") print ( dump( res ) ) @@ -82,10 +103,15 @@ skynet.start(function() local res = db:query("select * from notexisttable" ) print( dump(res) ) - res = db:query("select * from cats order by id asc") - print ( dump( res ) ) + local i=1 + while true do + local res = db:query("select * from cats order by id asc") + print ( "test1 i=" ,i,"\n",dump( res ) ) + skynet.sleep(1000) + i=i+1 + end - db:disconnect() - skynet.exit() + --db:disconnect() + --skynet.exit() end) From e9d4075e435c5564fd88c9d8f21d87380035f02a Mon Sep 17 00:00:00 2001 From: changfeng <87414772@qq.com> Date: Thu, 26 Jun 2014 07:34:16 +0800 Subject: [PATCH 09/13] remove commented out code, improve test code --- lualib/mysql.lua | 81 ++-------------------------------------------- test/testmysql.lua | 22 +++++++++---- 2 files changed, 18 insertions(+), 85 deletions(-) diff --git a/lualib/mysql.lua b/lualib/mysql.lua index adaf08e64..0c5bdbbbb 100755 --- a/lualib/mysql.lua +++ b/lualib/mysql.lua @@ -157,14 +157,11 @@ local function _compute_token(password, scramble) return "" end --_dump(scramble) - --print("password=",password) - --print("password:", password, "scramble: ", _dumphex(scramble) ) + local stage1 = sha1(password) --print("stage1:", _dumphex(stage1) ) local stage2 = sha1(stage1) - --print("stage2:", _dumphex(stage2) ) local stage3 = sha1(scramble .. stage2) - --print("stage3:", _dumphex(stage3) ) local n = #stage1 local bytes = new_tab(n, 0) for i = 1, n do @@ -187,13 +184,9 @@ local function _send_packet(self, req, size) self.packet_no = self.packet_no + 1 - --print("packet no: ", self.packet_no) local packet = _set_byte3(size) .. strchar(self.packet_no) .. req - --print("sending packet...") - - --return sock:send(packet) return socket.write(self.sock,packet) end @@ -205,12 +198,10 @@ local function _recv_packet(self,sock) if not data then return nil, nil, "failed to receive packet header: " end - --print("_recv_packet data type:" ,type(data) ) - --print("packet header: ", _dump(data)) + local len, pos = _get_byte3(data, 1) - --print("recv_packet packet length: ", len) if len == 0 then return nil, nil, "empty packet" @@ -222,24 +213,16 @@ local function _recv_packet(self,sock) local num = strbyte(data, pos) - --print("recv packet: packet no: ", num) - self.packet_no = num - --data, err = sock:receive(len) - data = sock:read(len) - if not data then return nil, nil, "failed to read packet content: " end - --print("packet content: ", _dump(data)) - --print("packet content (ascii): ", data) local field_count = strbyte(data, 1) - --print("field count:",field_count) local typ if field_count == 0x00 then typ = "OK" @@ -251,7 +234,6 @@ local function _recv_packet(self,sock) typ = "DATA" end - --print("recv packet: typ= ", typ) return data, typ end @@ -259,8 +241,6 @@ end local function _from_length_coded_bin(data, pos) local first = strbyte(data, pos) - --print("LCB: first: ", first) - if not first then return nil, pos end @@ -309,26 +289,18 @@ local function _parse_ok_packet(packet) res.affected_rows, pos = _from_length_coded_bin(packet, 2) - --print("affected rows: ", res.affected_rows, ", pos:", pos) - res.insert_id, pos = _from_length_coded_bin(packet, pos) - --print("insert id: ", res.insert_id, ", pos:", pos) - res.server_status, pos = _get_byte2(packet, pos) - --print("server status: ", res.server_status, ", pos:", pos) - res.warning_count, pos = _get_byte2(packet, pos) - --print("warning count: ", res.warning_count, ", pos: ", pos) local message = sub(packet, pos) if message and message ~= "" then res.message = message end - --print("message: ", res.message, ", pos:", pos) return res end @@ -376,7 +348,6 @@ local function _parse_field_packet(data) local pos catalog, pos = _from_length_coded_str(data, 1) - --print("catalog: ", col.catalog, ", pos:", pos) db, pos = _from_length_coded_str(data, pos) table, pos = _from_length_coded_str(data, pos) @@ -427,8 +398,6 @@ local function _parse_row_data_packet(data, cols, compact) local typ = col.type local name = col.name - --print("row field value: ", value, ", type: ", typ) - if value ~= null then local conv = converters[typ] if conv then @@ -476,7 +445,6 @@ end local function _recv_auth_resp(self) return function(sock) - --print("recv auth resp") local packet, typ, err = _recv_packet(self,sock) if not packet then --print("recv auth resp : failed to receive the result packet") @@ -484,7 +452,6 @@ local function _recv_auth_resp(self) --return nil,err end - --print("receive auth response packet type: ",typ) if typ == 'ERR' then local errno, msg, sqlstate = _parse_err_packet(packet) error( string.format("errno:%d, msg:%s,sqlstate:%s",errno,msg,sqlstate)) @@ -519,24 +486,17 @@ local function _mysql_login(self,user,password,database) self.protocol_ver = strbyte(packet) - --print("protocol version: ", self.protocol_ver) - local server_ver, pos = _from_cstring(packet, 2) if not server_ver then error "bad handshake initialization packet: bad server version" end - --print("server version: ", server_ver) - self._server_ver = server_ver local thread_id, pos = _get_byte4(packet, pos) - --print("thread id: ", thread_id) - local scramble1 = sub(packet, pos, pos + 8 - 1) - --print("scramble1:",_dump(scramble1), "pos:",pos) if not scramble1 then error "1st part of scramble not found" end @@ -546,31 +506,20 @@ local function _mysql_login(self,user,password,database) -- two lower bytes self._server_capabilities, pos = _get_byte2(packet, pos) - --print("server capabilities: ", self._server_capabilities) - self._server_lang = strbyte(packet, pos) pos = pos + 1 - --print("server lang: ", self._server_lang) - self._server_status, pos = _get_byte2(packet, pos) - --print("server status: ", self._server_status) - local more_capabilities more_capabilities, pos = _get_byte2(packet, pos) self._server_capabilities = bor(self._server_capabilities, lshift(more_capabilities, 16)) - --print("server capabilities: ", self._server_capabilities) - - -- local len = strbyte(packet, pos) local len = 21 - 8 - 1 - --print("scramble len: ", len) - pos = pos + 1 + 10 local scramble_part2 = sub(packet, pos, pos + len - 1) @@ -580,14 +529,10 @@ local function _mysql_login(self,user,password,database) local scramble = scramble1..scramble_part2 - --print("scramble:",_dump(scramble) ) local token = _compute_token(password, scramble) - -- local client_flags = self._server_capabilities local client_flags = 260047; - --print("token: ", _dump(token)) - local req = _set_byte4(client_flags) .. _set_byte4(self._max_packet_size) .. "\0" -- TODO: add support for charset encoding @@ -599,11 +544,7 @@ local function _mysql_login(self,user,password,database) local packet_len = 4 + 4 + 1 + 23 + #user + 1 + #token + 1 + #database + 1 - --print("packet content length: ", packet_len) - --print("packet content: ", _dump(concat(req, ""))) - local authpacket=_compose_packet(self,req,packet_len) - --print("mysql login authpacket len=",#authpacket) return sockchannel:request(authpacket,_recv_auth_resp(self)) end end @@ -617,24 +558,20 @@ local function _compose_query(self, query) local packet_len = 1 + #query local querypacket = _compose_packet(self, cmd_packet, packet_len) - --print("compose query packet, len= ", #querypacket) return querypacket end local function read_result(self, sock) - --print("read_result") local packet, typ, err = _recv_packet(self, sock) if not packet then - --print("read result", err) return nil, err --error( err ) end if typ == "ERR" then local errno, msg, sqlstate = _parse_err_packet(packet) - --print("read result ", msg, errno, sqlstate) return nil, msg, errno, sqlstate --error( string.format("errno:%d, msg:%s,sqlstate:%s",errno,msg,sqlstate)) end @@ -642,27 +579,20 @@ local function read_result(self, sock) if typ == 'OK' then local res = _parse_ok_packet(packet) if res and band(res.server_status, SERVER_MORE_RESULTS_EXISTS) ~= 0 then - --print("read result ", res, "again") return res, "again" end - --print("parse ok packet res=",res) return res end if typ ~= 'DATA' then - --print("read result", "packet type " ,typ , " not supported") return nil, "packet type " .. typ .. " not supported" --error( "packet type " .. typ .. " not supported" ) end -- typ == 'DATA' - --print("read the result set header packet") - local field_count, extra = _parse_result_set_header_packet(packet) - --print("field count: ", field_count) - local cols = new_tab(field_count, 0) for i = 1, field_count do local col, err, errno, sqlstate = _recv_field_packet(self, sock) @@ -692,8 +622,6 @@ local function read_result(self, sock) local rows = new_tab( 4, 0) local i = 0 while true do - --print("reading a row") - packet, typ, err = _recv_packet(self, sock) if not packet then --error (err) @@ -703,8 +631,6 @@ local function read_result(self, sock) if typ == 'EOF' then local warning_count, status_flags = _parse_eof_packet(packet) - --print("status flags: ", status_flags) - if band(status_flags, SERVER_MORE_RESULTS_EXISTS) ~= 0 then return rows, "again" end @@ -728,8 +654,6 @@ end local function _query_resp(self) return function(sock) - --return true ,read_result(self,sock) - --local res, more = read_result(self,sock) local res, err, errno, sqlstate = read_result(self,sock) if not res then local badresult ={} @@ -746,7 +670,6 @@ local function _query_resp(self) mulitresultset.mulitresultset = true local i =2 while err =="again" do - --res, more = read_result(self,sock) res, err, errno, sqlstate = read_result(self,sock) if not res then return true, mulitresultset diff --git a/test/testmysql.lua b/test/testmysql.lua index 1423adbdd..510625e0e 100644 --- a/test/testmysql.lua +++ b/test/testmysql.lua @@ -49,7 +49,10 @@ local function test2( db) local i=1 while true do local res = db:query("select * from cats order by id asc") - print ( "test2 i=" ,i,"\n",dump( res ) ) + print ( "test2 loop times=" ,i,"\n","query result=",dump( res ) ) + res = db:query("select * from cats order by id asc") + print ( "test2 loop times=" ,i,"\n","query result=",dump( res ) ) + skynet.sleep(1000) i=i+1 end @@ -58,7 +61,9 @@ local function test3( db) local i=1 while true do local res = db:query("select * from cats order by id asc") - print ( "test3 i=" ,i,"\n",dump( res ) ) + print ( "test3 loop times=" ,i,"\n","query result=",dump( res ) ) + res = db:query("select * from cats order by id asc") + print ( "test3 loop times=" ,i,"\n","query result=",dump( res ) ) skynet.sleep(1000) i=i+1 end @@ -95,18 +100,23 @@ skynet.start(function() skynet.fork( test3, db) -- multiresultset test res = db:query("select * from cats order by id asc ; select * from cats") - print ( dump( res ) ) + print ("multiresultset test result=", dump( res ) ) - print ( mysql.quote_sql_str([[\mysql escape %string test'test"]]) ) + print ("escape string test result=", mysql.quote_sql_str([[\mysql escape %string test'test"]]) ) -- bad sql statement local res = db:query("select * from notexisttable" ) - print( dump(res) ) + print( "bad query test result=" ,dump(res) ) local i=1 while true do local res = db:query("select * from cats order by id asc") - print ( "test1 i=" ,i,"\n",dump( res ) ) + print ( "test1 loop times=" ,i,"\n","query result=",dump( res ) ) + + res = db:query("select * from cats order by id asc") + print ( "test1 loop times=" ,i,"\n","query result=",dump( res ) ) + + skynet.sleep(1000) i=i+1 end From 3b95ecd9b22d736dbb65009c4a0d764f4e163aaf Mon Sep 17 00:00:00 2001 From: Cloud Wu Date: Fri, 17 Oct 2014 11:22:49 +0800 Subject: [PATCH 10/13] bugfix: skynet.queue --- lualib/skynet/queue.lua | 15 +++++++-------- 1 file changed, 7 insertions(+), 8 deletions(-) diff --git a/lualib/skynet/queue.lua b/lualib/skynet/queue.lua index 9f40e2473..b427da6b3 100644 --- a/lualib/skynet/queue.lua +++ b/lualib/skynet/queue.lua @@ -9,21 +9,20 @@ function skynet.queue() local thread_queue = {} return function(f, ...) local thread = coroutine.running() - if ref == 0 then - current_thread = thread - elseif current_thread ~= thread then + if current_thread and current_thread ~= thread then table.insert(thread_queue, thread) skynet.wait() - assert(ref == 0) + assert(ref == 0) -- current_thread == thread end + current_thread = thread + ref = ref + 1 local ok, err = pcall(f, ...) ref = ref - 1 if ref == 0 then - current_thread = nil - local co = table.remove(thread_queue,1) - if co then - skynet.wakeup(co) + current_thread = table.remove(thread_queue,1) + if current_thread then + skynet.wakeup(current_thread) end end assert(ok,err) From 9bdc36526ef5e5094b629d800421509e2638a6ee Mon Sep 17 00:00:00 2001 From: Cloud Wu Date: Wed, 22 Oct 2014 11:29:03 +0800 Subject: [PATCH 11/13] close socket after request --- lualib/http/httpc.lua | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lualib/http/httpc.lua b/lualib/http/httpc.lua index 784b186bb..a73d5710e 100644 --- a/lualib/http/httpc.lua +++ b/lualib/http/httpc.lua @@ -81,10 +81,10 @@ function httpc.request(method, host, url, recvheader, header, content) end local fd = socket.connect(hostname, port) local ok , statuscode, body = pcall(request, fd,method, host, url, recvheader, header, content) + socket.close(fd) if ok then return statuscode, body else - socket.close(fd) error(statuscode) end end From fc8983227d861b341c03c7f0e531dee220cfae49 Mon Sep 17 00:00:00 2001 From: Cloud Wu Date: Wed, 22 Oct 2014 20:32:23 +0800 Subject: [PATCH 12/13] bugfix issue #185 --- lualib-src/lua-seri.c | 24 ++++++++++++------------ 1 file changed, 12 insertions(+), 12 deletions(-) diff --git a/lualib-src/lua-seri.c b/lualib-src/lua-seri.c index 1acfb9748..fe4bfa46f 100644 --- a/lualib-src/lua-seri.c +++ b/lualib-src/lua-seri.c @@ -120,37 +120,37 @@ rb_read(struct read_block *rb, void *buffer, int sz) { static inline void wb_nil(struct write_block *wb) { - int n = TYPE_NIL; + uint8_t n = TYPE_NIL; wb_push(wb, &n, 1); } static inline void wb_boolean(struct write_block *wb, int boolean) { - int n = COMBINE_TYPE(TYPE_BOOLEAN , boolean ? 1 : 0); + uint8_t n = COMBINE_TYPE(TYPE_BOOLEAN , boolean ? 1 : 0); wb_push(wb, &n, 1); } static inline void wb_integer(struct write_block *wb, int v, int type) { if (v == 0) { - int n = COMBINE_TYPE(type , 0); + uint8_t n = COMBINE_TYPE(type , 0); wb_push(wb, &n, 1); } else if (v<0) { - int n = COMBINE_TYPE(type , 4); + uint8_t n = COMBINE_TYPE(type , 4); wb_push(wb, &n, 1); wb_push(wb, &v, 4); } else if (v<0x100) { - int n = COMBINE_TYPE(type , 1); + uint8_t n = COMBINE_TYPE(type , 1); wb_push(wb, &n, 1); uint8_t byte = (uint8_t)v; wb_push(wb, &byte, 1); } else if (v<0x10000) { - int n = COMBINE_TYPE(type , 2); + uint8_t n = COMBINE_TYPE(type , 2); wb_push(wb, &n, 1); uint16_t word = (uint16_t)v; wb_push(wb, &word, 2); } else { - int n = COMBINE_TYPE(type , 4); + uint8_t n = COMBINE_TYPE(type , 4); wb_push(wb, &n, 1); wb_push(wb, &v, 4); } @@ -158,14 +158,14 @@ wb_integer(struct write_block *wb, int v, int type) { static inline void wb_number(struct write_block *wb, double v) { - int n = COMBINE_TYPE(TYPE_NUMBER , 8); + uint8_t n = COMBINE_TYPE(TYPE_NUMBER , 8); wb_push(wb, &n, 1); wb_push(wb, &v, 8); } static inline void wb_pointer(struct write_block *wb, void *v) { - int n = TYPE_USERDATA; + uint8_t n = TYPE_USERDATA; wb_push(wb, &n, 1); wb_push(wb, &v, sizeof(v)); } @@ -173,7 +173,7 @@ wb_pointer(struct write_block *wb, void *v) { static inline void wb_string(struct write_block *wb, const char *str, int len) { if (len < MAX_COOKIE) { - int n = COMBINE_TYPE(TYPE_SHORT_STRING, len); + uint8_t n = COMBINE_TYPE(TYPE_SHORT_STRING, len); wb_push(wb, &n, 1); if (len > 0) { wb_push(wb, str, len); @@ -201,11 +201,11 @@ static int wb_table_array(lua_State *L, struct write_block * wb, int index, int depth) { int array_size = lua_rawlen(L,index); if (array_size >= MAX_COOKIE-1) { - int n = COMBINE_TYPE(TYPE_TABLE, MAX_COOKIE-1); + uint8_t n = COMBINE_TYPE(TYPE_TABLE, MAX_COOKIE-1); wb_push(wb, &n, 1); wb_integer(wb, array_size,TYPE_NUMBER); } else { - int n = COMBINE_TYPE(TYPE_TABLE, array_size); + uint8_t n = COMBINE_TYPE(TYPE_TABLE, array_size); wb_push(wb, &n, 1); } From 73bd788a6ca3e8554900077e5abd6dbe4fc104cb Mon Sep 17 00:00:00 2001 From: Cloud Wu Date: Mon, 27 Oct 2014 10:43:54 +0800 Subject: [PATCH 13/13] ready for v0.8.0 --- HISTORY.md | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/HISTORY.md b/HISTORY.md index e7771ed53..64c307446 100644 --- a/HISTORY.md +++ b/HISTORY.md @@ -1,3 +1,8 @@ +v0.8.0 (2014-10-27) +----------- +* Add mysql client driver +* Bugfix : skynet.queue + v0.7.4 (2014-10-13) ----------- * Bugfix : clear coroutine pool when GC