diff --git a/src/jwkest/jws.py b/src/jwkest/jws.py index a8c957d..1b2464e 100644 --- a/src/jwkest/jws.py +++ b/src/jwkest/jws.py @@ -406,6 +406,22 @@ def _pick_keys(self, keys, use="", alg=""): return pkey + def _pick_alg(self, keys): + alg = None + try: + alg = self["alg"] + except KeyError: + # try to get alg from key if there is only one + if keys is not None and len(keys) == 1: + key = next(iter(keys)) # first element from either list or dict + if key.alg: + self["alg"] = alg = key.alg + + if not alg: + self["alg"] = alg = "none" + + return alg + def _decode(self, payload): _msg = b64d(bytes(payload)) if "cty" in self: @@ -420,13 +436,7 @@ def dump_header(self): class JWS(JWx): def alg_keys(self, keys, use, protected=None): - try: - _alg = self["alg"] - except KeyError: - self["alg"] = _alg = "none" - else: - if not _alg: - self["alg"] = _alg = "none" + _alg = self._pick_alg(keys) if keys: keys = self._pick_keys(keys, use=use, alg=_alg) diff --git a/src/jwkest/jwt.py b/src/jwkest/jwt.py index 7a54ac0..5dc80c8 100644 --- a/src/jwkest/jwt.py +++ b/src/jwkest/jwt.py @@ -17,13 +17,15 @@ def split_token(token): def b2s_conv(item): if isinstance(item, bytes): return item.decode("utf-8") - elif isinstance(item, (six.string_types, int, bool)): + elif item is None or isinstance(item, (six.string_types, int, bool)): return item elif isinstance(item, list): return [b2s_conv(i) for i in item] elif isinstance(item, dict): return dict([(k, b2s_conv(v)) for k, v in item.items()]) + raise ValueError("Can't convert {}.".format(repr(item))) + def b64encode_item(item): if isinstance(item, bytes): diff --git a/tests/test_1_jwt.py b/tests/test_1_jwt.py index 0c8785d..72e1057 100644 --- a/tests/test_1_jwt.py +++ b/tests/test_1_jwt.py @@ -1,5 +1,8 @@ import json -from jwkest.jwt import JWT + +import pytest + +from jwkest.jwt import JWT, b2s_conv __author__ = 'roland' @@ -44,6 +47,10 @@ def test_unpack_str(): assert _jwt2 out_payload = _jwt2.payload() +def test_b2s_conv_raise_exception_on_bad_value(): + with pytest.raises(ValueError): + b2s_conv(object()) + if __name__ == "__main__": test_unpack_str() diff --git a/tests/test_3_jws.py b/tests/test_3_jws.py index 11fbe09..316a43e 100644 --- a/tests/test_3_jws.py +++ b/tests/test_3_jws.py @@ -505,6 +505,19 @@ def test_rs256_rm_signature(): else: assert False +def test_pick_alg_assume_alg_from_single_key(): + expected_alg = "HS256" + keys = [SYMKey(k="foobar", alg=expected_alg)] + + alg = JWS()._pick_alg(keys) + assert alg == expected_alg + +def test_pick_alg_dont_get_alg_from_single_key_if_already_specified(): + expected_alg = "RS512" + keys = [RSAKey(key=import_rsa_key_from_file(KEY), alg="RS256")] + + alg = JWS(alg=expected_alg)._pick_alg(keys) + assert alg == expected_alg if __name__ == "__main__": test_rs256_rm_signature()