Source code for jam.utils.codec.decorators.dataclasses

from dataclasses import fields, is_dataclass
from typing import Type, TypeVar, Tuple, Union

from jam.utils.codec.codable import Codable

T = TypeVar("T")


[docs] def decodable_dataclass(cls: Type[T]) -> Type[T]: """ Decorator that adds Codable support to any dataclass. """ # Make the class inherit from Codable if it doesn't already if not issubclass(cls, Codable): cls.__bases__ = (Codable,) + cls.__bases__ def encode_size(self) -> int: size = 0 for field in fields(self): item = getattr(self, field.name) size += item.encode_size() return size def encode_into(self, buffer: bytearray, offset: int = 0) -> int: current_offset = offset # Must be a dataclass if not is_dataclass(self): raise TypeError( f"{self.__class__.__name__} must be a dataclass to use encoding" ) for field in fields(self): item = getattr(self, field.name) size = item.encode_into(buffer, current_offset) current_offset += size return current_offset - offset @staticmethod def decode_from( buffer: Union[bytes, bytearray, memoryview], offset: int = 0 ) -> Tuple[T, int]: current_offset = offset decoded_values = [] # print("Decoding dataclass:", cls.__name__) for field in fields(cls): # type: ignore field_type = field.type value, size = field_type.decode_from(buffer, current_offset) decoded_values.append(value) current_offset += size instance = cls(*decoded_values) return instance, current_offset - offset # type: ignore def __eq__(self, other: object) -> bool: # Compare all fields and values, return true if all are equal return all( getattr(self, field.name) == getattr(other, field.name) for field in fields(self) ) setattr(cls, "__eq__", __eq__) setattr(cls, "encode_size", encode_size) setattr(cls, "encode_into", encode_into) setattr(cls, "decode_from", decode_from) return cls