Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 10 additions & 2 deletions trsfile/parametermap.py
Original file line number Diff line number Diff line change
Expand Up @@ -264,7 +264,11 @@ def deserialize(raw: BytesIO) -> TraceSetParameterMap:
for _ in range(number_of_entries):
name = read_parameter_name(raw)
value = TraceSetParameter.deserialize(raw)
result[name] = value
# Writing `result[name] = value` would cause the overridden `__setitem__`
# method in the `TraceParameterMap` to be called. That overridden method
# does additional type checking. There is no need to do type checking
# when deserializing. So invoke the base class method explicitly.
StringKeyOrderedDict.__setitem__(result, name, value)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why were these updated to their less readable counterparts? That doesn't seem idiomatic to me

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The class TraceParameterMap inherits from StringKeyOrderedDict.

TraceParameterMap overrides the __setitem__ method of StringKeyOrderedDict. In the overridden method, additional type checking is performed.

This type checking is not needed when deserializing, so we call the base class' method.

return result

def serialize(self) -> bytes:
Expand Down Expand Up @@ -474,7 +478,11 @@ def deserialize(raw: bytes, definitions: TraceParameterDefinitionMap) -> TracePa
for key, val in definitions.items():
io_bytes.seek(val.offset)
param = val.param_type.param_class.deserialize(io_bytes, val.length)
result[key] = param
# Writing `result[name] = value` would cause the overridden `__setitem__`
# method in the `TraceParameterMap` to be called. That overridden method
# does additional type checking. There is no need to do type checking
# when deserializing. So invoke the base class method explicitly.
StringKeyOrderedDict.__setitem__(result, key, param)
return result

def serialize(self) -> bytearray:
Expand Down
37 changes: 19 additions & 18 deletions trsfile/traceparameter.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,16 +37,17 @@ def serialize(self) -> bytes:
def _has_expected_type(value: Any) -> bool:
pass

def __init__(self, value):
if type(value) is ndarray and len(value.shape) > 1:
warnings.warn("Flatting multi-dimensional ndarray before adding it to trace parameter.\n"
"Information about dimensions of this ndarray will be lost.")
value = value.flatten()
if value is None or ((type(value) is list or type(value) is ndarray) and len(value) <= 0):
raise ValueError('The value for a TraceParameter cannot be empty')
if not type(self)._has_expected_type(value):
raise TypeError(f'A {type(self).__name__} must have a value of type "{type(self)._expected_type_string}"'
f', but it has a type of {type(value)}')
def __init__(self, value, skip_validation=False):
Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The main change is here. I added an additional parameter skip_validation that by default is set to False, but that is set to True in the deserialization methods.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The rest of the changes are whitespace changes because we don't have a formatter set up in this project.

if not skip_validation:
if type(value) is ndarray and len(value.shape) > 1:
warnings.warn("Flatting multi-dimensional ndarray before adding it to trace parameter.\n"
"Information about dimensions of this ndarray will be lost.")
value = value.flatten()
if value is None or ((type(value) is list or type(value) is ndarray) and len(value) <= 0):
raise ValueError('The value for a TraceParameter cannot be empty')
if not type(self)._has_expected_type(value):
raise TypeError(f'A {type(self).__name__} must have a value of type "{type(self)._expected_type_string}"'
f', but it has a type of {type(value)}')
self.value = value

def __len__(self):
Expand Down Expand Up @@ -84,7 +85,7 @@ def __len__(self):
def deserialize(io_bytes: BytesIO, param_length: int) -> BooleanArrayParameter:
raw_values = io_bytes.read(ParameterType.BOOL.byte_size * param_length)
param_value = [bool(x) for x in list(raw_values)]
return BooleanArrayParameter(param_value)
return BooleanArrayParameter(param_value, skip_validation=True)

def serialize(self) -> bytes:
out = bytearray()
Expand Down Expand Up @@ -112,7 +113,7 @@ def __eq__(self, other):
@staticmethod
def deserialize(io_bytes: BytesIO, param_length: int):
param_value = list(io_bytes.read(ParameterType.BYTE.byte_size * param_length))
return ByteArrayParameter(param_value)
return ByteArrayParameter(param_value, skip_validation=True)

def __str__(self):
return '0x' + bytes(self.value).hex().upper() if self.value else ''
Expand All @@ -139,7 +140,7 @@ class DoubleArrayParameter(TraceParameter):
@staticmethod
def deserialize(io_bytes: BytesIO, param_length: int) -> DoubleArrayParameter:
param_value = [struct.unpack('<d', io_bytes.read(ParameterType.DOUBLE.byte_size))[0] for i in range(param_length)]
return DoubleArrayParameter(param_value)
return DoubleArrayParameter(param_value, skip_validation=True)

def serialize(self) -> bytes:
out = bytearray()
Expand All @@ -162,7 +163,7 @@ class FloatArrayParameter(TraceParameter):
@staticmethod
def deserialize(io_bytes: BytesIO, param_length: int) -> FloatArrayParameter:
param_value = [struct.unpack('<f', io_bytes.read(ParameterType.FLOAT.byte_size))[0] for i in range(param_length)]
return FloatArrayParameter(param_value)
return FloatArrayParameter(param_value, skip_validation=True)

def serialize(self) -> bytes:
out = bytearray()
Expand All @@ -185,7 +186,7 @@ class IntegerArrayParameter(TraceParameter):
@staticmethod
def deserialize(io_bytes: BytesIO, param_length: int) -> IntegerArrayParameter:
param_value = [struct.unpack('<i', io_bytes.read(ParameterType.INT.byte_size))[0] for i in range(param_length)]
return IntegerArrayParameter(param_value)
return IntegerArrayParameter(param_value, skip_validation=True)

def serialize(self) -> bytes:
out = bytearray()
Expand All @@ -208,7 +209,7 @@ class LongArrayParameter(TraceParameter):
@staticmethod
def deserialize(io_bytes: BytesIO, param_length: int) -> LongArrayParameter:
param_value = [struct.unpack('<q', io_bytes.read(ParameterType.LONG.byte_size))[0] for i in range(param_length)]
return LongArrayParameter(param_value)
return LongArrayParameter(param_value, skip_validation=True)

def serialize(self) -> bytes:
out = bytearray()
Expand All @@ -231,7 +232,7 @@ class ShortArrayParameter(TraceParameter):
@staticmethod
def deserialize(io_bytes: BytesIO, param_length: int) -> ShortArrayParameter:
param_value = [struct.unpack('<h', io_bytes.read(ParameterType.SHORT.byte_size))[0] for i in range(param_length)]
return ShortArrayParameter(param_value)
return ShortArrayParameter(param_value, skip_validation=True)

def serialize(self) -> bytes:
out = bytearray()
Expand Down Expand Up @@ -261,7 +262,7 @@ def __eq__(self, other):
def deserialize(io_bytes: BytesIO, param_length: int) -> StringParameter:
bytes_read = io_bytes.read(ParameterType.STRING.byte_size * param_length)
param_value = bytes_read.decode(UTF_8)
return StringParameter(param_value)
return StringParameter(param_value, skip_validation=True)

def serialize(self) -> bytes:
out = bytearray()
Expand Down