diff --git a/3rd/lua-mysqlaux/lua_mysqlaux.c b/3rd/lua-mysqlaux/lua_mysqlaux.c new file mode 100755 index 000000000..ba170ddb2 --- /dev/null +++ b/3rd/lua-mysqlaux/lua_mysqlaux.c @@ -0,0 +1,372 @@ +// +// lua_mysqlaux.c +// +// Created by changfeng on 6/17/14. +// Copyright (c) 2014 changfeng. All rights reserved. +// +#include +#include +#include + +#include +#include + +#define SHA1SIZE 20 +#define ROTL(bits,word) (((word) << (bits)) | ((word) >> (32-(bits)))) +typedef unsigned int uint32_t; + +struct sha +{ + uint32_t digest[5]; + uint32_t w[80]; + uint32_t a,b,c,d,e,f; + int err; +}; + + +static uint32_t padded_length_in_bits(uint32_t len) +{ + if(len%64 == 56) + { + len++; + } + while((len%64)!=56) + { + len++; + } + return len*8; +} + + +static int calculate_sha1(struct sha *sha1, const unsigned char *text, uint32_t length) +{ + unsigned int i,j; + unsigned char *buffer=NULL, *pbuffer=NULL; + uint32_t bits=0; + uint32_t temp=0,k=0; + uint32_t lb = length*8; + + if (!sha1) + { + return 0; + } + // initialize the default digest values + sha1->digest[0] = 0x67452301; + sha1->digest[1] = 0xEFCDAB89; + sha1->digest[2] = 0x98BADCFE; + sha1->digest[3] = 0x10325476; + sha1->digest[4] = 0xC3D2E1F0; + sha1->a=sha1->b=sha1->c=sha1->d=sha1->e=sha1->f=0; + if (!text || !length) + { + return 0; + } + + bits = padded_length_in_bits(length); + buffer = (unsigned char *) malloc((bits/8)+8); + memset(buffer,0,(bits/8)+8); + if(buffer == NULL) + { + return 1; + } + pbuffer = buffer; + memcpy(buffer, text, length); + + + //add 1 on the last of the message.. + *(buffer+length) = 0x80; + for(i=length+1; i<(bits/8); i++) + { + *(buffer+i) = 0x00; + } + + *(buffer +(bits/8)+4+0) = (lb>>24) & 0xFF; + *(buffer +(bits/8)+4+1) = (lb>>16) & 0xFF; + *(buffer +(bits/8)+4+2) = (lb>>8) & 0xFF; + *(buffer +(bits/8)+4+3) = (lb>>0) & 0xFF; + + + //main loop + for(i=0; i<((bits+64)/512); i++) + { + //first empty the block for each pass.. + for(j=0; j<80; j++) + { + sha1->w[j] = 0x00; + } + + + //fill the first 16 words with the characters read directly from the buffer. + for(j=0; j<16; j++) + { + sha1->w[j] =buffer[j*4+0]; + sha1->w[j] = sha1->w[j]<<8; + sha1->w[j] |= buffer[j*4+1]; + sha1->w[j] = sha1->w[j]<<8; + sha1->w[j] |= buffer[j*4+2]; + sha1->w[j] = sha1->w[j]<<8; + sha1->w[j] |= buffer[j*4+3]; + } + + //fill the rest 64 words using the formula + for(j=16; j<80; j++) + { + sha1->w[j] = (ROTL(1,(sha1->w[j-3] ^ sha1->w[j-8] ^ sha1->w[j-14] ^ sha1->w[j-16]))); + } + + + //initialize hash for this chunck reading that has been stored in the structure digest + sha1->a = sha1->digest[0]; + sha1->b = sha1->digest[1]; + sha1->c = sha1->digest[2]; + sha1->d = sha1->digest[3]; + sha1->e = sha1->digest[4]; + + //for all the 80 32bit blocks calculate f and use k accordingly per specification. + for(j=0; j<80; j++) + { + if((j>=0) && (j<20)) + { + sha1->f = ((sha1->b)&(sha1->c)) | ((~(sha1->b))&(sha1->d)); + k = 0x5A827999; + + } + else if((j>=20) && (j<40)) + { + sha1->f = (sha1->b)^(sha1->c)^(sha1->d); + k = 0x6ED9EBA1; + } + else if((j>=40) && (j<60)) + { + sha1->f = ((sha1->b)&(sha1->c)) | ((sha1->b)&(sha1->d)) | ((sha1->c)&(sha1->d)); + k = 0x8F1BBCDC; + } + else if((j>=60) && (j<80)) + { + sha1->f = (sha1->b)^(sha1->c)^(sha1->d); + k = 0xCA62C1D6; + } + + temp = ROTL(5,(sha1->a)) + (sha1->f) + (sha1->e) + k + sha1->w[j]; + sha1->e = (sha1->d); + sha1->d = (sha1->c); + sha1->c = ROTL(30,(sha1->b)); + sha1->b = (sha1->a); + sha1->a = temp; + + //reset temp to 0 to be in safe side only, not mandatory. + temp =0x00; + + + } + + // append to total hash. + sha1->digest[0] += sha1->a; + sha1->digest[1] += sha1->b; + sha1->digest[2] += sha1->c; + sha1->digest[3] += sha1->d; + sha1->digest[4] += sha1->e; + + + //since we used 512bit size block per each pass, let us update the buffer pointer accordingly. + buffer = buffer+64; + + } + free(pbuffer); + return 0; +} + +static void int2ch4(int intVal,unsigned char *result) +{ + result[0]= (unsigned char)((intVal>>24) & 0x000000ff); + result[1]= (unsigned char)((intVal>>16) & 0x000000ff); + result[2]= (unsigned char)((intVal>> 8) & 0x000000ff); + result[3]= (unsigned char)((intVal>> 0) & 0x000000ff); +} + + +static int sha1_bin (lua_State *L) { + const void * msg = NULL; + size_t len =0; + + if( lua_gettop(L) != 1 ){ + return 0; + } + if( lua_isnil(L,1) ) { + msg = NULL; + len =0; + }else{ + msg=luaL_checklstring(L,1,&len); + } + struct sha tmpsha; + calculate_sha1( &tmpsha, msg, (uint32_t)len); + unsigned char tmpret[SHA1SIZE+8]; + memset(tmpret,0,SHA1SIZE+8); + int i=0; + for ( i=0; i<5; i++) + { + int2ch4(tmpsha.digest[i], tmpret+i*4); + } + + lua_pushlstring(L, (char *)tmpret, SHA1SIZE); + return 1; +} + +static unsigned int num_escape_sql_str(unsigned char *dst, unsigned char *src, size_t size) +{ + unsigned int n =0; + while (size) { + /* the highest bit of all the UTF-8 chars + * is always 1 */ + if ((*src & 0x80) == 0) { + switch (*src) { + case '\0': + case '\b': + case '\n': + case '\r': + case '\t': + case 26: /* \z */ + case '\\': + case '\'': + case '"': + n++; + break; + default: + break; + } + } + src++; + size--; + } + return n; +} +static unsigned char* +escape_sql_str(unsigned char *dst, unsigned char *src, size_t size) +{ + + while (size) { + if ((*src & 0x80) == 0) { + switch (*src) { + case '\0': + *dst++ = '\\'; + *dst++ = '0'; + break; + + case '\b': + *dst++ = '\\'; + *dst++ = 'b'; + break; + + case '\n': + *dst++ = '\\'; + *dst++ = 'n'; + break; + + case '\r': + *dst++ = '\\'; + *dst++ = 'r'; + break; + + case '\t': + *dst++ = '\\'; + *dst++ = 't'; + break; + + case 26: + *dst++ = '\\'; + *dst++ = 'z'; + break; + + case '\\': + *dst++ = '\\'; + *dst++ = '\\'; + break; + + case '\'': + *dst++ = '\\'; + *dst++ = '\''; + break; + + case '"': + *dst++ = '\\'; + *dst++ = '"'; + break; + + default: + *dst++ = *src; + break; + } + } else { + *dst++ = *src; + } + src++; + size--; + } /* while (size) */ + + return dst; +} + + + + +static int +quote_sql_str(lua_State *L) +{ + size_t len, dlen, escape; + unsigned char *p; + unsigned char *src, *dst; + + if (lua_gettop(L) != 1) { + return luaL_error(L, "expecting one argument"); + } + + src = (unsigned char *) luaL_checklstring(L, 1, &len); + + if (len == 0) { + dst = (unsigned char *) "''"; + dlen = sizeof("''") - 1; + lua_pushlstring(L, (char *) dst, dlen); + return 1; + } + + escape = num_escape_sql_str(NULL, src, len); + + dlen = sizeof("''") - 1 + len + escape; + p = lua_newuserdata(L, dlen); + + dst = p; + + *p++ = '\''; + + if (escape == 0) { + memcpy(p, src, len); + p+=len; + } else { + p = (unsigned char *) escape_sql_str(p, src, len); + } + + *p++ = '\''; + + if (p != dst + dlen) { + return luaL_error(L, "quote sql string error"); + } + + lua_pushlstring(L, (char *) dst, p - dst); + + return 1; +} + + +static struct luaL_Reg mysqlauxlib[] = { + {"sha1_bin", sha1_bin}, + {"quote_sql_str",quote_sql_str}, + {NULL, NULL} +}; + + +int luaopen_mysqlaux_c (lua_State *L) { + lua_newtable(L); + luaL_setfuncs(L, mysqlauxlib, 0); + return 1; +} + diff --git a/HISTORY.md b/HISTORY.md index e7771ed53..64c307446 100644 --- a/HISTORY.md +++ b/HISTORY.md @@ -1,3 +1,8 @@ +v0.8.0 (2014-10-27) +----------- +* Add mysql client driver +* Bugfix : skynet.queue + v0.7.4 (2014-10-13) ----------- * Bugfix : clear coroutine pool when GC diff --git a/Makefile b/Makefile index eb29a6294..ec52fcd39 100644 --- a/Makefile +++ b/Makefile @@ -43,7 +43,8 @@ jemalloc : $(MALLOC_STATICLIB) CSERVICE = snlua logger gate harbor LUA_CLIB = skynet socketdriver int64 bson mongo md5 netpack \ clientsocket memory profile multicast \ - cluster crypt sharedata stm sproto lpeg + cluster crypt sharedata stm sproto lpeg \ + mysqlaux SKYNET_SRC = skynet_main.c skynet_handle.c skynet_module.c skynet_mq.c \ skynet_server.c skynet_start.c skynet_timer.c skynet_error.c \ @@ -122,6 +123,9 @@ $(LUA_CLIB_PATH)/sproto.so : lualib-src/sproto/sproto.c lualib-src/sproto/lsprot $(LUA_CLIB_PATH)/lpeg.so : 3rd/lpeg/lpcap.c 3rd/lpeg/lpcode.c 3rd/lpeg/lpprint.c 3rd/lpeg/lptree.c 3rd/lpeg/lpvm.c | $(LUA_CLIB_PATH) $(CC) $(CFLAGS) $(SHARED) -I3rd/lpeg $^ -o $@ +$(LUA_CLIB_PATH)/mysqlaux.so : 3rd/lua-mysqlaux/lua_mysqlaux.c | $(LUA_CLIB_PATH) + $(CC) $(CFLAGS) $(SHARED) $^ -o $@ + clean : rm -f $(SKYNET_BUILD_PATH)/skynet $(CSERVICE_PATH)/*.so $(LUA_CLIB_PATH)/*.so diff --git a/examples/config.mysql b/examples/config.mysql new file mode 100644 index 000000000..8f23a02d6 --- /dev/null +++ b/examples/config.mysql @@ -0,0 +1,11 @@ +root = "./" +thread = 8 +logger = nil +harbor = 0 +start = "main_mysql" -- main script +bootstrap = "snlua bootstrap" -- The service for bootstrap +luaservice = root.."service/?.lua;"..root.."test/?.lua;"..root.."examples/?.lua" +lualoader = "lualib/loader.lua" +snax = root.."examples/?.lua;"..root.."test/?.lua" +cpath = root.."cservice/?.so" +-- daemon = "./skynet.pid" diff --git a/examples/main_mysql.lua b/examples/main_mysql.lua new file mode 100644 index 000000000..8a92b2cd3 --- /dev/null +++ b/examples/main_mysql.lua @@ -0,0 +1,10 @@ +local skynet = require "skynet" + + +skynet.start(function() + print("Main Server start") + local console = skynet.newservice("testmysql") + + print("Main Server exit") + skynet.exit() +end) diff --git a/lualib-src/lua-seri.c b/lualib-src/lua-seri.c index 1acfb9748..fe4bfa46f 100644 --- a/lualib-src/lua-seri.c +++ b/lualib-src/lua-seri.c @@ -120,37 +120,37 @@ rb_read(struct read_block *rb, void *buffer, int sz) { static inline void wb_nil(struct write_block *wb) { - int n = TYPE_NIL; + uint8_t n = TYPE_NIL; wb_push(wb, &n, 1); } static inline void wb_boolean(struct write_block *wb, int boolean) { - int n = COMBINE_TYPE(TYPE_BOOLEAN , boolean ? 1 : 0); + uint8_t n = COMBINE_TYPE(TYPE_BOOLEAN , boolean ? 1 : 0); wb_push(wb, &n, 1); } static inline void wb_integer(struct write_block *wb, int v, int type) { if (v == 0) { - int n = COMBINE_TYPE(type , 0); + uint8_t n = COMBINE_TYPE(type , 0); wb_push(wb, &n, 1); } else if (v<0) { - int n = COMBINE_TYPE(type , 4); + uint8_t n = COMBINE_TYPE(type , 4); wb_push(wb, &n, 1); wb_push(wb, &v, 4); } else if (v<0x100) { - int n = COMBINE_TYPE(type , 1); + uint8_t n = COMBINE_TYPE(type , 1); wb_push(wb, &n, 1); uint8_t byte = (uint8_t)v; wb_push(wb, &byte, 1); } else if (v<0x10000) { - int n = COMBINE_TYPE(type , 2); + uint8_t n = COMBINE_TYPE(type , 2); wb_push(wb, &n, 1); uint16_t word = (uint16_t)v; wb_push(wb, &word, 2); } else { - int n = COMBINE_TYPE(type , 4); + uint8_t n = COMBINE_TYPE(type , 4); wb_push(wb, &n, 1); wb_push(wb, &v, 4); } @@ -158,14 +158,14 @@ wb_integer(struct write_block *wb, int v, int type) { static inline void wb_number(struct write_block *wb, double v) { - int n = COMBINE_TYPE(TYPE_NUMBER , 8); + uint8_t n = COMBINE_TYPE(TYPE_NUMBER , 8); wb_push(wb, &n, 1); wb_push(wb, &v, 8); } static inline void wb_pointer(struct write_block *wb, void *v) { - int n = TYPE_USERDATA; + uint8_t n = TYPE_USERDATA; wb_push(wb, &n, 1); wb_push(wb, &v, sizeof(v)); } @@ -173,7 +173,7 @@ wb_pointer(struct write_block *wb, void *v) { static inline void wb_string(struct write_block *wb, const char *str, int len) { if (len < MAX_COOKIE) { - int n = COMBINE_TYPE(TYPE_SHORT_STRING, len); + uint8_t n = COMBINE_TYPE(TYPE_SHORT_STRING, len); wb_push(wb, &n, 1); if (len > 0) { wb_push(wb, str, len); @@ -201,11 +201,11 @@ static int wb_table_array(lua_State *L, struct write_block * wb, int index, int depth) { int array_size = lua_rawlen(L,index); if (array_size >= MAX_COOKIE-1) { - int n = COMBINE_TYPE(TYPE_TABLE, MAX_COOKIE-1); + uint8_t n = COMBINE_TYPE(TYPE_TABLE, MAX_COOKIE-1); wb_push(wb, &n, 1); wb_integer(wb, array_size,TYPE_NUMBER); } else { - int n = COMBINE_TYPE(TYPE_TABLE, array_size); + uint8_t n = COMBINE_TYPE(TYPE_TABLE, array_size); wb_push(wb, &n, 1); } diff --git a/lualib/http/httpc.lua b/lualib/http/httpc.lua index 784b186bb..a73d5710e 100644 --- a/lualib/http/httpc.lua +++ b/lualib/http/httpc.lua @@ -81,10 +81,10 @@ function httpc.request(method, host, url, recvheader, header, content) end local fd = socket.connect(hostname, port) local ok , statuscode, body = pcall(request, fd,method, host, url, recvheader, header, content) + socket.close(fd) if ok then return statuscode, body else - socket.close(fd) error(statuscode) end end diff --git a/lualib/mysql.lua b/lualib/mysql.lua new file mode 100755 index 000000000..b81e2ecca --- /dev/null +++ b/lualib/mysql.lua @@ -0,0 +1,747 @@ +-- Copyright (C) 2012 Yichun Zhang (agentzh) +-- Copyright (C) 2014 Chang Feng +-- This file is modified version from https://github.com/openresty/lua-resty-mysql +-- The license is under the BSD license. + +local socketchannel = require "socketchannel" +local bit = require "bit32" +local mysqlaux = require "mysqlaux.c" + + + +local sub = string.sub +local strbyte = string.byte +local strchar = string.char +local strfind = string.find +local strrep = string.rep +local null = nil +local band = bit.band +local bxor = bit.bxor +local bor = bit.bor +local lshift = bit.lshift +local rshift = bit.rshift +local sha1= mysqlaux.sha1_bin +local concat = table.concat +local unpack = unpack +local setmetatable = setmetatable +local error = error +local tonumber = tonumber +local new_tab = function (narr, nrec) return {} end + + +local _M = { _VERSION = '0.13' } +-- constants + +local STATE_CONNECTED = 1 +local STATE_COMMAND_SENT = 2 + +local COM_QUERY = 0x03 + +local SERVER_MORE_RESULTS_EXISTS = 8 + +-- 16MB - 1, the default max allowed packet size used by libmysqlclient +local FULL_PACKET_SIZE = 16777215 + + +local mt = { __index = _M } + + +-- mysql field value type converters +local converters = new_tab(0, 8) + +for i = 0x01, 0x05 do + -- tiny, short, long, float, double + converters[i] = tonumber +end +-- converters[0x08] = tonumber -- long long +converters[0x09] = tonumber -- int24 +converters[0x0d] = tonumber -- year +converters[0xf6] = tonumber -- newdecimal + + +local function _get_byte2(data, i) + local a, b = strbyte(data, i, i + 1) + return bor(a, lshift(b, 8)), i + 2 +end + + +local function _get_byte3(data, i) + local a, b, c = strbyte(data, i, i + 2) + return bor(a, lshift(b, 8), lshift(c, 16)), i + 3 +end + + +local function _get_byte4(data, i) + local a, b, c, d = strbyte(data, i, i + 3) + return bor(a, lshift(b, 8), lshift(c, 16), lshift(d, 24)), i + 4 +end + + +local function _get_byte8(data, i) + local a, b, c, d, e, f, g, h = strbyte(data, i, i + 7) + + -- XXX workaround for the lack of 64-bit support in bitop: + local lo = bor(a, lshift(b, 8), lshift(c, 16), lshift(d, 24)) + local hi = bor(e, lshift(f, 8), lshift(g, 16), lshift(h, 24)) + return lo + hi * 4294967296, i + 8 + + -- return bor(a, lshift(b, 8), lshift(c, 16), lshift(d, 24), lshift(e, 32), + -- lshift(f, 40), lshift(g, 48), lshift(h, 56)), i + 8 +end + + +local function _set_byte2(n) + return strchar(band(n, 0xff), band(rshift(n, 8), 0xff)) +end + + +local function _set_byte3(n) + return strchar(band(n, 0xff), + band(rshift(n, 8), 0xff), + band(rshift(n, 16), 0xff)) +end + + +local function _set_byte4(n) + return strchar(band(n, 0xff), + band(rshift(n, 8), 0xff), + band(rshift(n, 16), 0xff), + band(rshift(n, 24), 0xff)) +end + + +local function _from_cstring(data, i) + local last = strfind(data, "\0", i, true) + if not last then + return nil, nil + end + + return sub(data, i, last), last + 1 +end + + +local function _to_cstring(data) + return data .. "\0" +end + + +local function _to_binary_coded_string(data) + return strchar(#data) .. data +end + + +local function _dump(data) + local len = #data + local bytes = new_tab(len, 0) + for i = 1, len do + bytes[i] = strbyte(data, i) + end + return concat(bytes, " ") +end + + + +local function _dumphex(bytes) + local result ={} + + for i = 1, string.len(bytes) do + local charcode = tonumber(strbyte(bytes, i, i)) + local hexstr = string.format("%02X", charcode) + result[i]=hexstr + end + + local res=table.concat(result, " ") + return res +end + + +local function _compute_token(password, scramble) + if password == "" then + return "" + end + --_dump(scramble) + + local stage1 = sha1(password) + --print("stage1:", _dumphex(stage1) ) + local stage2 = sha1(stage1) + local stage3 = sha1(scramble .. stage2) + local n = #stage1 + local bytes = new_tab(n, 0) + for i = 1, n do + bytes[i] = strchar(bxor(strbyte(stage3, i), strbyte(stage1, i))) + end + + return concat(bytes) +end + +local function _compose_packet(self, req, size) + self.packet_no = self.packet_no + 1 + + local packet = _set_byte3(size) .. strchar(self.packet_no) .. req + return packet +end + + +local function _send_packet(self, req, size) + local sock = self.sock + + self.packet_no = self.packet_no + 1 + + + local packet = _set_byte3(size) .. strchar(self.packet_no) .. req + + return socket.write(self.sock,packet) +end + + +local function _recv_packet(self,sock) + + + local data = sock:read( 4) + if not data then + return nil, nil, "failed to receive packet header: " + end + + + local len, pos = _get_byte3(data, 1) + + + if len == 0 then + return nil, nil, "empty packet" + end + + if len > self._max_packet_size then + return nil, nil, "packet size too big: " .. len + end + + local num = strbyte(data, pos) + + self.packet_no = num + + data = sock:read(len) + + if not data then + return nil, nil, "failed to read packet content: " + end + + + local field_count = strbyte(data, 1) + local typ + if field_count == 0x00 then + typ = "OK" + elseif field_count == 0xff then + typ = "ERR" + elseif field_count == 0xfe then + typ = "EOF" + elseif field_count <= 250 then + typ = "DATA" + end + + return data, typ +end + + +local function _from_length_coded_bin(data, pos) + local first = strbyte(data, pos) + + if not first then + return nil, pos + end + + if first >= 0 and first <= 250 then + return first, pos + 1 + end + + if first == 251 then + return null, pos + 1 + end + + if first == 252 then + pos = pos + 1 + return _get_byte2(data, pos) + end + + if first == 253 then + pos = pos + 1 + return _get_byte3(data, pos) + end + + if first == 254 then + pos = pos + 1 + return _get_byte8(data, pos) + end + + return false, pos + 1 +end + + +local function _from_length_coded_str(data, pos) + local len + len, pos = _from_length_coded_bin(data, pos) + if len == nil or len == null then + return null, pos + end + + return sub(data, pos, pos + len - 1), pos + len +end + + +local function _parse_ok_packet(packet) + local res = new_tab(0, 5) + local pos + + res.affected_rows, pos = _from_length_coded_bin(packet, 2) + + res.insert_id, pos = _from_length_coded_bin(packet, pos) + + res.server_status, pos = _get_byte2(packet, pos) + + res.warning_count, pos = _get_byte2(packet, pos) + + + local message = sub(packet, pos) + if message and message ~= "" then + res.message = message + end + + + return res +end + + +local function _parse_eof_packet(packet) + local pos = 2 + + local warning_count, pos = _get_byte2(packet, pos) + local status_flags = _get_byte2(packet, pos) + + return warning_count, status_flags +end + + +local function _parse_err_packet(packet) + local errno, pos = _get_byte2(packet, 2) + local marker = sub(packet, pos, pos) + local sqlstate + if marker == '#' then + -- with sqlstate + pos = pos + 1 + sqlstate = sub(packet, pos, pos + 5 - 1) + pos = pos + 5 + end + + local message = sub(packet, pos) + return errno, message, sqlstate +end + + +local function _parse_result_set_header_packet(packet) + local field_count, pos = _from_length_coded_bin(packet, 1) + + local extra + extra = _from_length_coded_bin(packet, pos) + + return field_count, extra +end + + +local function _parse_field_packet(data) + local col = new_tab(0, 2) + local catalog, db, table, orig_table, orig_name, charsetnr, length + local pos + catalog, pos = _from_length_coded_str(data, 1) + + + db, pos = _from_length_coded_str(data, pos) + table, pos = _from_length_coded_str(data, pos) + orig_table, pos = _from_length_coded_str(data, pos) + col.name, pos = _from_length_coded_str(data, pos) + + orig_name, pos = _from_length_coded_str(data, pos) + + pos = pos + 1 -- ignore the filler + + charsetnr, pos = _get_byte2(data, pos) + + length, pos = _get_byte4(data, pos) + + col.type = strbyte(data, pos) + + --[[ + pos = pos + 1 + + col.flags, pos = _get_byte2(data, pos) + + col.decimals = strbyte(data, pos) + pos = pos + 1 + + local default = sub(data, pos + 2) + if default and default ~= "" then + col.default = default + end + --]] + + return col +end + + +local function _parse_row_data_packet(data, cols, compact) + local pos = 1 + local ncols = #cols + local row + if compact then + row = new_tab(ncols, 0) + else + row = new_tab(0, ncols) + end + for i = 1, ncols do + local value + value, pos = _from_length_coded_str(data, pos) + local col = cols[i] + local typ = col.type + local name = col.name + + if value ~= null then + local conv = converters[typ] + if conv then + value = conv(value) + end + end + + if compact then + row[i] = value + + else + row[name] = value + end + end + + return row +end + + +local function _recv_field_packet(self, sock) + local packet, typ, err = _recv_packet(self, sock) + if not packet then + return nil, err + end + + if typ == "ERR" then + local errno, msg, sqlstate = _parse_err_packet(packet) + return nil, msg, errno, sqlstate + end + + if typ ~= 'DATA' then + return nil, "bad field packet type: " .. typ + end + + -- typ == 'DATA' + + return _parse_field_packet(packet) +end + +local function _recv_decode_packet_resp(self) + return function(sock) + return true, _recv_packet(self,sock) + end +end + +local function _recv_auth_resp(self) + return function(sock) + local packet, typ, err = _recv_packet(self,sock) + if not packet then + --print("recv auth resp : failed to receive the result packet") + error ("failed to receive the result packet"..err) + --return nil,err + end + + if typ == 'ERR' then + local errno, msg, sqlstate = _parse_err_packet(packet) + error( string.format("errno:%d, msg:%s,sqlstate:%s",errno,msg,sqlstate)) + --return nil, errno,msg, sqlstate + end + + if typ == 'EOF' then + error "old pre-4.1 authentication protocol not supported" + end + + if typ ~= 'OK' then + error "bad packet type: " + end + return true, true + end +end + + +local function _mysql_login(self,user,password,database) + + return function(sockchannel) + local packet, typ, err = sockchannel:response( _recv_decode_packet_resp(self) ) + --local aat={} + if not packet then + error( err ) + end + + if typ == "ERR" then + local errno, msg, sqlstate = _parse_err_packet(packet) + error( string.format("errno:%d, msg:%s,sqlstate:%s",errno,msg,sqlstate)) + end + + self.protocol_ver = strbyte(packet) + + local server_ver, pos = _from_cstring(packet, 2) + if not server_ver then + error "bad handshake initialization packet: bad server version" + end + + self._server_ver = server_ver + + + local thread_id, pos = _get_byte4(packet, pos) + + local scramble1 = sub(packet, pos, pos + 8 - 1) + if not scramble1 then + error "1st part of scramble not found" + end + + pos = pos + 9 -- skip filler + + -- two lower bytes + self._server_capabilities, pos = _get_byte2(packet, pos) + + self._server_lang = strbyte(packet, pos) + pos = pos + 1 + + self._server_status, pos = _get_byte2(packet, pos) + + local more_capabilities + more_capabilities, pos = _get_byte2(packet, pos) + + self._server_capabilities = bor(self._server_capabilities, + lshift(more_capabilities, 16)) + + + local len = 21 - 8 - 1 + + pos = pos + 1 + 10 + + local scramble_part2 = sub(packet, pos, pos + len - 1) + if not scramble_part2 then + error "2nd part of scramble not found" + end + + + local scramble = scramble1..scramble_part2 + local token = _compute_token(password, scramble) + + local client_flags = 260047; + + local req = _set_byte4(client_flags) + .. _set_byte4(self._max_packet_size) + .. "\0" -- TODO: add support for charset encoding + .. strrep("\0", 23) + .. _to_cstring(user) + .. _to_binary_coded_string(token) + .. _to_cstring(database) + + local packet_len = 4 + 4 + 1 + 23 + #user + 1 + + #token + 1 + #database + 1 + + local authpacket=_compose_packet(self,req,packet_len) + return sockchannel:request(authpacket,_recv_auth_resp(self)) + end +end + + +local function _compose_query(self, query) + + self.packet_no = -1 + + local cmd_packet = strchar(COM_QUERY) .. query + local packet_len = 1 + #query + + local querypacket = _compose_packet(self, cmd_packet, packet_len) + return querypacket +end + + + +local function read_result(self, sock) + local packet, typ, err = _recv_packet(self, sock) + if not packet then + return nil, err + --error( err ) + end + + if typ == "ERR" then + local errno, msg, sqlstate = _parse_err_packet(packet) + return nil, msg, errno, sqlstate + --error( string.format("errno:%d, msg:%s,sqlstate:%s",errno,msg,sqlstate)) + end + + if typ == 'OK' then + local res = _parse_ok_packet(packet) + if res and band(res.server_status, SERVER_MORE_RESULTS_EXISTS) ~= 0 then + return res, "again" + end + return res + end + + if typ ~= 'DATA' then + return nil, "packet type " .. typ .. " not supported" + --error( "packet type " .. typ .. " not supported" ) + end + + -- typ == 'DATA' + + local field_count, extra = _parse_result_set_header_packet(packet) + + local cols = new_tab(field_count, 0) + for i = 1, field_count do + local col, err, errno, sqlstate = _recv_field_packet(self, sock) + if not col then + return nil, err, errno, sqlstate + --error( string.format("errno:%d, msg:%s,sqlstate:%s",errno,msg,sqlstate)) + end + + cols[i] = col + end + + local packet, typ, err = _recv_packet(self, sock) + if not packet then + --error( err) + return nil, err + end + + if typ ~= 'EOF' then + --error ( "unexpected packet type " .. typ .. " while eof packet is ".. "expected" ) + return nil, "unexpected packet type " .. typ .. " while eof packet is ".. "expected" + end + + -- typ == 'EOF' + + local compact = self.compact + + local rows = new_tab( 4, 0) + local i = 0 + while true do + packet, typ, err = _recv_packet(self, sock) + if not packet then + --error (err) + return nil, err + end + + if typ == 'EOF' then + local warning_count, status_flags = _parse_eof_packet(packet) + + if band(status_flags, SERVER_MORE_RESULTS_EXISTS) ~= 0 then + return rows, "again" + end + + break + end + + -- if typ ~= 'DATA' then + -- return nil, 'bad row packet type: ' .. typ + -- end + + -- typ == 'DATA' + + local row = _parse_row_data_packet(packet, cols, compact) + i = i + 1 + rows[i] = row + end + + return rows +end + +local function _query_resp(self) + return function(sock) + local res, err, errno, sqlstate = read_result(self,sock) + if not res then + local badresult ={} + badresult.badresult = true + badresult.err = err + badresult.errno = errno + badresult.sqlstate = sqlstate + return true , badresult + end + if err ~= "again" then + return true, res + end + local mulitresultset = {res} + mulitresultset.mulitresultset = true + local i =2 + while err =="again" do + res, err, errno, sqlstate = read_result(self,sock) + if not res then + return true, mulitresultset + end + mulitresultset[i]=res + i=i+1 + end + return true, mulitresultset + end +end + +function _M.connect( opts) + + local self = setmetatable( {}, mt) + + local max_packet_size = opts.max_packet_size + if not max_packet_size then + max_packet_size = 1024 * 1024 -- default 1 MB + end + self._max_packet_size = max_packet_size + self.compact = opts.compact_arrays + + + local database = opts.database or "" + local user = opts.user or "" + local password = opts.password or "" + + local channel = socketchannel.channel { + host = opts.host, + port = opts.port or 3306, + auth = _mysql_login(self,user,password,database ), + } + -- try connect first only once + channel:connect(true) + self.sockchannel = channel + + + return self +end + + + +function _M.disconnect(self) + self.sockchannel:close() + setmetatable(self, nil) +end + + +function _M.query(self, query) + local querypacket = _compose_query(self, query) + local sockchannel = self.sockchannel + if not self.query_resp then + self.query_resp = _query_resp(self) + end + return sockchannel:request( querypacket, self.query_resp ) +end + +function _M.server_ver(self) + return self._server_ver +end + + +function _M.quote_sql_str( str) + return mysqlaux.quote_sql_str(str) +end + +function _M.set_compact_arrays(self, value) + self.compact = value +end + + +return _M diff --git a/lualib/skynet/queue.lua b/lualib/skynet/queue.lua index 1b6799026..3db244e3a 100644 --- a/lualib/skynet/queue.lua +++ b/lualib/skynet/queue.lua @@ -10,21 +10,20 @@ function skynet.queue() local thread_queue = {} return function(f, ...) local thread = coroutine.running() - if ref == 0 then - current_thread = thread - elseif current_thread ~= thread then + if current_thread and current_thread ~= thread then table.insert(thread_queue, thread) skynet.wait() - assert(ref == 0) + assert(ref == 0) -- current_thread == thread end + current_thread = thread + ref = ref + 1 local ok, err = xpcall(f, traceback, ...) ref = ref - 1 if ref == 0 then - current_thread = nil - local co = table.remove(thread_queue,1) - if co then - skynet.wakeup(co) + current_thread = table.remove(thread_queue,1) + if current_thread then + skynet.wakeup(current_thread) end end assert(ok,err) diff --git a/lualib/socketchannel.lua b/lualib/socketchannel.lua index e82074890..4b90193ba 100644 --- a/lualib/socketchannel.lua +++ b/lualib/socketchannel.lua @@ -135,7 +135,7 @@ local function dispatch_by_order(self) else close_channel_socket(self) local errmsg - if result ~= socket_error then + if result_ok ~= socket_error then errmsg = result_ok end self.__result[co] = socket_error diff --git a/test/testmysql.lua b/test/testmysql.lua new file mode 100644 index 000000000..510625e0e --- /dev/null +++ b/test/testmysql.lua @@ -0,0 +1,127 @@ +local skynet = require "skynet" +local mysql = require "mysql" + +local function dump(obj) + local getIndent, quoteStr, wrapKey, wrapVal, dumpObj + getIndent = function(level) + return string.rep("\t", level) + end + quoteStr = function(str) + return '"' .. string.gsub(str, '"', '\\"') .. '"' + end + wrapKey = function(val) + if type(val) == "number" then + return "[" .. val .. "]" + elseif type(val) == "string" then + return "[" .. quoteStr(val) .. "]" + else + return "[" .. tostring(val) .. "]" + end + end + wrapVal = function(val, level) + if type(val) == "table" then + return dumpObj(val, level) + elseif type(val) == "number" then + return val + elseif type(val) == "string" then + return quoteStr(val) + else + return tostring(val) + end + end + dumpObj = function(obj, level) + if type(obj) ~= "table" then + return wrapVal(obj) + end + level = level + 1 + local tokens = {} + tokens[#tokens + 1] = "{" + for k, v in pairs(obj) do + tokens[#tokens + 1] = getIndent(level) .. wrapKey(k) .. " = " .. wrapVal(v, level) .. "," + end + tokens[#tokens + 1] = getIndent(level - 1) .. "}" + return table.concat(tokens, "\n") + end + return dumpObj(obj, 0) +end + +local function test2( db) + local i=1 + while true do + local res = db:query("select * from cats order by id asc") + print ( "test2 loop times=" ,i,"\n","query result=",dump( res ) ) + res = db:query("select * from cats order by id asc") + print ( "test2 loop times=" ,i,"\n","query result=",dump( res ) ) + + skynet.sleep(1000) + i=i+1 + end +end +local function test3( db) + local i=1 + while true do + local res = db:query("select * from cats order by id asc") + print ( "test3 loop times=" ,i,"\n","query result=",dump( res ) ) + res = db:query("select * from cats order by id asc") + print ( "test3 loop times=" ,i,"\n","query result=",dump( res ) ) + skynet.sleep(1000) + i=i+1 + end +end +skynet.start(function() + + local db=mysql.connect{ + host="127.0.0.1", + port=3306, + database="skynet", + user="root", + password="1", + max_packet_size = 1024 * 1024 + } + if not db then + print("failed to connect") + end + print("testmysql success to connect to mysql server") + + local res = db:query("drop table if exists cats") + res = db:query("create table cats " + .."(id serial primary key, ".. "name varchar(5))") + print( dump( res ) ) + + res = db:query("insert into cats (name) " + .. "values (\'Bob\'),(\'\'),(null)") + print ( dump( res ) ) + + res = db:query("select * from cats order by id asc") + print ( dump( res ) ) + + -- test in another coroutine + skynet.fork( test2, db) + skynet.fork( test3, db) + -- multiresultset test + res = db:query("select * from cats order by id asc ; select * from cats") + print ("multiresultset test result=", dump( res ) ) + + print ("escape string test result=", mysql.quote_sql_str([[\mysql escape %string test'test"]]) ) + + -- bad sql statement + local res = db:query("select * from notexisttable" ) + print( "bad query test result=" ,dump(res) ) + + local i=1 + while true do + local res = db:query("select * from cats order by id asc") + print ( "test1 loop times=" ,i,"\n","query result=",dump( res ) ) + + res = db:query("select * from cats order by id asc") + print ( "test1 loop times=" ,i,"\n","query result=",dump( res ) ) + + + skynet.sleep(1000) + i=i+1 + end + + --db:disconnect() + --skynet.exit() +end) +