Skip to content

Commit

Permalink
feat: Support parameters
Browse files Browse the repository at this point in the history
This commit allows users to pass some types of Python values as
parameters. Supported types are as follows.

* Basic types
  * string
  * integer
  * `float`
  * `bool`(`True`/`False`)
  * `None`
* Collection
  * `dict` -- only basic types can be key
  * `list`/`tuple`
  • Loading branch information
protodef committed Mar 27, 2018
1 parent 64c8b45 commit 99cd99e
Show file tree
Hide file tree
Showing 4 changed files with 235 additions and 2 deletions.
3 changes: 2 additions & 1 deletion agensgraph/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from agensgraph._vertex import Vertex, cast_vertex as _cast_vertex
from agensgraph._edge import Edge, cast_edge as _cast_edge
from agensgraph._graphpath import Path, cast_graphpath as _cast_graphpath
from agensgraph._property import Property

_GRAPHID_OID = 7002
_VERTEX_OID = 7012
Expand All @@ -40,5 +41,5 @@
PATH = _ext.new_type((_GRAPHPATH_OID,), 'PATH', _cast_graphpath)
_ext.register_type(PATH)

__all__ = ['GraphId', 'Vertex', 'Edge', 'Path',
__all__ = ['GraphId', 'Vertex', 'Edge', 'Path', 'Property',
'GRAPHID', 'VERTEX', 'EDGE', 'PATH']
148 changes: 148 additions & 0 deletions agensgraph/_property.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,148 @@
'''
Copyright (c) 2014-2018, Bitnine Inc.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
'''

import sys

from psycopg2.extensions import ISQLQuote
from psycopg2.extras import json

# borrowed from simplejson's compat.py
if sys.version_info[0] < 3:
string_types = (basestring,)
integer_types = (int, long)
def dict_items(o):
return o.iteritems()
else:
string_types = (str,)
integer_types = (int,)
def dict_items(o):
return o.items()

def quote_string(s):
s = s[1:-1]
s = "'" + s.replace("'", "''") + "'"
return s

class PropertyEncoder(object):
def encode(self, o):
chunks = self.iterencode(o)
if not isinstance(chunks, (list, tuple)):
chunks = list(chunks)
return ''.join(chunks)

def iterencode(self, o):
markers = {}
_iterencode = _make_iterencode(markers, json.dumps, quote_string)
return _iterencode(o)

def _make_iterencode(markers, _encoder, _quote_string,
dict=dict,
float=float,
id=id,
isinstance=isinstance,
list=list,
tuple=tuple,
string_types=string_types,
integer_types=integer_types,
dict_items=dict_items):
def _iterencode_list(o):
if not o:
yield '[]'
return

markerid = id(o)
if markerid in markers:
raise ValueError('Circular reference detected')
markers[markerid] = o

yield '['
first = True
for e in o:
if first:
first = False
else:
yield ','

for chunk in _iterencode(e):
yield chunk
yield ']'

del markers[markerid]

def _iterencode_dict(o):
if not o:
yield '{}'
return

markerid = id(o)
if markerid in markers:
raise ValueError('Circular reference detected')
markers[markerid] = o

yield '{'
first = True
for k, v in dict_items(o):
if isinstance(k, string_types):
pass
elif (k is True or k is False or k is None or
isinstance(k, integer_types) or isinstance(k, float)):
k = _encoder(k)
else:
raise TypeError('keys must be str, int, float, bool or None, '
'not %s' % k.__class__.__name__)

if first:
first = False
else:
yield ','

yield _quote_string(_encoder(k))
yield ':'
for chunk in _iterencode(v):
yield chunk
yield '}'

del markers[markerid]

def _iterencode(o):
if isinstance(o, string_types):
yield _quote_string(_encoder(o))
elif isinstance(o, (list, tuple)):
for chunk in _iterencode_list(o):
yield chunk
elif isinstance(o, dict):
for chunk in _iterencode_dict(o):
yield chunk
else:
yield _encoder(o)

return _iterencode

_default_encoder = PropertyEncoder()

class Property(object):
def __init__(self, value):
self.value = value

def __conform__(self, proto):
if proto is ISQLQuote:
return self

def prepare(self, conn):
self._conn = conn

def getquoted(self):
return _default_encoder.encode(self.value)
33 changes: 32 additions & 1 deletion tests/test_agensgraph.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
'''
Copyright (c) 2014-2017, Bitnine Inc.
Copyright (c) 2014-2018, Bitnine Inc.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -124,5 +124,36 @@ def test_path(self):
(e.eid,))
self.assertEqual(1, self.cur.fetchone()[0])

class TestParam(TestConnection):
def setUp(self):
super(TestParam, self).setUp()
self.name = "'Agens\"Graph'"

def test_param_dict(self):
d = {'name': self.name, 'since': 2016}
p = agensgraph.Property(d)
self.cur.execute('CREATE (n %s) RETURN n', (p,))
self.conn.commit()

v = self.cur.fetchone()[0]
self.assertEqual(self.name, v.props['name'])
self.assertEqual(2016, v.props['since'])

def test_param_list_and_tuple(self):
a = [self.name, 2016]
t = (self.name, 2016)
pa = agensgraph.Property(a)
pt = agensgraph.Property(t)
self.cur.execute('CREATE (n {a: %s, t: %s}) RETURN n', (pa, pt))
self.conn.commit()

v = self.cur.fetchone()[0]
va = v.props['a']
self.assertEqual(self.name, va[0])
self.assertEqual(2016, va[1])
vt = v.props['t']
self.assertEqual(self.name, vt[0])
self.assertEqual(2016, vt[1])

if __name__ == '__main__':
unittest.main()
53 changes: 53 additions & 0 deletions tests/test_property.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
'''
Copyright (c) 2014-2018, Bitnine Inc.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
'''

import unittest

from agensgraph._property import Property

from psycopg2.extensions import QuotedString, adapt

class TestProperty(unittest.TestCase):
def test_string(self):
self.assertEqual(r"'\"'", Property('"').getquoted())
self.assertEqual(r"''''", Property("'").getquoted())

def test_number(self):
self.assertEqual('0', Property(0).getquoted())
self.assertEqual('-1', Property(-1).getquoted())
self.assertEqual('3.14159', Property(3.14159).getquoted())

def test_boolean(self):
self.assertEqual('true', Property(True).getquoted())
self.assertEqual('false', Property(False).getquoted())

def test_null(self):
self.assertEqual('null', Property(None).getquoted())

def test_array(self):
a = ["'\\\"'", 3.14159, True, None, (), {}]
e = "['''\\\\\\\"''',3.14159,true,null,[],{}]"
self.assertEqual(e, Property(a).getquoted())

def test_object(self):
self.assertEqual("{'\\\"':'\\\"'}", Property({'"': '"'}).getquoted())
self.assertEqual("{'3.14159':3.14159}",
Property({3.14159: 3.14159}).getquoted())
self.assertEqual("{'true':false}", Property({True: False}).getquoted())
self.assertEqual("{'null':null}", Property({None: None}).getquoted())
self.assertEqual("{'a':[]}", Property({'a': []}).getquoted())
self.assertEqual("{'o':{}}", Property({'o': {}}).getquoted())
self.assertRaises(TypeError, Property({(): None}).getquoted)

0 comments on commit 99cd99e

Please sign in to comment.