-
Notifications
You must be signed in to change notification settings - Fork 11
Fix performance of reading trace parameters #59
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -37,16 +37,17 @@ def serialize(self) -> bytes: | |
| def _has_expected_type(value: Any) -> bool: | ||
| pass | ||
|
|
||
| def __init__(self, value): | ||
| if type(value) is ndarray and len(value.shape) > 1: | ||
| warnings.warn("Flatting multi-dimensional ndarray before adding it to trace parameter.\n" | ||
| "Information about dimensions of this ndarray will be lost.") | ||
| value = value.flatten() | ||
| if value is None or ((type(value) is list or type(value) is ndarray) and len(value) <= 0): | ||
| raise ValueError('The value for a TraceParameter cannot be empty') | ||
| if not type(self)._has_expected_type(value): | ||
| raise TypeError(f'A {type(self).__name__} must have a value of type "{type(self)._expected_type_string}"' | ||
| f', but it has a type of {type(value)}') | ||
| def __init__(self, value, skip_validation=False): | ||
|
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The main change is here. I added an additional parameter
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The rest of the changes are whitespace changes because we don't have a formatter set up in this project. |
||
| if not skip_validation: | ||
| if type(value) is ndarray and len(value.shape) > 1: | ||
| warnings.warn("Flatting multi-dimensional ndarray before adding it to trace parameter.\n" | ||
| "Information about dimensions of this ndarray will be lost.") | ||
| value = value.flatten() | ||
| if value is None or ((type(value) is list or type(value) is ndarray) and len(value) <= 0): | ||
| raise ValueError('The value for a TraceParameter cannot be empty') | ||
| if not type(self)._has_expected_type(value): | ||
| raise TypeError(f'A {type(self).__name__} must have a value of type "{type(self)._expected_type_string}"' | ||
| f', but it has a type of {type(value)}') | ||
| self.value = value | ||
|
|
||
| def __len__(self): | ||
|
|
@@ -84,7 +85,7 @@ def __len__(self): | |
| def deserialize(io_bytes: BytesIO, param_length: int) -> BooleanArrayParameter: | ||
| raw_values = io_bytes.read(ParameterType.BOOL.byte_size * param_length) | ||
| param_value = [bool(x) for x in list(raw_values)] | ||
| return BooleanArrayParameter(param_value) | ||
| return BooleanArrayParameter(param_value, skip_validation=True) | ||
|
|
||
| def serialize(self) -> bytes: | ||
| out = bytearray() | ||
|
|
@@ -112,7 +113,7 @@ def __eq__(self, other): | |
| @staticmethod | ||
| def deserialize(io_bytes: BytesIO, param_length: int): | ||
| param_value = list(io_bytes.read(ParameterType.BYTE.byte_size * param_length)) | ||
| return ByteArrayParameter(param_value) | ||
| return ByteArrayParameter(param_value, skip_validation=True) | ||
|
|
||
| def __str__(self): | ||
| return '0x' + bytes(self.value).hex().upper() if self.value else '' | ||
|
|
@@ -139,7 +140,7 @@ class DoubleArrayParameter(TraceParameter): | |
| @staticmethod | ||
| def deserialize(io_bytes: BytesIO, param_length: int) -> DoubleArrayParameter: | ||
| param_value = [struct.unpack('<d', io_bytes.read(ParameterType.DOUBLE.byte_size))[0] for i in range(param_length)] | ||
| return DoubleArrayParameter(param_value) | ||
| return DoubleArrayParameter(param_value, skip_validation=True) | ||
|
|
||
| def serialize(self) -> bytes: | ||
| out = bytearray() | ||
|
|
@@ -162,7 +163,7 @@ class FloatArrayParameter(TraceParameter): | |
| @staticmethod | ||
| def deserialize(io_bytes: BytesIO, param_length: int) -> FloatArrayParameter: | ||
| param_value = [struct.unpack('<f', io_bytes.read(ParameterType.FLOAT.byte_size))[0] for i in range(param_length)] | ||
| return FloatArrayParameter(param_value) | ||
| return FloatArrayParameter(param_value, skip_validation=True) | ||
|
|
||
| def serialize(self) -> bytes: | ||
| out = bytearray() | ||
|
|
@@ -185,7 +186,7 @@ class IntegerArrayParameter(TraceParameter): | |
| @staticmethod | ||
| def deserialize(io_bytes: BytesIO, param_length: int) -> IntegerArrayParameter: | ||
| param_value = [struct.unpack('<i', io_bytes.read(ParameterType.INT.byte_size))[0] for i in range(param_length)] | ||
| return IntegerArrayParameter(param_value) | ||
| return IntegerArrayParameter(param_value, skip_validation=True) | ||
|
|
||
| def serialize(self) -> bytes: | ||
| out = bytearray() | ||
|
|
@@ -208,7 +209,7 @@ class LongArrayParameter(TraceParameter): | |
| @staticmethod | ||
| def deserialize(io_bytes: BytesIO, param_length: int) -> LongArrayParameter: | ||
| param_value = [struct.unpack('<q', io_bytes.read(ParameterType.LONG.byte_size))[0] for i in range(param_length)] | ||
| return LongArrayParameter(param_value) | ||
| return LongArrayParameter(param_value, skip_validation=True) | ||
|
|
||
| def serialize(self) -> bytes: | ||
| out = bytearray() | ||
|
|
@@ -231,7 +232,7 @@ class ShortArrayParameter(TraceParameter): | |
| @staticmethod | ||
| def deserialize(io_bytes: BytesIO, param_length: int) -> ShortArrayParameter: | ||
| param_value = [struct.unpack('<h', io_bytes.read(ParameterType.SHORT.byte_size))[0] for i in range(param_length)] | ||
| return ShortArrayParameter(param_value) | ||
| return ShortArrayParameter(param_value, skip_validation=True) | ||
|
|
||
| def serialize(self) -> bytes: | ||
| out = bytearray() | ||
|
|
@@ -261,7 +262,7 @@ def __eq__(self, other): | |
| def deserialize(io_bytes: BytesIO, param_length: int) -> StringParameter: | ||
| bytes_read = io_bytes.read(ParameterType.STRING.byte_size * param_length) | ||
| param_value = bytes_read.decode(UTF_8) | ||
| return StringParameter(param_value) | ||
| return StringParameter(param_value, skip_validation=True) | ||
|
|
||
| def serialize(self) -> bytes: | ||
| out = bytearray() | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why were these updated to their less readable counterparts? That doesn't seem idiomatic to me
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The class
TraceParameterMapinherits fromStringKeyOrderedDict.TraceParameterMapoverrides the__setitem__method ofStringKeyOrderedDict. In the overridden method, additional type checking is performed.This type checking is not needed when deserializing, so we call the base class' method.