From 99cd99e8eb125ffd61c67898090d5b640608394d Mon Sep 17 00:00:00 2001 From: Junseok Yang Date: Tue, 27 Mar 2018 14:55:33 -0700 Subject: [PATCH] feat: Support parameters This commit allows users to pass some types of Python values as parameters. Supported types are as follows. * Basic types * string * integer * `float` * `bool`(`True`/`False`) * `None` * Collection * `dict` -- only basic types can be key * `list`/`tuple` --- agensgraph/__init__.py | 3 +- agensgraph/_property.py | 148 +++++++++++++++++++++++++++++++++++++++ tests/test_agensgraph.py | 33 ++++++++- tests/test_property.py | 53 ++++++++++++++ 4 files changed, 235 insertions(+), 2 deletions(-) create mode 100644 agensgraph/_property.py create mode 100644 tests/test_property.py diff --git a/agensgraph/__init__.py b/agensgraph/__init__.py index b67ccca..7083912 100644 --- a/agensgraph/__init__.py +++ b/agensgraph/__init__.py @@ -21,6 +21,7 @@ from agensgraph._vertex import Vertex, cast_vertex as _cast_vertex from agensgraph._edge import Edge, cast_edge as _cast_edge from agensgraph._graphpath import Path, cast_graphpath as _cast_graphpath +from agensgraph._property import Property _GRAPHID_OID = 7002 _VERTEX_OID = 7012 @@ -40,5 +41,5 @@ PATH = _ext.new_type((_GRAPHPATH_OID,), 'PATH', _cast_graphpath) _ext.register_type(PATH) -__all__ = ['GraphId', 'Vertex', 'Edge', 'Path', +__all__ = ['GraphId', 'Vertex', 'Edge', 'Path', 'Property', 'GRAPHID', 'VERTEX', 'EDGE', 'PATH'] diff --git a/agensgraph/_property.py b/agensgraph/_property.py new file mode 100644 index 0000000..898cf8a --- /dev/null +++ b/agensgraph/_property.py @@ -0,0 +1,148 @@ +''' +Copyright (c) 2014-2018, Bitnine Inc. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +''' + +import sys + +from psycopg2.extensions import ISQLQuote +from psycopg2.extras import json + +# borrowed from simplejson's compat.py +if sys.version_info[0] < 3: + string_types = (basestring,) + integer_types = (int, long) + def dict_items(o): + return o.iteritems() +else: + string_types = (str,) + integer_types = (int,) + def dict_items(o): + return o.items() + +def quote_string(s): + s = s[1:-1] + s = "'" + s.replace("'", "''") + "'" + return s + +class PropertyEncoder(object): + def encode(self, o): + chunks = self.iterencode(o) + if not isinstance(chunks, (list, tuple)): + chunks = list(chunks) + return ''.join(chunks) + + def iterencode(self, o): + markers = {} + _iterencode = _make_iterencode(markers, json.dumps, quote_string) + return _iterencode(o) + +def _make_iterencode(markers, _encoder, _quote_string, + dict=dict, + float=float, + id=id, + isinstance=isinstance, + list=list, + tuple=tuple, + string_types=string_types, + integer_types=integer_types, + dict_items=dict_items): + def _iterencode_list(o): + if not o: + yield '[]' + return + + markerid = id(o) + if markerid in markers: + raise ValueError('Circular reference detected') + markers[markerid] = o + + yield '[' + first = True + for e in o: + if first: + first = False + else: + yield ',' + + for chunk in _iterencode(e): + yield chunk + yield ']' + + del markers[markerid] + + def _iterencode_dict(o): + if not o: + yield '{}' + return + + markerid = id(o) + if markerid in markers: + raise ValueError('Circular reference detected') + markers[markerid] = o + + yield '{' + first = True + for k, v in dict_items(o): + if isinstance(k, string_types): + pass + elif (k is True or k is False or k is None or + isinstance(k, integer_types) or isinstance(k, float)): + k = _encoder(k) + else: + raise TypeError('keys must be str, int, float, bool or None, ' + 'not %s' % k.__class__.__name__) + + if first: + first = False + else: + yield ',' + + yield _quote_string(_encoder(k)) + yield ':' + for chunk in _iterencode(v): + yield chunk + yield '}' + + del markers[markerid] + + def _iterencode(o): + if isinstance(o, string_types): + yield _quote_string(_encoder(o)) + elif isinstance(o, (list, tuple)): + for chunk in _iterencode_list(o): + yield chunk + elif isinstance(o, dict): + for chunk in _iterencode_dict(o): + yield chunk + else: + yield _encoder(o) + + return _iterencode + +_default_encoder = PropertyEncoder() + +class Property(object): + def __init__(self, value): + self.value = value + + def __conform__(self, proto): + if proto is ISQLQuote: + return self + + def prepare(self, conn): + self._conn = conn + + def getquoted(self): + return _default_encoder.encode(self.value) diff --git a/tests/test_agensgraph.py b/tests/test_agensgraph.py index ff64eb4..6e8688a 100644 --- a/tests/test_agensgraph.py +++ b/tests/test_agensgraph.py @@ -1,5 +1,5 @@ ''' -Copyright (c) 2014-2017, Bitnine Inc. +Copyright (c) 2014-2018, Bitnine Inc. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -124,5 +124,36 @@ def test_path(self): (e.eid,)) self.assertEqual(1, self.cur.fetchone()[0]) +class TestParam(TestConnection): + def setUp(self): + super(TestParam, self).setUp() + self.name = "'Agens\"Graph'" + + def test_param_dict(self): + d = {'name': self.name, 'since': 2016} + p = agensgraph.Property(d) + self.cur.execute('CREATE (n %s) RETURN n', (p,)) + self.conn.commit() + + v = self.cur.fetchone()[0] + self.assertEqual(self.name, v.props['name']) + self.assertEqual(2016, v.props['since']) + + def test_param_list_and_tuple(self): + a = [self.name, 2016] + t = (self.name, 2016) + pa = agensgraph.Property(a) + pt = agensgraph.Property(t) + self.cur.execute('CREATE (n {a: %s, t: %s}) RETURN n', (pa, pt)) + self.conn.commit() + + v = self.cur.fetchone()[0] + va = v.props['a'] + self.assertEqual(self.name, va[0]) + self.assertEqual(2016, va[1]) + vt = v.props['t'] + self.assertEqual(self.name, vt[0]) + self.assertEqual(2016, vt[1]) + if __name__ == '__main__': unittest.main() diff --git a/tests/test_property.py b/tests/test_property.py new file mode 100644 index 0000000..9cab65b --- /dev/null +++ b/tests/test_property.py @@ -0,0 +1,53 @@ +''' +Copyright (c) 2014-2018, Bitnine Inc. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +''' + +import unittest + +from agensgraph._property import Property + +from psycopg2.extensions import QuotedString, adapt + +class TestProperty(unittest.TestCase): + def test_string(self): + self.assertEqual(r"'\"'", Property('"').getquoted()) + self.assertEqual(r"''''", Property("'").getquoted()) + + def test_number(self): + self.assertEqual('0', Property(0).getquoted()) + self.assertEqual('-1', Property(-1).getquoted()) + self.assertEqual('3.14159', Property(3.14159).getquoted()) + + def test_boolean(self): + self.assertEqual('true', Property(True).getquoted()) + self.assertEqual('false', Property(False).getquoted()) + + def test_null(self): + self.assertEqual('null', Property(None).getquoted()) + + def test_array(self): + a = ["'\\\"'", 3.14159, True, None, (), {}] + e = "['''\\\\\\\"''',3.14159,true,null,[],{}]" + self.assertEqual(e, Property(a).getquoted()) + + def test_object(self): + self.assertEqual("{'\\\"':'\\\"'}", Property({'"': '"'}).getquoted()) + self.assertEqual("{'3.14159':3.14159}", + Property({3.14159: 3.14159}).getquoted()) + self.assertEqual("{'true':false}", Property({True: False}).getquoted()) + self.assertEqual("{'null':null}", Property({None: None}).getquoted()) + self.assertEqual("{'a':[]}", Property({'a': []}).getquoted()) + self.assertEqual("{'o':{}}", Property({'o': {}}).getquoted()) + self.assertRaises(TypeError, Property({(): None}).getquoted)