[tor-commits] [stem/master] Convert Ed25519Extension to a field

atagar at torproject.org atagar at torproject.org
Sun Nov 17 23:40:39 UTC 2019


commit 064336ee318cd1e8201184400f04ad1af27ebe92
Author: Damian Johnson <atagar at torproject.org>
Date:   Wed Oct 23 15:23:42 2019 -0700

    Convert Ed25519Extension to a field
    
    So much cleaner! Moving unpacking into this class simplifies
    Ed25519CertificateV1's from_base64(), and lays the groundwork for packing.
---
 stem/descriptor/certificate.py      | 66 ++++++++++++++++++++++---------------
 test/unit/descriptor/certificate.py | 10 +++---
 2 files changed, 44 insertions(+), 32 deletions(-)

diff --git a/stem/descriptor/certificate.py b/stem/descriptor/certificate.py
index 92f75719..dcf0c227 100644
--- a/stem/descriptor/certificate.py
+++ b/stem/descriptor/certificate.py
@@ -72,7 +72,6 @@ used to for a variety of purposes...
 
 import base64
 import binascii
-import collections
 import datetime
 import hashlib
 import re
@@ -83,7 +82,7 @@ import stem.prereq
 import stem.util.enum
 import stem.util.str_tools
 
-from stem.client.datatype import Size, split
+from stem.client.datatype import Field, Size, split
 
 # TODO: Importing under an alternate name until we can deprecate our redundant
 # CertType enum in Stem 2.x.
@@ -107,16 +106,50 @@ ExtensionType = stem.util.enum.Enum(('HAS_SIGNING_KEY', 4),)
 ExtensionFlag = stem.util.enum.UppercaseEnum('AFFECTS_VALIDATION', 'UNKNOWN')
 
 
-class Ed25519Extension(collections.namedtuple('Ed25519Extension', ['type', 'flags', 'flag_int', 'data'])):
+class Ed25519Extension(Field):
   """
   Extension within an Ed25519 certificate.
 
-  :var int type: extension type
+  :var stem.descriptor.certificate.ExtensionType type: extension type
   :var list flags: extension attribute flags
   :var int flag_int: integer encoding of the extension attribute flags
   :var bytes data: data the extension concerns
   """
 
+  def __init__(self, ext_type, flag_val, data):
+    self.type = ext_type
+    self.flags = []
+    self.flag_int = flag_val
+    self.data = data
+
+    if flag_val % 2 == 1:
+      self.flags.append(ExtensionFlag.AFFECTS_VALIDATION)
+      flag_val -= 1
+
+    if flag_val:
+      self.flags.append(ExtensionFlag.UNKNOWN)
+
+    if ext_type == ExtensionType.HAS_SIGNING_KEY and len(data) != 32:
+      raise ValueError('Ed25519 HAS_SIGNING_KEY extension must be 32 bytes, but was %i.' % len(data))
+
+  @staticmethod
+  def pop(content):
+    if len(content) < 4:
+      raise ValueError('Ed25519 extension is missing header fields')
+
+    data_size, content = Size.SHORT.pop(content)
+    ext_type, content = Size.CHAR.pop(content)
+    flags, content = Size.CHAR.pop(content)
+    data, content = split(content, data_size)
+
+    if len(data) != data_size:
+      raise ValueError("Ed25519 extension is truncated. It should have %i bytes of data but there's only %i." % (data_size, len(data)))
+
+    return Ed25519Extension(ext_type, flags, data), content
+
+  def __hash__(self):
+    return stem.util._hash_attr(self, 'type', 'flag_int', 'data', cache = True)
+
 
 class Ed25519Certificate(object):
   """
@@ -270,29 +303,8 @@ class Ed25519CertificateV1(Ed25519Certificate):
     extensions = []
 
     for i in range(extension_count):
-      if len(extension_data) < 4:
-        raise ValueError('Ed25519 extension is missing header field data')
-
-      extension_length, extension_data = Size.SHORT.pop(extension_data)
-      extension_type, extension_data = Size.CHAR.pop(extension_data)
-      extension_flags, extension_data = Size.CHAR.pop(extension_data)
-      extension_value, extension_data = split(extension_data, extension_length)
-
-      if extension_length != len(extension_value):
-        raise ValueError("Ed25519 extension is truncated. It should have %i bytes of data but there's only %i." % (extension_length, len(extension_value)))
-      elif extension_type == ExtensionType.HAS_SIGNING_KEY and len(extension_value) != 32:
-        raise ValueError('Ed25519 HAS_SIGNING_KEY extension must be 32 bytes, but was %i.' % len(extension_value))
-
-      flags, remaining_flags = [], extension_flags
-
-      if remaining_flags % 2 == 1:
-        flags.append(ExtensionFlag.AFFECTS_VALIDATION)
-        remaining_flags -= 1
-
-      if remaining_flags:
-        flags.append(ExtensionFlag.UNKNOWN)
-
-      extensions.append(Ed25519Extension(extension_type, flags, extension_flags, extension_value))
+      extension, extension_data = Ed25519Extension.pop(extension_data)
+      extensions.append(extension)
 
     if extension_data:
       raise ValueError('Ed25519 certificate had %i bytes of unused extension data' % len(extension_data))
diff --git a/test/unit/descriptor/certificate.py b/test/unit/descriptor/certificate.py
index 7fd5e731..dc7ff184 100644
--- a/test/unit/descriptor/certificate.py
+++ b/test/unit/descriptor/certificate.py
@@ -13,7 +13,7 @@ import stem.prereq
 import test.require
 
 from stem.client.datatype import CertType
-from stem.descriptor.certificate import ED25519_SIGNATURE_LENGTH, ExtensionType, ExtensionFlag, Ed25519Certificate, Ed25519CertificateV1, Ed25519Extension
+from stem.descriptor.certificate import ED25519_SIGNATURE_LENGTH, ExtensionType, Ed25519Certificate, Ed25519CertificateV1, Ed25519Extension
 from test.unit.descriptor import get_resource
 
 from cryptography.hazmat.primitives import serialization
@@ -69,8 +69,8 @@ class TestEd25519Certificate(unittest.TestCase):
     self.assertEqual(b'\x01' * ED25519_SIGNATURE_LENGTH, cert.signature)
 
     self.assertEqual([
-      Ed25519Extension(type = ExtensionType.HAS_SIGNING_KEY, flags = [ExtensionFlag.AFFECTS_VALIDATION, ExtensionFlag.UNKNOWN], flag_int = 7, data = signing_key),
-      Ed25519Extension(type = 5, flags = [ExtensionFlag.UNKNOWN], flag_int = 4, data = b''),
+      Ed25519Extension(ExtensionType.HAS_SIGNING_KEY, 7, signing_key),
+      Ed25519Extension(5, 4, b''),
     ], cert.extensions)
 
     self.assertEqual(ExtensionType.HAS_SIGNING_KEY, cert.extensions[0].type)
@@ -90,7 +90,7 @@ class TestEd25519Certificate(unittest.TestCase):
     self.assertEqual(datetime.datetime(2015, 8, 28, 17, 0), cert.expiration)
     self.assertEqual(1, cert.key_type)
     self.assertEqual(EXPECTED_CERT_KEY, cert.key)
-    self.assertEqual([Ed25519Extension(type = 4, flags = [], flag_int = 0, data = EXPECTED_EXTENSION_DATA)], cert.extensions)
+    self.assertEqual([Ed25519Extension(4, 0, EXPECTED_EXTENSION_DATA)], cert.extensions)
     self.assertEqual(EXPECTED_SIGNATURE, cert.signature)
 
   def test_non_base64(self):
@@ -141,7 +141,7 @@ class TestEd25519Certificate(unittest.TestCase):
     Include an extension without as much data as it specifies.
     """
 
-    exc_msg = 'Ed25519 extension is missing header field data'
+    exc_msg = 'Ed25519 extension is missing header fields'
     self.assertRaisesWith(ValueError, exc_msg, Ed25519Certificate.from_base64, certificate(extension_data = [b'']))
 
     exc_msg = "Ed25519 extension is truncated. It should have 20480 bytes of data but there's only 2."





More information about the tor-commits mailing list