Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 9 additions & 10 deletions jwcrypto/jwk.py
Original file line number Diff line number Diff line change
Expand Up @@ -1356,19 +1356,18 @@ def import_keyset(self, keyset):
"""
try:
jwkset = json_decode(keyset)
if 'keys' not in jwkset:
raise ValueError("'keys' not in set")

for k, v in jwkset.items():
if k == 'keys':
for jwk in v:
self['keys'].add(JWK(**jwk))
else:
self[k] = v
Comment on lines +1359 to +1367
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we're changing this around, can't we also improve it?
Something like:

Suggested change
if 'keys' not in jwkset:
raise ValueError("'keys' not in set")
for k, v in jwkset.items():
if k == 'keys':
for jwk in v:
self['keys'].add(JWK(**jwk))
else:
self[k] = v
self["keys"].update(JWK(**jwk) for jwk in jwkset.pop("keys"))
self.update(jwkset)

We don't need to raise a ValueError, because we'll get a KeyError in pop("keys") that'll be caught in the except block and re-raised.

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I prefer to keep things readable, and I can't drop the 'keys' check as then I would miss the case when jwkset is an empty dict, which is also invalid.

except Exception as e: # pylint: disable=broad-except
raise InvalidJWKValue from e

if 'keys' not in jwkset:
raise InvalidJWKValue

for k, v in jwkset.items():
if k == 'keys':
for jwk in v:
self['keys'].add(JWK(**jwk))
else:
self[k] = v

@classmethod
def from_json(cls, keyset):
"""Creates a RFC 7517 key set from the standard JSON format.
Expand Down
15 changes: 15 additions & 0 deletions jwcrypto/tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -562,6 +562,21 @@ def test_jwkset_issue_208(self):
self.assertEqual(len(ks['keys']), 2)
self.assertEqual(len(ks['keys']), len(ks2['keys']))

def test_import_keyset_invalid(self):
ks = jwk.JWKSet()
invalid_inputs = [
'',
'null',
'[]',
'{}',
'{"keys": 1}',
'{"keys": [1]}',
'{"keys": [{"kty": "invalid"}]}'
]
for inp in invalid_inputs:
with self.assertRaises(jwk.InvalidJWKValue):
ks.import_keyset(inp)

def test_thumbprint(self):
for i in range(0, len(PublicKeys['keys'])):
k = jwk.JWK(**PublicKeys['keys'][i])
Expand Down
Loading