From a787a6e626b22d9c1fafe7d851c75dd1d90954f4 Mon Sep 17 00:00:00 2001 From: Rebecka Gulliksson Date: Mon, 18 Jan 2016 16:01:50 +0100 Subject: [PATCH 1/3] Use 'alg' from key as signing algorithm if there is only one key. --- src/jwkest/jws.py | 5 +++++ tests/test_3_jws.py | 6 ++++++ 2 files changed, 11 insertions(+) diff --git a/src/jwkest/jws.py b/src/jwkest/jws.py index a8c957d..7886fd9 100644 --- a/src/jwkest/jws.py +++ b/src/jwkest/jws.py @@ -428,6 +428,11 @@ def alg_keys(self, keys, use, protected=None): if not _alg: self["alg"] = _alg = "none" + if keys is not None and len(keys) == 1: + key = next(iter(keys)) # first element from either list or dict + if key.alg: + _alg = key.alg + if keys: keys = self._pick_keys(keys, use=use, alg=_alg) else: diff --git a/tests/test_3_jws.py b/tests/test_3_jws.py index 11fbe09..51d0acf 100644 --- a/tests/test_3_jws.py +++ b/tests/test_3_jws.py @@ -505,6 +505,12 @@ def test_rs256_rm_signature(): else: assert False +def test_alg_keys_assume_alg_from_single_key(): + expected_alg = "HS256" + keys = [SYMKey(k="foobar", alg=expected_alg)] + + _, _, alg = JWS().alg_keys(keys, "sig") + assert alg == expected_alg if __name__ == "__main__": test_rs256_rm_signature() From c285715e30eb25bcdeac4510e05b8ef86f48f7cd Mon Sep 17 00:00:00 2001 From: Rebecka Gulliksson Date: Tue, 19 Jan 2016 08:31:10 +0100 Subject: [PATCH 2/3] Only use 'alg' from key if it is not already specified. --- src/jwkest/jws.py | 29 +++++++++++++++++------------ tests/test_3_jws.py | 11 +++++++++-- 2 files changed, 26 insertions(+), 14 deletions(-) diff --git a/src/jwkest/jws.py b/src/jwkest/jws.py index 7886fd9..1b2464e 100644 --- a/src/jwkest/jws.py +++ b/src/jwkest/jws.py @@ -406,6 +406,22 @@ def _pick_keys(self, keys, use="", alg=""): return pkey + def _pick_alg(self, keys): + alg = None + try: + alg = self["alg"] + except KeyError: + # try to get alg from key if there is only one + if keys is not None and len(keys) == 1: + key = next(iter(keys)) # first element from either list or dict + if key.alg: + self["alg"] = alg = key.alg + + if not alg: + self["alg"] = alg = "none" + + return alg + def _decode(self, payload): _msg = b64d(bytes(payload)) if "cty" in self: @@ -420,18 +436,7 @@ def dump_header(self): class JWS(JWx): def alg_keys(self, keys, use, protected=None): - try: - _alg = self["alg"] - except KeyError: - self["alg"] = _alg = "none" - else: - if not _alg: - self["alg"] = _alg = "none" - - if keys is not None and len(keys) == 1: - key = next(iter(keys)) # first element from either list or dict - if key.alg: - _alg = key.alg + _alg = self._pick_alg(keys) if keys: keys = self._pick_keys(keys, use=use, alg=_alg) diff --git a/tests/test_3_jws.py b/tests/test_3_jws.py index 51d0acf..316a43e 100644 --- a/tests/test_3_jws.py +++ b/tests/test_3_jws.py @@ -505,11 +505,18 @@ def test_rs256_rm_signature(): else: assert False -def test_alg_keys_assume_alg_from_single_key(): +def test_pick_alg_assume_alg_from_single_key(): expected_alg = "HS256" keys = [SYMKey(k="foobar", alg=expected_alg)] - _, _, alg = JWS().alg_keys(keys, "sig") + alg = JWS()._pick_alg(keys) + assert alg == expected_alg + +def test_pick_alg_dont_get_alg_from_single_key_if_already_specified(): + expected_alg = "RS512" + keys = [RSAKey(key=import_rsa_key_from_file(KEY), alg="RS256")] + + alg = JWS(alg=expected_alg)._pick_alg(keys) assert alg == expected_alg if __name__ == "__main__": From d2cb76991bce4d3a91a2a2618dab78533a345a20 Mon Sep 17 00:00:00 2001 From: Rebecka Gulliksson Date: Wed, 20 Jan 2016 14:47:40 +0100 Subject: [PATCH 3/3] Raise exception if item can't be converted to avoid silently signing None. --- src/jwkest/jwt.py | 4 +++- tests/test_1_jwt.py | 9 ++++++++- 2 files changed, 11 insertions(+), 2 deletions(-) diff --git a/src/jwkest/jwt.py b/src/jwkest/jwt.py index 7a54ac0..5dc80c8 100644 --- a/src/jwkest/jwt.py +++ b/src/jwkest/jwt.py @@ -17,13 +17,15 @@ def split_token(token): def b2s_conv(item): if isinstance(item, bytes): return item.decode("utf-8") - elif isinstance(item, (six.string_types, int, bool)): + elif item is None or isinstance(item, (six.string_types, int, bool)): return item elif isinstance(item, list): return [b2s_conv(i) for i in item] elif isinstance(item, dict): return dict([(k, b2s_conv(v)) for k, v in item.items()]) + raise ValueError("Can't convert {}.".format(repr(item))) + def b64encode_item(item): if isinstance(item, bytes): diff --git a/tests/test_1_jwt.py b/tests/test_1_jwt.py index 0c8785d..72e1057 100644 --- a/tests/test_1_jwt.py +++ b/tests/test_1_jwt.py @@ -1,5 +1,8 @@ import json -from jwkest.jwt import JWT + +import pytest + +from jwkest.jwt import JWT, b2s_conv __author__ = 'roland' @@ -44,6 +47,10 @@ def test_unpack_str(): assert _jwt2 out_payload = _jwt2.payload() +def test_b2s_conv_raise_exception_on_bad_value(): + with pytest.raises(ValueError): + b2s_conv(object()) + if __name__ == "__main__": test_unpack_str()