diff --git a/docs/syntax/arrays.md b/docs/syntax/arrays.md new file mode 100644 index 0000000..f50e47c --- /dev/null +++ b/docs/syntax/arrays.md @@ -0,0 +1,98 @@ +Multi-dimensional arrays are often used to represented dimensionally structured +data. Packtype's syntax allows any type to be arrayed with an arbitrary number +of dimensions and dimension sizes. The base type can be a simple [scalar](scalar.md), +or can reference a more complex type like a [struct](struct.md) or [union](union.md), +you can even reference another multi-dimensional array! + +## Example + +The Packtype definition can either use a Python dataclass style or the Packtype +custom grammar: + +=== "Python (.py)" + + ```python linenums="1" + import packtype + from packtype import Constant, Scalar + + @packtype.package() + class Package1D: + Scalar1D : Scalar[4] + + @Package1D.struct() + class Struct1D: + field_a : Scalar[2] + field_b : Scalar[3] + + @packtype.package() + class Package3D: + Scalar3D : Package1D.Scalar1D[4][5] + Struct3D : Package1D.Struct1D[3][2] + ``` + +=== "Packtype (.pt)" + + ```sv linenums="1" + package package_1d { + scalar_1d_t : scalar[4] + + struct struct_1d_t { + field_a : scalar[2] + field_b : scalar[3] + } + } + + package package_3d { + scalar_3d_t : package_1d::scalar_1d_t[4][5] + struct_3d_t : package_1d::struct_1d_t[3][2] + } + ``` + +As rendered to SystemVerilog + +```sv linenums="1" +package package_1d; + +typedef logic [3:0] scalar_1d_t; + +typedef struct packed { + logic [2:0] field_b; + logic [1:0] field_a; +} struct_1d_t; + +endpackage : package_1d + +package package_3d; + +import package_1d::scalar_1d_t; +import package_1d::struct_1d_t; + +typedef scalar_1d_t [4:0][3:0] scalar_3d_t; +typedef struct_1d_t [1:0][2:0] struct_3d_t; + +endpackage : package_3d +``` + +!!! warning + + The order of dimensions is _reversed_ when compared to declaring a packed + multi-dimensional array in SystemVerilog. For example `scalar[4][5][6]` + declares a 6x5 array of 4-bit elements, which in SystemVerilog would be + written `logic [5:0][4:0][3:0]`. This is done to make it easier to parse the + syntax, as decisions can be made reading left-to-right. + +## Helper Properties and Methods + +Struct definitions expose a collection of helper functions for properties related +to the type: + + * `._pt_width` - property that returns the bit width of the entire array; + * `._pt_pack()` - packs all values contained within the array into a + singular integer value (can also be achieved by casting to an int, e.g. + `int()`); + * `._pt_unpack(packed: int)` - unpacks an integer value into the entries + of the array; + * `len()` - returns the size of the outermost dimension of the array; + * `[X]` - accesses element X within the array, which may return either + an instance of the base type _or_ another packed array depending on the + number of dimensions. diff --git a/docs/syntax/scalar.md b/docs/syntax/scalar.md index e56bb6e..1ae37d2 100644 --- a/docs/syntax/scalar.md +++ b/docs/syntax/scalar.md @@ -9,21 +9,21 @@ custom grammar: === "Python (.py)" - ```python linenums="1" - import packtype - from packtype import Constant, Scalar - - @packtype.package() - class MyPackage: - # Constants - TYPE_A_W : Constant = 29 - TYPE_B_W : Constant = 13 - - # Typedefs - TypeA : Scalar[TYPE_A_W] - TypeB : Scalar[TYPE_B_W] - TypeC : Scalar[7] - ``` + ```python linenums="1" + import packtype + from packtype import Constant, Scalar + + @packtype.package() + class MyPackage: + # Constants + TYPE_A_W : Constant = 29 + TYPE_B_W : Constant = 13 + + # Typedefs + TypeA : Scalar[TYPE_A_W] + TypeB : Scalar[TYPE_B_W] + TypeC : Scalar[7] + ``` === "Packtype (.pt)" diff --git a/examples/arrays/spec.pt b/examples/arrays/spec.pt new file mode 100644 index 0000000..b63a741 --- /dev/null +++ b/examples/arrays/spec.pt @@ -0,0 +1,33 @@ +package one_dimension { + scalar_1d : scalar[4] + + struct struct_1d { + field_a : scalar[4][3][2] + field_b : scalar[2][3][4] + } + + enum enum_1d { + VAL_A: constant + VAL_B: constant + VAL_C: constant + } + + union union_1d { + member_a: struct_1d[2][3] + member_b: scalar[4][12][2][3] + } +} + +package two_dimension { + scalar_2d : one_dimension::scalar_1d[3] + enum_2d : one_dimension::enum_1d[2] + struct_2d : one_dimension::struct_1d[4] + union_2d : one_dimension::union_1d[2] +} + +package three_dimension { + scalar_3d : two_dimension::scalar_2d[2] + enum_3d : two_dimension::enum_2d[4] + struct_3d : two_dimension::struct_2d[5] + union_3d : two_dimension::union_2d[3] +} diff --git a/examples/arrays/spec.py b/examples/arrays/spec.py new file mode 100644 index 0000000..064f627 --- /dev/null +++ b/examples/arrays/spec.py @@ -0,0 +1,42 @@ +import packtype +from packtype import Constant, Scalar + + +@packtype.package() +class OneDimension: + Scalar1D: Scalar[4] + + +@OneDimension.struct() +class Struct1D: + field_a: Scalar[4][3][2] + field_b: Scalar[2][3][4] + + +@OneDimension.enum() +class Enum1D: + VAL_A: Constant + VAL_B: Constant + VAL_C: Constant + + +@OneDimension.union() +class Union1D: + member_a: Struct1D[2][3] + member_b: Scalar[4][12][2][3] + + +@packtype.package() +class TwoDimension: + Scalar2D: OneDimension.Scalar1D[3] + Enum2D: OneDimension.Enum1D[2] + Struct2D: OneDimension.Struct1D[4] + Union2D: OneDimension.Union1D[2] + + +@packtype.package() +class ThreeDimension: + Scalar3D: TwoDimension.Scalar2D[2] + Enum3D: TwoDimension.Enum2D[4] + Struct3D: TwoDimension.Struct2D[5] + Union3D: TwoDimension.Union2D[3] diff --git a/examples/arrays/test.sh b/examples/arrays/test.sh new file mode 100755 index 0000000..50e33ba --- /dev/null +++ b/examples/arrays/test.sh @@ -0,0 +1,17 @@ +#!/bin/bash + +# Copyright 2023-2025, Peter Birch, mailto:peter@intuity.io +# SPDX-License-Identifier: Apache-2.0 +# + +# Credit to Dave Dopson: https://stackoverflow.com/questions/59895/how-can-i-get-the-source-directory-of-a-bash-script-from-within-the-script-itsel +this_dir="$( cd "$( dirname "${BASH_SOURCE[0]}" )" >/dev/null 2>&1 && pwd )" + +# Setup PYTHONPATH to get access to packtype +export PYTHONPATH=${this_dir}/../..:$PYTHONPATH + +# Invoke packtype on Python syntax +python3 -m packtype --debug code package sv ${this_dir}/out_py ${this_dir}/spec.py + +# Invoke packtype on Packtype syntax +python3 -m packtype --debug code package sv ${this_dir}/out_pt ${this_dir}/spec.pt diff --git a/examples/axi4l_registers/registers.py b/examples/axi4l_registers/registers.py index 8442403..fa38fc3 100644 --- a/examples/axi4l_registers/registers.py +++ b/examples/axi4l_registers/registers.py @@ -47,7 +47,7 @@ class ResetControl: @packtype.registers.group() class ControlGroup: - core_reset: 4 * ResetControl + core_reset: ResetControl[4] # === Communications === @@ -76,4 +76,4 @@ class CommsGroup: class Control: device: DeviceGroup control: ControlGroup - comms: 2 * CommsGroup + comms: CommsGroup[2] diff --git a/examples/raw_registers/registers.py b/examples/raw_registers/registers.py index 8442403..fa38fc3 100644 --- a/examples/raw_registers/registers.py +++ b/examples/raw_registers/registers.py @@ -47,7 +47,7 @@ class ResetControl: @packtype.registers.group() class ControlGroup: - core_reset: 4 * ResetControl + core_reset: ResetControl[4] # === Communications === @@ -76,4 +76,4 @@ class CommsGroup: class Control: device: DeviceGroup control: ControlGroup - comms: 2 * CommsGroup + comms: CommsGroup[2] diff --git a/mkdocs.yml b/mkdocs.yml index d539fae..d2dfdae 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -35,6 +35,7 @@ nav: - Packtype: index.md - Syntax: - Alias: syntax/alias.md + - Arrays: syntax/arrays.md - Constants: syntax/constant.md - Enumerations: syntax/enum.md - Packages: syntax/package.md diff --git a/packtype/grammar/declarations.py b/packtype/grammar/declarations.py index c10dd4e..e3da28e 100644 --- a/packtype/grammar/declarations.py +++ b/packtype/grammar/declarations.py @@ -8,6 +8,7 @@ from ..common.expression import Expression from ..types.alias import Alias +from ..types.array import ArraySpec from ..types.assembly import Packing from ..types.base import Base from ..types.constant import Constant @@ -50,17 +51,46 @@ class Position: @dataclass() -class DeclImport: - position: Position +class ForeignRef: package: str name: str +@dataclass() +class DeclImport: + position: Position + foreign: ForeignRef + + +@dataclass() +class DeclDimensions: + dimensions: list[int] + + def resolve( + self, + cb_resolve: Callable[ + [ + str, + ], + int | type[Base], + ], + ) -> list[int]: + eval_dims = [] + for raw_dim in self.dimensions: + if isinstance(raw_dim, Expression): + eval_dims.append(raw_dim.evaluate(cb_resolve)) + else: + raise Exception("Unexpected width type in DeclScalar") + return eval_dims + + @dataclass() class DeclAlias: position: Position name: str foreign: str + dimensions: DeclDimensions | None = None + description: Description | None = None def to_class( self, @@ -70,8 +100,14 @@ def to_class( ], int | type[Base], ], - ) -> type[Alias]: - return Alias[cb_resolve(self.foreign)] + ) -> type[Alias] | ArraySpec: + entity = cb_resolve(self.foreign) + if self.dimensions: + for dim in self.dimensions.resolve(cb_resolve): + entity = entity[dim] + return entity + else: + return Alias[entity] @dataclass() @@ -112,25 +148,9 @@ class DeclScalar: position: Position name: str signedness: type[Signed | Unsigned] - width: Expression + dimensions: DeclDimensions description: Description | None = None - def resolve_width( - self, - cb_resolve: Callable[ - [ - str, - ], - int | type[Base], - ], - ) -> int: - if isinstance(self.width, Expression): - return self.width.evaluate(cb_resolve) - elif self.width is None: - return 1 - else: - raise Exception("Unexpected width type in DeclScalar") - def to_field_def( self, cb_resolve: Callable[ @@ -154,12 +174,14 @@ def to_class( int | type[Base], ], ) -> type[Scalar]: - scalar_cls = Scalar[ - self.resolve_width(cb_resolve), - (self.signedness is Signed), - ] - scalar_cls.__doc__ = str(self.description) if self.description else None - return scalar_cls + entity = None + for dim in self.dimensions.resolve(cb_resolve): + if entity is None: + entity = Scalar[dim, (self.signedness is Signed)] + else: + entity = entity[dim] + entity.__doc__ = str(self.description) if self.description else None + return entity @dataclass() @@ -229,6 +251,7 @@ class DeclField: position: Position name: str ref: str + dimensions: DeclDimensions | None = None description: Description | None = None @@ -265,7 +288,11 @@ def to_class( if isinstance(fdecl, DeclScalar): fields[fdecl.name] = fdecl.to_field_def(cb_resolve) elif isinstance(fdecl, DeclField): - fields[fdecl.name] = (cb_resolve(fdecl.ref), None) + ftype = cb_resolve(fdecl.ref) + if fdecl.dimensions: + for dim in fdecl.dimensions.resolve(cb_resolve): + ftype = ftype[dim] + fields[fdecl.name] = (ftype, None) else: raise ValueError(f"Unexpected struct field name: {fdecl}") return build_from_fields( @@ -309,7 +336,11 @@ def to_class( if isinstance(fdecl, DeclScalar): fields[fdecl.name] = fdecl.to_field_def(cb_resolve) elif isinstance(fdecl, DeclField): - fields[fdecl.name] = (cb_resolve(fdecl.ref), None) + ftype = cb_resolve(fdecl.ref) + if fdecl.dimensions: + for dim in fdecl.dimensions.resolve(cb_resolve): + ftype = ftype[dim] + fields[fdecl.name] = (ftype, None) else: raise ValueError(f"Unexpected struct field name: {fdecl}") return build_from_fields( diff --git a/packtype/grammar/grammar.py b/packtype/grammar/grammar.py index dbe5909..6d99409 100644 --- a/packtype/grammar/grammar.py +++ b/packtype/grammar/grammar.py @@ -4,6 +4,7 @@ import functools import inspect +from collections.abc import Iterable from pathlib import Path from lark import Lark @@ -19,9 +20,11 @@ DeclConstant, DeclEnum, DeclImport, + DeclPackage, DeclScalar, DeclStruct, DeclUnion, + ForeignRef, Position, ) from .transformer import PacktypeTransformer @@ -62,7 +65,7 @@ def parse_string( constant_overrides: dict[str, int] | None = None, source: Path | None = None, keep_expression: bool = False, -) -> Package: +) -> Iterable[Package]: """ Parse a Packtype definition from a string producing a Package object. @@ -75,7 +78,7 @@ def parse_string( associating each declaration with its source file. :param keep_expression: If True, expressions will be attached to constants allowing them to be re-evaluated with new inputs. - :return: A Package object representing the parsed definition. + :yields: A Package object representing the parsed definition. """ # If no namespaces are provided, use an empty dict namespaces = namespaces or {} @@ -83,7 +86,7 @@ def parse_string( constant_overrides = constant_overrides or {} # Parse the definition try: - defn = PacktypeTransformer().transform(create_parser().parse(definition)) + definitions = PacktypeTransformer().transform(create_parser().parse(definition)) except UnexpectedToken as exc: raise ParseError( f"Failed to parse {source.name if source else 'input'} on line {exc.line}: " @@ -102,101 +105,117 @@ def _check_collision(name: str) -> None: f"on line {pos.line}" ) - def _resolve(name: str) -> int: + def _resolve(ref: str | ForeignRef) -> int: nonlocal known_entities - if name in known_entities: - return known_entities[name][0] - raise UnknownEntityError(f"Failed to resolve '{name}' to a known constant or type") - - # Create the package - package: Package = build_from_fields( - base=Package, - cname=defn.name, - fields={}, - kwds=defn.get_modifiers(), - doc_str=str(defn.description) if defn.description else None, - source=(source.as_posix() if source else "N/A", defn.position.line), - ) - - # Run through the declarations - for decl in defn.declarations: - match decl: - # Imports - case DeclImport(): - # Resolve the package - if (foreign_pkg := namespaces.get(decl.package, None)) is None: - raise ImportError(f"Unknown package '{decl.package}'") - # Resolve the type - if (foreign_type := getattr(foreign_pkg, decl.name, None)) is None: - raise ImportError(f"'{decl.name}' not declared in package '{decl.package}'") - # Check for name collisions - _check_collision(decl.name) - # Remember this type - if isinstance(foreign_type, Constant): - known_entities[decl.name] = (foreign_type, decl.position) - else: - known_entities[decl.name] = (foreign_type, decl.position) - # Aliases - case DeclAlias(): - package._pt_attach( - scalar := decl.to_class(_resolve), - name=decl.name, + if isinstance(ref, ForeignRef): + if ref.package not in namespaces: + raise UnknownEntityError(f"Failed to resolve package '{ref.package}'") + if not hasattr(namespaces[ref.package], ref.name): + raise UnknownEntityError( + f"Failed to resolve '{ref.name}' in package '{ref.package}'" ) - # Check for name collisions - _check_collision(decl.name) - # Remember this type - known_entities[decl.name] = (scalar, decl.position) - # Build constants - case DeclConstant(): - constant = decl.to_instance(_resolve) - if keep_expression: - constant._PT_EXPRESSION = decl.expr - package._pt_attach_constant(decl.name, constant) - # Check for name collisions - _check_collision(decl.name) - # Check for a constant override - if decl.name in constant_overrides: - get_log().debug( - f"Overriding constant '{decl.name}' with value " - f"{constant_overrides[decl.name]}" + return getattr(namespaces[ref.package], ref.name) + elif ref in known_entities: + return known_entities[ref][0] + raise UnknownEntityError(f"Failed to resolve '{ref}' to a known constant or type") + + for defn in [definitions] if isinstance(definitions, DeclPackage) else definitions: + # Create the package + package: Package = build_from_fields( + base=Package, + cname=defn.name, + fields={}, + kwds=defn.get_modifiers(), + doc_str=str(defn.description) if defn.description else None, + source=(source.as_posix() if source else "N/A", defn.position.line), + ) + + # Run through the declarations + for decl in defn.declarations: + match decl: + # Imports + case DeclImport(): + # Resolve the package + if (foreign_pkg := namespaces.get(decl.foreign.package, None)) is None: + raise ImportError(f"Unknown package '{decl.foreign.package}'") + # Resolve the type + if (foreign_type := getattr(foreign_pkg, decl.foreign.name, None)) is None: + raise ImportError( + f"'{decl.foreign.name}' not declared in package " + f"'{decl.foreign.package}'" + ) + # Check for name collisions + _check_collision(decl.foreign.name) + # Remember this type + if isinstance(foreign_type, Constant): + known_entities[decl.foreign.name] = (foreign_type, decl.position) + else: + known_entities[decl.foreign.name] = (foreign_type, decl.position) + # Aliases + case DeclAlias(): + package._pt_attach( + alias := decl.to_class(_resolve), + name=decl.name, + ) + # Check for name collisions + _check_collision(decl.name) + # Remember this type + known_entities[decl.name] = (alias, decl.position) + # Build constants + case DeclConstant(): + constant = decl.to_instance(_resolve) + if keep_expression: + constant._PT_EXPRESSION = decl.expr + package._pt_attach_constant(decl.name, constant) + # Check for name collisions + _check_collision(decl.name) + # Check for a constant override + if decl.name in constant_overrides: + get_log().debug( + f"Overriding constant '{decl.name}' with value " + f"{constant_overrides[decl.name]}" + ) + constant._pt_set(int(constant_overrides[decl.name])) + # Remember this constant + known_entities[decl.name] = (constant, decl.position) + # Build aliases and scalars + case DeclScalar() | DeclAlias(): + package._pt_attach( + obj := decl.to_class(_resolve), + name=decl.name, ) - constant._pt_set(int(constant_overrides[decl.name])) - # Remember this constant - known_entities[decl.name] = (constant, decl.position) - # Build aliases and scalars - case DeclScalar() | DeclAlias(): - package._pt_attach( - obj := decl.to_class(_resolve), - name=decl.name, + # Check for name collisions + _check_collision(decl.name) + # Remember this type + known_entities[decl.name] = (obj, decl.position) + # Build enums, structs, and unions + case DeclEnum() | DeclStruct() | DeclUnion(): + package._pt_attach(obj := decl.to_class(source, _resolve)) + # Check for name collisions + _check_collision(decl.name) + # Remember this type + known_entities[decl.name] = (obj, decl.position) + case _: + raise Exception(f"Unhandled declaration: {decl}") + + # Check for overrides that don't match up + for name in constant_overrides.keys(): + if not hasattr(package, name): + raise UnknownEntityError( + f"Constant override '{name}' does not match any defined constant " + f"in package '{package.__name__}'" ) - # Check for name collisions - _check_collision(decl.name) - # Remember this type - known_entities[decl.name] = (obj, decl.position) - # Build enums, structs, and unions - case DeclEnum() | DeclStruct() | DeclUnion(): - package._pt_attach(obj := decl.to_class(source, _resolve)) - # Check for name collisions - _check_collision(decl.name) - # Remember this type - known_entities[decl.name] = (obj, decl.position) - case _: - raise Exception(f"Unhandled declaration: {decl}") - - # Check for overrides that don't match up - for name in constant_overrides.keys(): - if not hasattr(package, name): - raise UnknownEntityError( - f"Constant override '{name}' does not match any defined constant " - f"in package '{package.__name__}'" - ) - elif not isinstance(getattr(package, name), Constant): - raise TypeError( - f"Constant override '{name}' does not match a constant in package " - f"'{package.__name__}', found {getattr(package, name).__name__}" - ) + elif not isinstance(getattr(package, name), Constant): + raise TypeError( + f"Constant override '{name}' does not match a constant in package " + f"'{package.__name__}', found {getattr(package, name).__name__}" + ) + + # Register with namespace + namespaces[package.__name__] = package - return package + # Yield the package + yield package def parse( @@ -204,7 +223,7 @@ def parse( namespaces: dict[str, Package] | None = None, constant_overrides: dict[str, int] | None = None, keep_expression: bool = False, -) -> Package: +) -> Iterable[Package]: """ Parse a Packtype definition from a file path producing a Package object. @@ -215,10 +234,10 @@ def parse( the constant's name. :param keep_expression: If True, expressions will be attached to constants allowing them to be re-evaluated with new inputs. - :return: A Package object representing the parsed definition. + :yields: Package objects representing the parsed definition. """ with path.open("r", encoding="utf-8") as fh: - return parse_string( + yield from parse_string( definition=fh.read(), namespaces=namespaces, constant_overrides=constant_overrides, diff --git a/packtype/grammar/packtype.lark b/packtype/grammar/packtype.lark index 149d72c..64a80bb 100644 --- a/packtype/grammar/packtype.lark +++ b/packtype/grammar/packtype.lark @@ -38,7 +38,8 @@ COMMENT: /\/\/[^\n]*/ ?signed: "signed" ?unsigned: "unsigned" -?width: "[" expr "]" +dimension: "[" expr "]" +dimensions: dimension+ ?name: CNAME descr: ESCAPED_STRING @@ -50,7 +51,7 @@ modifier: "@" name "=" (name | ESCAPED_STRING | NUMERIC) // ============================================================================= // Allowable root nodes -?root: decl_package +?root: decl_package* // | decl_regblock // ============================================================================= @@ -73,20 +74,21 @@ decl_package: "package"i name "{" descr? modifier* package_body* "}" // ============================================================================= // Example: import other_pkg::VALUE_A -decl_import: "import" name "::" name +foreign_ref: name "::" name +decl_import: "import" foreign_ref // ============================================================================= // Simple Declarations // ============================================================================= // Example: local_type_t : foreign_type_t -decl_alias: name ":" name +decl_alias: name ":" (name | foreign_ref) dimensions? descr? // Example: MY_CONSTANT : constant[8] = 123 -decl_constant: name ":" "constant"i width? "=" expr descr? +decl_constant: name ":" "constant"i dimension? "=" expr descr? // Example: simple_type_t : scalar[8] -decl_scalar: name ":" (signed|unsigned)? "scalar"i width? descr? +decl_scalar: name ":" (signed|unsigned)? "scalar"i dimensions? descr? // ============================================================================= // Enumerations @@ -101,7 +103,7 @@ decl_scalar: name ":" (signed|unsigned)? "scalar"i width? descr? // // ============================================================================= -decl_enum: "enum"i enum_mode? width? name "{" descr? modifier* enum_body* "}" +decl_enum: "enum"i enum_mode? dimension? name "{" descr? modifier* enum_body* "}" ?enum_mode: enum_mode_indexed | enum_mode_onehot @@ -131,10 +133,10 @@ enum_mode_gray: "gray"i // // ============================================================================= -field: name ":" name descr? +field: name ":" (name | foreign_ref) dimensions? descr? | decl_scalar -decl_struct: "struct"i packing_mode? width? name "{" descr? modifier* field* "}" +decl_struct: "struct"i packing_mode? dimension? name "{" descr? modifier* field* "}" ?packing_mode: packing_mode_msb | packing_mode_lsb diff --git a/packtype/grammar/transformer.py b/packtype/grammar/transformer.py index 988f4ab..f3a89ba 100644 --- a/packtype/grammar/transformer.py +++ b/packtype/grammar/transformer.py @@ -13,6 +13,7 @@ from .declarations import ( DeclAlias, DeclConstant, + DeclDimensions, DeclEnum, DeclField, DeclImport, @@ -21,6 +22,7 @@ DeclStruct, DeclUnion, Description, + ForeignRef, Modifier, Position, Signed, @@ -85,6 +87,15 @@ def descr(self, body): def modifier(self, body): return Modifier(*body) + def dimension(self, body): + return body[0] + + def dimensions(self, body): + return DeclDimensions(dimensions=body) + + def foreign_ref(self, body): + return ForeignRef(*body) + @v_args(meta=True) def decl_import(self, meta, body): return DeclImport(Position(meta.line, meta.column), *body) @@ -115,11 +126,16 @@ def decl_constant(self, meta, body): @v_args(meta=True) def field(self, meta, body): - return ( - body[0] - if isinstance(body[0], DeclScalar) - else DeclField(Position(meta.line, meta.column), *body) - ) + if isinstance(body[0], DeclScalar): + return body[0] + else: + return DeclField( + Position(meta.line, meta.column), + name=body.pop(0), + ref=body.pop(0), + dimensions=body.pop(0) if body and isinstance(body[0], DeclDimensions) else None, + description=body.pop(0) if body and isinstance(body[0], Description) else None, + ) @v_args(meta=True) def enum_body_simple(self, meta, body): @@ -193,17 +209,17 @@ def decl_scalar(self, meta, body): signed, *remainder = remainder else: signed = Unsigned - # Pickup width - if remainder and isinstance(remainder[0], Expression): - width, *remainder = remainder + # Pickup dimensions + if remainder and isinstance(remainder[0], DeclDimensions): + dimensions, *remainder = remainder else: - width = Expression(1) + dimensions = DeclDimensions(dimensions=[Expression(1)]) # Pickup description if remainder and isinstance(remainder[0], Description): descr = remainder[0] else: descr = None - return DeclScalar(Position(meta.line, meta.column), s_type, signed, width, descr) + return DeclScalar(Position(meta.line, meta.column), s_type, signed, dimensions, descr) @v_args(meta=True) def decl_struct(self, meta, body): @@ -260,3 +276,6 @@ def decl_package(self, meta, body): while remainder and isinstance(remainder[0], Modifier): mods.append(remainder.pop(0)) return DeclPackage(Position(meta.line, meta.column), p_name, description, mods, remainder) + + def root(self, body): + return body diff --git a/packtype/registers/registers.py b/packtype/registers/registers.py index 92109ac..89db0c2 100644 --- a/packtype/registers/registers.py +++ b/packtype/registers/registers.py @@ -14,8 +14,8 @@ from ..types.constant import Constant from ..types.enum import Enum from ..types.packing import Packing -from ..types.primitive import NumericPrimitive -from ..types.scalar import Scalar +from ..types.primitive import NumericType +from ..types.scalar import Scalar, ScalarType from ..types.struct import Struct from ..types.wrap import build_from_fields, get_wrapper @@ -97,7 +97,7 @@ class Register(PackedAssembly): """Defines a single register with a behaviour, width, and alignment""" # Allow both constants and scalars to be assigned values - _PT_ALLOW_DEFAULTS: list[type[Base]] = [Constant, Enum, Scalar, Struct] + _PT_ALLOW_DEFAULTS: list[type[Base]] = [Constant, Enum, ScalarType, Struct] # Detail custom attributes that registers offer _PT_ATTRIBUTES: dict[str, tuple[Any, list[Any]]] = { "behaviour": (Behaviour.CONSTANT, list(Behaviour)), @@ -326,7 +326,7 @@ def _pt_construct(cls, parent: Base, width: int | None, align: int | None, spaci parent=fbase, ) # Insert a placeholder entry to offsets - dimension = ftype.dimension if isinstance(ftype, ArraySpec) else 1 + dimension = ftype.dimensions[0] if isinstance(ftype, ArraySpec) else 1 for idx in range(dimension): cls._PT_OFFSETS[fname, idx] = ( fbase._PT_BYTE_SIZE, @@ -384,7 +384,7 @@ def _is_a_type(obj: Any) -> bool: if obj._PT_BASE in (Group, Register): return False # If it's not a primitive, immediately accept - if inspect.isclass(obj) and not issubclass(obj, NumericPrimitive): + if inspect.isclass(obj) and not issubclass(obj, NumericType): return True # If not attached to a different package, accept return obj._PT_ATTACHED_TO is not None and type(obj._PT_ATTACHED_TO) is not cls diff --git a/packtype/start.py b/packtype/start.py index a4a0226..ab4f7ba 100644 --- a/packtype/start.py +++ b/packtype/start.py @@ -24,7 +24,7 @@ from .types.enum import Enum from .types.package import Package from .types.primitive import NumericPrimitive -from .types.scalar import Scalar +from .types.scalar import Scalar, ScalarType from .types.struct import Struct from .types.union import Union from .types.wrap import Registry @@ -76,8 +76,8 @@ def load_specification(spec_files: list[str], keep_expression: bool) -> list[Bas get_log().debug(f"Loading specification: {item}") # Packtype grammar files if item.lower().endswith((".pt", ".packtype", ".ptype")): - package = parse(Path(item), namespaces, keep_expression=keep_expression) - namespaces[package.__name__] = package + for package in parse(Path(item), namespaces, keep_expression=keep_expression): + namespaces[package.__name__] = package # If it ends with `.py` assume it's Python elif item.endswith(".py"): item = Path(item) @@ -345,6 +345,7 @@ def code( Enum, Packing, Scalar, + ScalarType, Struct, Union, NumericPrimitive, diff --git a/packtype/templates/package.sv.mako b/packtype/templates/package.sv.mako index 9e9d4b1..a5a2a80 100644 --- a/packtype/templates/package.sv.mako +++ b/packtype/templates/package.sv.mako @@ -60,6 +60,22 @@ typedef logic [${utils.get_width(obj)-1}:0] ${name | filters.type}; // ${name} typedef ${obj._PT_ALIAS._pt_name() | filters.type} ${name | filters.type}; %endfor +%for name, spec in baseline._pt_arrays: +// ${name} +typedef \ + %if utils.is_scalar(spec.base): +logic \ + %else: +${utils.get_name(spec.base) | filters.type} \ + %endif +%for dim in spec.dimensions: +[${dim-1}:0]\ +%endfor + %if utils.is_scalar(spec.base): +[${utils.get_width(spec.base)-1}:0]\ + %endif + ${name | filters.type}; +%endfor // ============================================================================= // Enumerations @@ -102,16 +118,18 @@ typedef struct packed { <% pad_idx += 1 %>\ %endif <% - array_sfx = f" [{len(field)-1}:0]" if isinstance(field, PackedArray) else "" - field = field[0] if isinstance(field, PackedArray) else field + array_sfx = "" + if utils.array.is_packed_array(field): + array_sfx = " " + "".join(f"[{x-1}:0]" for x in field._pt_spec.dimensions) + field = field._pt_spec.base() %>\ - %if isinstance(field, Scalar): + %if utils.is_scalar(field): %if field._PT_ATTACHED_TO: <% refers_to = field._pt_name() %>\ ${refers_to | filters.type}${array_sfx} ${fname | tc.snake_case}; %else: <% sign_sfx = " signed" if field._pt_signed else "" %>\ - logic${sign_sfx}${array_sfx}${f" [{utils.get_width(field)}:0]" if utils.get_width(field) > 1 else ""} ${fname | tc.snake_case}; + logic${sign_sfx}${array_sfx}${f" [{utils.get_width(field)-1}:0]" if utils.get_width(field) > 1 else ""} ${fname | tc.snake_case}; %endif %elif isinstance(field, Alias | Enum | Struct | Union): ${field._pt_name() | filters.type}${array_sfx} ${fname | tc.snake_case}; @@ -127,16 +145,18 @@ typedef struct packed { typedef union packed { %for field, fname in obj._pt_fields.items(): <% - array_sfx = f" [{len(field)-1}:0]" if isinstance(field, PackedArray) else "" - field = field[0] if isinstance(field, PackedArray) else field + array_sfx = "" + if utils.array.is_packed_array(field): + array_sfx = " " + "".join(f"[{x-1}:0]" for x in field._pt_spec.dimensions) + field = field._pt_spec.base() %>\ - %if isinstance(field, Scalar): + %if isinstance(field, ScalarType): %if field._PT_ATTACHED_TO: <% refers_to = field._pt_name() %>\ ${refers_to | filters.type} ${fname | tc.snake_case}; %else: <% sign_sfx = " signed" if field._pt_signed else "" %>\ - logic${sign_sfx}${f" [{utils.get_width(field)}:0]" if utils.get_width(field) > 1 else ""} ${fname | tc.snake_case}; + logic${sign_sfx}${array_sfx}${f" [{utils.get_width(field)-1}:0]" if utils.get_width(field) > 1 else ""} ${fname | tc.snake_case}; %endif %elif isinstance(field, Enum | Struct | Alias | Union): ${field._pt_name() | filters.type}${array_sfx} ${fname | tc.snake_case}; diff --git a/packtype/types/array.py b/packtype/types/array.py index ec4aca5..834ff2d 100644 --- a/packtype/types/array.py +++ b/packtype/types/array.py @@ -3,25 +3,62 @@ # import functools +import math from collections.abc import Callable, Iterable -from typing import Any +from typing import Any, Self from .bitvector import BitVector, BitVectorWindow from .packing import Packing class ArraySpec: - def __init__(self, base: Any, dimension: int) -> None: + def __init__(self, base: Any, dimensions: int | tuple[int]) -> None: self.base = base - self.dimension = dimension + self.dimensions = dimensions if isinstance(dimensions, list | tuple) else (dimensions,) @property - def _PT_WIDTH(self) -> int: # noqa: N802 - return self.base._PT_WIDTH * self.dimension + def _pt_flat_dimension(self) -> int: + return math.prod(self.dimensions) @property def _pt_width(self) -> int: - return self._PT_WIDTH + return self.base._PT_WIDTH * self._pt_flat_dimension + + @property + def _PT_WIDTH(self) -> int: # noqa: N802 + return self._pt_width + + def _pt_ranges( + self, + packing: Packing = Packing.FROM_LSB, + ) -> dict[tuple[int], tuple[int, int]]: + def _recurse( + remaining: tuple[int], + path: tuple[int], + msb: int, + lsb: int, + ) -> tuple[tuple[int], int, int]: + # If no dimensions left, produce an element + if len(remaining) == 0: + yield path, msb, lsb + # Otherwise, iterate over the dimension + else: + dimension, *remaining = remaining + stepping = math.prod((*remaining, 1)) * self.base._PT_WIDTH + for idx in range(dimension): + if packing is Packing.FROM_LSB: + lsb = lsb + msb = lsb + stepping - 1 + else: + msb = msb + lsb = msb - stepping + 1 + yield from _recurse(remaining, (*path, idx), msb, lsb) + if packing is Packing.FROM_LSB: + lsb += stepping + else: + msb -= stepping + + return {x[0]: (x[1], x[2]) for x in _recurse(self.dimensions, [], self._pt_width - 1, 0)} def _pt_references(self) -> Iterable[Any]: return self.base._pt_references() @@ -35,6 +72,9 @@ def as_unpacked(self, **kwds) -> "PackedArray": def __call__(self, **kwds) -> "PackedArray": return self.as_packed(**kwds) + def __getitem__(self, key: int) -> Self: + return type(self)(self.base, (key, *self.dimensions)) + class PackedArray: def __init__( @@ -45,38 +85,71 @@ def __init__( _pt_per_inst: Callable[[int, list[Any], dict[str, Any]], tuple[list[Any], dict[str, Any]]] | None = None, packing: Packing = Packing.FROM_LSB, + dimensions: tuple[int] | None = None, + dim_path: tuple[int] | None = None, **kwds, ): + self._pt_spec = spec self._pt_bv = BitVector(width=spec._pt_width) if _pt_bv is None else _pt_bv self._pt_entries = [] - if packing is Packing.FROM_LSB: + self._pt_dimensions = dimensions or spec.dimensions + self._pt_dim_path = dim_path or [] + + # For a single dimension, instance elements + if len(self._pt_dimensions) == 1: + msb = self._pt_dimensions[0] * spec.base._PT_WIDTH - 1 lsb = 0 - for idx in range(spec.dimension): - inst_args, inst_kwds = ( - _pt_per_inst(idx, *args, **kwds) if callable(_pt_per_inst) else (args, kwds) - ) + for idx in range(self._pt_dimensions[0]): + if packing is Packing.FROM_LSB: + msb = lsb + spec.base._PT_WIDTH - 1 + else: + lsb = msb - spec.base._PT_WIDTH + 1 + if callable(_pt_per_inst): + inst_args, inst_kwds = _pt_per_inst((*self._pt_dim_path, idx), *args, **kwds) + else: + inst_args, inst_kwds = args, kwds self._pt_entries.append( - entry := spec.base( + spec.base( *inst_args, - _pt_bv=self._pt_bv.create_window(lsb + spec.base._PT_WIDTH - 1, lsb), + _pt_bv=self._pt_bv.create_window(msb, lsb), **inst_kwds, ) ) - lsb += entry._pt_width + if packing is Packing.FROM_LSB: + lsb += spec.base._PT_WIDTH + else: + msb -= spec.base._PT_WIDTH + # Otherwise, nest another PackedArray else: - msb = spec._pt_width - 1 - for idx in range(spec.dimension): - inst_args, inst_kwds = ( - _pt_per_inst(idx, *args, **kwds) if callable(_pt_per_inst) else (args, kwds) - ) + dimension, *remaining = self._pt_dimensions + stepping = math.prod(remaining) * spec.base._PT_WIDTH + msb = dimension * stepping - 1 + lsb = 0 + for idx in range(dimension): + if packing is Packing.FROM_LSB: + msb = lsb + stepping - 1 + else: + lsb = msb - stepping + 1 + if callable(_pt_per_inst): + inst_args, inst_kwds = _pt_per_inst((*self._pt_dim_path, idx), *args, **kwds) + else: + inst_args, inst_kwds = args, kwds self._pt_entries.append( - entry := spec.base( - *inst_args, - _pt_bv=self._pt_bv.create_window(msb, msb - spec.base._PT_WIDTH + 1), - **inst_kwds, + PackedArray( + spec, + *args, + _pt_bv=self._pt_bv.create_window(msb, lsb), + _pt_per_inst=_pt_per_inst, + packing=packing, + dimensions=remaining, + dim_path=(*self._pt_dim_path, idx), + **kwds, ) ) - msb -= entry._pt_width + if packing is Packing.FROM_LSB: + lsb += stepping + else: + msb -= stepping def __getitem__(self, key: int) -> Any: return self._pt_entries[key] @@ -93,7 +166,22 @@ def __len__(self) -> int: @property @functools.cache # noqa: B019 def _pt_width(self) -> int: - return sum(x._pt_width for x in self._pt_entries) + return self._pt_bv.width + + def _pt_pack(self) -> int: + return int(self._pt_bv) + + @classmethod + def _pt_unpack(cls, packed: int) -> "PackedArray": + inst = cls() + inst._pt_set(packed) + return inst + + def __int__(self) -> int: + return self._pt_pack() + + def _pt_set(self, value: int) -> None: + self._pt_bv.set(value) class UnpackedArray: @@ -105,8 +193,9 @@ def __init__( | None = None, **kwds, ): + self._pt_spec = spec self._pt_entries = [] - for idx in range(spec.dimension): + for idx in range(spec.dimensions[0]): inst_args, inst_kwds = ( _pt_per_inst(idx, *args, **kwds) if callable(_pt_per_inst) else (args, kwds) ) diff --git a/packtype/types/assembly.py b/packtype/types/assembly.py index b9d301c..684eb6b 100644 --- a/packtype/types/assembly.py +++ b/packtype/types/assembly.py @@ -198,11 +198,10 @@ def _pt_construct(cls, parent: Base, packing: Packing, width: int | None): for fname, ftype, _ in cls._pt_definitions(): # For arrays record each component placement separately if isinstance(ftype, ArraySpec): - fwidth = ftype.base()._pt_width - part_lsb = lsb - for idx in range(ftype.dimension): - cls._PT_RANGES[fname, idx] = (part_lsb, part_lsb + fwidth - 1) - part_lsb += fwidth + for dimension, (part_msb, part_lsb) in ftype._pt_ranges( + cls._PT_PACKING + ).items(): + cls._PT_RANGES[fname, dimension] = (lsb + part_msb, lsb + part_lsb) # For every field type (including arrays) record full placement fwidth = ftype()._pt_width cls._PT_RANGES[fname] = (lsb, lsb + fwidth - 1) @@ -217,11 +216,14 @@ def _pt_construct(cls, parent: Base, packing: Packing, width: int | None): for fname, ftype, _ in cls._pt_definitions(): # For arrays record each component placement separately if isinstance(ftype, ArraySpec): - fwidth = ftype.base()._pt_width - part_msb = msb - for idx in range(ftype.dimension): - cls._PT_RANGES[fname, idx] = (part_msb - fwidth + 1, part_msb) - part_msb -= fwidth + root_lsb = msb - ftype._PT_WIDTH + 1 + for dimension, (part_msb, part_lsb) in ftype._pt_ranges( + cls._PT_PACKING + ).items(): + cls._PT_RANGES[fname, dimension] = ( + part_msb + root_lsb, + part_lsb + root_lsb, + ) # For every field type (including arrays) record full placement fwidth = ftype()._pt_width cls._PT_RANGES[fname] = (msb - fwidth + 1, msb) @@ -253,11 +255,7 @@ def _pt_mask(self) -> int: def _pt_fields_lsb_asc(self) -> list[tuple[int, int, tuple[str, Base]]]: pairs = [] for finst, fname in self._pt_fields.items(): - if isinstance(finst, PackedArray): - lsb = min(self._PT_RANGES[(fname, x)][0] for x in range(len(finst))) - msb = max(self._PT_RANGES[(fname, x)][1] for x in range(len(finst))) - else: - lsb, msb = self._PT_RANGES[fname] + lsb, msb = self._PT_RANGES[fname] pairs.append((lsb, msb, (fname, finst))) return sorted(pairs, key=lambda x: x[0]) @@ -266,11 +264,7 @@ def _pt_fields_lsb_asc(self) -> list[tuple[int, int, tuple[str, Base]]]: def _pt_fields_msb_desc(self) -> list[tuple[int, int, tuple[str, Base]]]: pairs = [] for finst, fname in self._pt_fields.items(): - if isinstance(finst, PackedArray): - lsb = min(self._PT_RANGES[(fname, x)][0] for x in range(len(finst))) - msb = max(self._PT_RANGES[(fname, x)][1] for x in range(len(finst))) - else: - lsb, msb = self._PT_RANGES[fname] + lsb, msb = self._PT_RANGES[fname] pairs.append((lsb, msb, (fname, finst))) return sorted(pairs, key=lambda x: x[1], reverse=True) diff --git a/packtype/types/base.py b/packtype/types/base.py index 389c113..0c1056d 100644 --- a/packtype/types/base.py +++ b/packtype/types/base.py @@ -19,16 +19,15 @@ class MetaBase(type): - def __mul__(cls, other: int): - return ArraySpec(cls, other) - - def __rmul__(cls, other: int): - return ArraySpec(cls, other) + def __getitem__(cls, key: int): + return ArraySpec(cls, key) class Base(metaclass=MetaBase): # The base class type _PT_BASE: type["Base"] | None = None + # Substitute type for metaclass + _PT_META_USE_TYPE: type["Base"] | None = None # What contained types are allowed to have a default value (e.g. constants) _PT_ALLOW_DEFAULTS: list[type["Base"]] = [] # Any other types to be attached to this one (e.g. struct to a package) diff --git a/packtype/types/package.py b/packtype/types/package.py index 0cd06b4..ca7c72b 100644 --- a/packtype/types/package.py +++ b/packtype/types/package.py @@ -13,8 +13,8 @@ from .base import Base from .constant import Constant from .enum import Enum -from .primitive import NumericPrimitive -from .scalar import Scalar +from .primitive import NumericType +from .scalar import ScalarType from .struct import Struct from .union import Union from .wrap import get_wrapper @@ -29,7 +29,7 @@ def _pt_construct(cls, parent: Base) -> None: super()._pt_construct(parent) cls._PT_FIELDS = {} for fname, ftype, fval in cls._pt_definitions(): - if issubclass(ftype, Constant): + if inspect.isclass(ftype) and issubclass(ftype, Constant): cls._pt_attach_constant(fname, ftype(default=fval)) else: cls._pt_attach(ftype, name=fname) @@ -83,7 +83,7 @@ def _is_a_type(obj: Any) -> bool: if isinstance(obj, ArraySpec): obj = obj.base # If it's not a primitive, immediately accept - if inspect.isclass(obj) and not issubclass(obj, NumericPrimitive): + if inspect.isclass(obj) and not issubclass(obj, NumericType): return True # If not attached to a different package, accept return obj._PT_ATTACHED_TO is not None and type(obj._PT_ATTACHED_TO) is not cls @@ -98,37 +98,40 @@ def _pt_fields(self) -> dict: def _pt_constants(self) -> Iterable[Constant]: return ((y, x) for x, y in self._pt_fields.items() if isinstance(x, Constant)) - @property - def _pt_scalars(self) -> Iterable[tuple[str, Scalar]]: + def _pt_filter_for_class(self, ctype: type[Base]) -> Iterable[tuple[str, type[Base]]]: return ( (y, x) for x, y in self._pt_fields.items() - if (inspect.isclass(x) and issubclass(x, Scalar)) + if (inspect.isclass(x) and issubclass(x, ctype)) ) + @property + def _pt_scalars(self) -> Iterable[tuple[str, ScalarType]]: + return self._pt_filter_for_class(ScalarType) + + @property + def _pt_arrays(self) -> Iterable[tuple[str, ArraySpec]]: + return ((y, x) for x, y in self._pt_fields.items() if isinstance(x, ArraySpec)) + @property def _pt_aliases(self) -> Iterable[Alias]: - return ( - (y, x) - for x, y in self._pt_fields.items() - if inspect.isclass(x) and issubclass(x, Alias) - ) + return self._pt_filter_for_class(Alias) @property def _pt_enums(self) -> Iterable[tuple[str, Enum]]: - return ((x._pt_name(), x) for x in self._PT_ATTACH if issubclass(x, Enum)) + return self._pt_filter_for_class(Enum) @property def _pt_structs(self) -> Iterable[tuple[str, Struct]]: - return ((x._pt_name(), x) for x in self._PT_ATTACH if issubclass(x, Struct)) + return self._pt_filter_for_class(Struct) @property def _pt_unions(self) -> Iterable[tuple[str, Union]]: - return ((x._pt_name(), x) for x in self._PT_ATTACH if issubclass(x, Union)) + return self._pt_filter_for_class(Union) @property def _pt_structs_and_unions(self) -> Iterable[tuple[str, Struct | Union]]: - return ((x._pt_name(), x) for x in self._PT_ATTACH if issubclass(x, Struct | Union)) + return self._pt_filter_for_class(Struct | Union) @classmethod def _pt_lookup(cls, field: type[Base] | Base) -> str: diff --git a/packtype/types/primitive.py b/packtype/types/primitive.py index 1bde6a5..a4aeca5 100644 --- a/packtype/types/primitive.py +++ b/packtype/types/primitive.py @@ -2,6 +2,7 @@ # SPDX-License-Identifier: Apache-2.0 # +import inspect from collections import defaultdict from typing import Any @@ -28,19 +29,31 @@ def __getitem__(self, key: int | tuple[int, bool]): @staticmethod def get_variant(prim: Self, segments: tuple[str], kwargs: dict[str, Any]): + # Use the frame to determine declaration source + frame = inspect.currentframe() + for _ in range(2): + frame = frame.f_back + source = (frame.f_code.co_filename, frame.f_lineno) + # If the primitive provides a base type, use that instead + meta_type = prim._PT_META_USE_TYPE or prim # NOTE: Don't share primitives between creations as this prevents the # parent being distinctly tracked (a problem when they are used as # typedefs on a package) uid = MetaPrimitive.UNIQUE_ID[segments] MetaPrimitive.UNIQUE_ID[segments] += 1 - return type( - prim.__name__ + "_" + "_".join(str(x) for x in segments) + f"_{uid}", - (prim,), - kwargs, + imposter = type( + meta_type.__name__ + "_" + "_".join(str(x) for x in segments) + f"_{uid}", + (meta_type,), + { + **kwargs, + "_PT_SOURCE": source, + "_PT_BASE": meta_type, + }, ) + return imposter -class NumericPrimitive(Base, Numeric, metaclass=MetaPrimitive): +class NumericType(Base, Numeric): _PT_WIDTH: int = -1 _PT_SIGNED: bool = False @@ -51,22 +64,6 @@ def __init__( ) -> None: super().__init__(_pt_bv=_pt_bv, default=default) - @classmethod - def _pt_meta_key(cls, key: int | tuple[int, bool]) -> tuple[tuple[str], dict[str, Any]]: - if isinstance(key, int) or hasattr(key, "__int__"): - key = int(key) - return ((str(key),), {"_PT_WIDTH": key}) - elif ( - isinstance(key, tuple) - and (isinstance(key[0], int) or hasattr(key[0], "__int__")) - and (isinstance(key[1], bool) or hasattr(key[1], "__bool__")) - ): - width, signed = int(key[0]), bool(key[1]) - key = f"{width}{'S' if signed else 'U'}" - return ((key,), {"_PT_WIDTH": width, "_PT_SIGNED": signed}) - else: - raise Exception(f"Unsupported NumericPrimitive key: {key}") - @property def _pt_width(self) -> int: return type(self)._PT_WIDTH @@ -100,3 +97,21 @@ def __int__(self) -> int: def __float__(self) -> float: return float(int(self)) + + +class NumericPrimitive(NumericType, metaclass=MetaPrimitive): + @classmethod + def _pt_meta_key(cls, key: int | tuple[int, bool]) -> tuple[tuple[str], dict[str, Any]]: + if isinstance(key, int) or hasattr(key, "__int__"): + key = int(key) + return ((str(key),), {"_PT_WIDTH": key}) + elif ( + isinstance(key, tuple) + and (isinstance(key[0], int) or hasattr(key[0], "__int__")) + and (isinstance(key[1], bool) or hasattr(key[1], "__bool__")) + ): + width, signed = int(key[0]), bool(key[1]) + key = f"{width}{'S' if signed else 'U'}" + return ((key,), {"_PT_WIDTH": width, "_PT_SIGNED": signed}) + else: + raise Exception(f"Unsupported NumericPrimitive key: {key}") diff --git a/packtype/types/scalar.py b/packtype/types/scalar.py index 250dc29..1d16b65 100644 --- a/packtype/types/scalar.py +++ b/packtype/types/scalar.py @@ -2,10 +2,12 @@ # SPDX-License-Identifier: Apache-2.0 # -from .primitive import NumericPrimitive +from .base import Base +from .primitive import NumericPrimitive, NumericType -class Scalar(NumericPrimitive): + +class ScalarType(NumericType): _PT_WIDTH: int = 1 @classmethod @@ -14,3 +16,7 @@ def _pt_name(cls) -> str: return cls._PT_ATTACHED_TO._pt_lookup(cls) else: return NumericPrimitive._pt_name(cls) + + +class Scalar(NumericPrimitive): + _PT_META_USE_TYPE: type[Base] = ScalarType diff --git a/packtype/utils/__init__.py b/packtype/utils/__init__.py index 407068d..e2f69ea 100644 --- a/packtype/utils/__init__.py +++ b/packtype/utils/__init__.py @@ -2,10 +2,21 @@ # SPDX-License-Identifier: Apache-2.0 # -from . import constant, enum, package, struct, union -from .basic import clog2, get_doc, get_name, get_source, get_width, is_signed, pack, unpack +from . import array, constant, enum, package, struct, union +from .basic import ( + clog2, + get_doc, + get_name, + get_source, + get_width, + is_scalar, + is_signed, + pack, + unpack, +) __all__ = [ + "array", "clog2", "constant", "enum", @@ -13,6 +24,7 @@ "get_name", "get_source", "get_width", + "is_scalar", "is_signed", "pack", "package", diff --git a/packtype/utils/array.py b/packtype/utils/array.py new file mode 100644 index 0000000..f96dd27 --- /dev/null +++ b/packtype/utils/array.py @@ -0,0 +1,16 @@ +# Copyright 2023-2025, Peter Birch, mailto:peter@intuity.io +# SPDX-License-Identifier: Apache-2.0 +# + +from typing import Any + +from ..types.array import PackedArray + + +def is_packed_array(ptype: Any) -> bool: + """ + Determine if a field is a packed array instance. + :param ptype: The field to check + :return: True if the field is a packed array instance, False otherwise + """ + return isinstance(ptype, PackedArray) diff --git a/packtype/utils/basic.py b/packtype/utils/basic.py index 4e8b71a..89c6e3f 100644 --- a/packtype/utils/basic.py +++ b/packtype/utils/basic.py @@ -9,7 +9,8 @@ from ..types.assembly import PackedAssembly from ..types.base import Base from ..types.enum import Enum -from ..types.primitive import NumericPrimitive +from ..types.primitive import NumericType +from ..types.scalar import ScalarType from ..types.union import Union @@ -24,19 +25,16 @@ def clog2(x: int) -> int: def get_width( - ptype: type[PackedAssembly | Enum | NumericPrimitive | Union] - | PackedAssembly - | NumericPrimitive - | Union, + ptype: type[PackedAssembly | Enum | NumericType | Union] | PackedAssembly | NumericType | Union, ) -> int: """ Get the width of a Packtype definition :param ptype: The Packtype definition to inspect :return: The width in bits of the Packtype definition """ - if isinstance(ptype, PackedAssembly | Enum | NumericPrimitive | Union): + if isinstance(ptype, PackedAssembly | Enum | NumericType | Union): return ptype._pt_width - elif issubclass(ptype, PackedAssembly | Enum | NumericPrimitive | Union): + elif issubclass(ptype, PackedAssembly | Enum | NumericType | Union): return ptype._PT_WIDTH elif issubclass(ptype, Alias): return get_width(ptype._PT_ALIAS) @@ -86,15 +84,15 @@ def get_source(ptype: type[Base] | Base) -> tuple[str, int]: raise TypeError(f"{ptype} is not a Packtype definition") -def is_signed(ptype: type[NumericPrimitive] | NumericPrimitive) -> bool: +def is_signed(ptype: type[NumericType] | NumericType) -> bool: """ Check if a Packtype definition is signed :param ptype: The Packtype definition to check :return: True if the definition is signed, False otherwise """ - if isinstance(ptype, NumericPrimitive): + if isinstance(ptype, NumericType): return ptype._pt_signed - elif issubclass(ptype, NumericPrimitive): + elif issubclass(ptype, NumericType): return ptype._PT_SIGNED else: raise TypeError(f"{ptype} is not a Packtype definition") @@ -111,7 +109,7 @@ def unpack(ptype: type[Base], value: int) -> Base: raise TypeError(f"{ptype} is an instance of a Packtype definition") if not issubclass(ptype, Base): raise TypeError(f"{ptype} is not a Packtype definition") - if issubclass(ptype, NumericPrimitive): + if issubclass(ptype, NumericType): return ptype(value) elif issubclass(ptype, Enum): return ptype._pt_cast(value) @@ -128,3 +126,14 @@ def pack(pinst: Base) -> int: if inspect.isclass(pinst): raise TypeError(f"{pinst} is not an instance of a Packtype definition") return int(pinst) + + +def is_scalar(ptype: type[Base] | Base) -> bool: + """ + Check if a Packtype definition is a scalar type + :param ptype: The Packtype definition to check + :return: True if the definition is a scalar type, False otherwise + """ + return isinstance(ptype, ScalarType) or ( + inspect.isclass(ptype) and issubclass(ptype, ScalarType) + ) diff --git a/packtype/utils/struct.py b/packtype/utils/struct.py index cd197aa..120276b 100644 --- a/packtype/utils/struct.py +++ b/packtype/utils/struct.py @@ -3,7 +3,7 @@ # from ..types.base import Base -from ..types.scalar import Scalar +from ..types.scalar import ScalarType from ..types.struct import Struct from .basic import get_name @@ -51,7 +51,7 @@ def is_simple_field(field: Base) -> bool: :param field: The field to check :return: True if the field is a simple field, False otherwise """ - return isinstance(field, Scalar) and not field._PT_ATTACHED_TO + return isinstance(field, ScalarType) and not field._PT_ATTACHED_TO def get_field_type(field: Base) -> str | None: diff --git a/packtype/utils/union.py b/packtype/utils/union.py index e8e35b4..359eae1 100644 --- a/packtype/utils/union.py +++ b/packtype/utils/union.py @@ -5,7 +5,7 @@ from collections.abc import Iterable from ..types.base import Base -from ..types.scalar import Scalar +from ..types.scalar import ScalarType from ..types.union import Union from .basic import get_name @@ -44,7 +44,7 @@ def is_simple_member(member: Base) -> bool: :param member: The member to check :return: True if the member is a simple member, False otherwise """ - return isinstance(member, Scalar) and not member._PT_ATTACHED_TO + return isinstance(member, ScalarType) and not member._PT_ATTACHED_TO def get_member_type(member: Base) -> str | None: diff --git a/pyproject.toml b/pyproject.toml index b0ba077..b5d7d9f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "poetry.masonry.api" [tool.poetry] name = "packtype" -version = "3.0.1" +version = "3.0.2" description = "Packed data structure specifications for multi-language hardware projects" authors = ["Peter Birch "] license = "Apache-2.0" diff --git a/tests/grammar/test_alias.py b/tests/grammar/test_alias.py index 32e5873..1ced254 100644 --- a/tests/grammar/test_alias.py +++ b/tests/grammar/test_alias.py @@ -4,8 +4,9 @@ import pytest -from packtype import Alias, Scalar +from packtype import Alias from packtype.grammar import UnknownEntityError, parse_string +from packtype.types.scalar import ScalarType from packtype.utils import get_width from ..fixtures import reset_registry @@ -15,8 +16,9 @@ def test_parse_alias(): """Test parsing an alias definition within a package""" - pkg = parse_string( - """ + pkg = next( + parse_string( + """ package the_package { // Original scalar original: scalar[8] @@ -24,10 +26,11 @@ def test_parse_alias(): alias: original } """ + ) ) assert len(pkg._PT_FIELDS) == 2 # original - assert issubclass(pkg.original, Scalar) + assert issubclass(pkg.original, ScalarType) assert get_width(pkg.original) == 8 # alias assert issubclass(pkg.alias, Alias) @@ -40,10 +43,12 @@ def test_parse_alias_bad_reference(): UnknownEntityError, match="Failed to resolve 'non_existent' to a known constant or type", ): - parse_string( - """ + next( + parse_string( + """ package the_package { alias: non_existent } """ + ) ) diff --git a/tests/grammar/test_constant.py b/tests/grammar/test_constant.py index 9924314..0cb1537 100644 --- a/tests/grammar/test_constant.py +++ b/tests/grammar/test_constant.py @@ -15,8 +15,9 @@ def test_parse_constant(): """Test parsing a constant definition within a package""" - pkg = parse_string( - """ + pkg = next( + parse_string( + """ package the_package { A: constant = 42 "Unsized declaration" @@ -26,6 +27,7 @@ def test_parse_constant(): "Declarations are case insensitive" } """ + ) ) assert len(pkg._PT_FIELDS) == 3 # A @@ -48,26 +50,30 @@ def test_parse_constant(): def test_parse_constant_keep_expression(): """Test keeping the expression when parsing a constant definition""" # Not kept - pkg = parse_string( - """ + pkg = next( + parse_string( + """ package the_package { A: constant = 1 B: constant = 2 C: constant = A + B } """ + ) )() assert pkg.C._PT_EXPRESSION is None # Kept - pkg = parse_string( - """ + pkg = next( + parse_string( + """ package the_package { A: constant = 1 B: constant = 2 C: constant = A + B } """, - keep_expression=True, + keep_expression=True, + ) )() assert pkg.C._PT_EXPRESSION is not None assert pkg.C._PT_EXPRESSION.evaluate({"A": 4, "B": 5}.get) == 4 + 5 @@ -76,14 +82,16 @@ def test_parse_constant_keep_expression(): def test_parse_constant_override(): """Test parsing a constant definition within a package""" # Parse without overrides - pkg = parse_string( - """ + pkg = next( + parse_string( + """ package the_package { A: constant = 42 B: constant = 39 C: constant = A + B } """ + ) ) assert len(pkg._PT_FIELDS) == 3 # A @@ -96,18 +104,20 @@ def test_parse_constant_override(): assert isinstance(pkg.C, Constant) assert pkg.C.value == 42 + 39 # Parse with overrides - pkg = parse_string( - """ + pkg = next( + parse_string( + """ package the_package { A: constant = 42 B: constant = 39 C: constant = A + B } """, - constant_overrides={ - "A": 123, - "B": 456, - }, + constant_overrides={ + "A": 123, + "B": 456, + }, + ) ) assert len(pkg._PT_FIELDS) == 3 # A @@ -127,13 +137,15 @@ def test_parse_constant_override_unknown(): UnknownEntityError, match="Constant override 'UNKNOWN' does not match any defined constant", ): - parse_string( - """ + next( + parse_string( + """ package the_package { A: constant = 42 } """, - constant_overrides={"UNKNOWN": 123}, + constant_overrides={"UNKNOWN": 123}, + ) ) @@ -143,29 +155,33 @@ def test_parse_constant_override_type_mismatch(): TypeError, match=( "Constant override 'b' does not match a constant in package " - "'the_package', found Scalar_42U_0" + "'the_package', found ScalarType_42U_0" ), ): - parse_string( - """ + next( + parse_string( + """ package the_package { A: constant = 42 b: scalar[A] } """, - constant_overrides={"b": 123}, + constant_overrides={"b": 123}, + ) ) def test_parse_constant_no_value(): """Test parsing a constant definition without a value.""" with pytest.raises(ParseError, match="Failed to parse input"): - parse_string( - """ + next( + parse_string( + """ package the_package { A: CONSTANT[12] } """ + ) ) @@ -174,19 +190,22 @@ def test_parse_constant_bad_reference(): with pytest.raises( UnknownEntityError, match="Failed to resolve 'NON_EXISTENT' to a known constant" ): - parse_string( - """ + next( + parse_string( + """ package the_package { A: CONSTANT[12] = NON_EXISTENT + 1 } """ + ) ) def test_parse_constant_expression(): """Check that a complex expression is evaluated correctly""" - pkg = parse_string( - """ + pkg = next( + parse_string( + """ package the_package { A: Constant = 32 B: Constant = 9 @@ -196,5 +215,6 @@ def test_parse_constant_expression(): F: Constant = ((A * B) ** C) / D + E } """ + ) ) assert int(pkg.F) == (32 * 9) ** 2 // -4 + 43 diff --git a/tests/grammar/test_enum.py b/tests/grammar/test_enum.py index 8eb3c76..70b65ca 100644 --- a/tests/grammar/test_enum.py +++ b/tests/grammar/test_enum.py @@ -16,8 +16,9 @@ def test_parse_enum(): """Test parsing an enum definition.""" - pkg = parse_string( - """ + pkg = next( + parse_string( + """ package the_package { // Default behaviours (implicit width, indexed) enum a { @@ -84,6 +85,7 @@ def test_parse_enum(): } } """ + ) ) assert len(pkg._PT_FIELDS) == 9 # a @@ -135,8 +137,9 @@ def test_parse_enum(): def test_parse_enum_description(): """Test parsing an enum definition with a description.""" - pkg = parse_string( - """ + pkg = next( + parse_string( + """ package the_package { // Default behaviours (implicit width, indexed) enum a { @@ -148,6 +151,7 @@ def test_parse_enum_description(): } } """ + ) ) assert len(pkg._PT_FIELDS) == 1 assert issubclass(pkg.a, Enum) @@ -162,8 +166,9 @@ def test_parse_enum_description(): def test_parse_enum_modifiers(): """Test parsing an enum definition with modifiers.""" - pkg = parse_string( - """ + pkg = next( + parse_string( + """ package the_package { // Default behaviours (implicit width, indexed) enum a { @@ -176,6 +181,7 @@ def test_parse_enum_modifiers(): } } """ + ) ) assert len(pkg._PT_FIELDS) == 1 assert issubclass(pkg.a, Enum) @@ -191,8 +197,9 @@ def test_parse_enum_modifiers(): def test_parse_enum_descriptions(): """Test parsing an enum definition with descriptions.""" - pkg = parse_string( - """ + pkg = next( + parse_string( + """ package the_package { // Default behaviours (implicit width, indexed) enum a { @@ -208,6 +215,7 @@ def test_parse_enum_descriptions(): } } """ + ) ) assert len(pkg._PT_FIELDS) == 1 assert issubclass(pkg.a, Enum) @@ -227,14 +235,16 @@ def test_parse_enum_descriptions(): def test_parse_enum_bad_field(): """Test parsing an enum definition with a bad field.""" with pytest.raises(ParseError, match="Failed to parse input"): - parse_string( - """ + next( + parse_string( + """ package the_package { enum a { A : scalar[4] } } """ + ) ) @@ -243,8 +253,9 @@ def test_parse_enum_bad_width(): with pytest.raises( EnumError, match="Enum entry E has value 4 that cannot be encoded in a bit width of 2" ): - parse_string( - """ + next( + parse_string( + """ package the_package { enum [2] a { A @@ -255,14 +266,16 @@ def test_parse_enum_bad_width(): } } """ + ) ) def test_parse_enum_bad_modifier(): """Test parsing an enum where an unrecognised modifier is used.""" with pytest.raises(BadAttributeError, match="Unsupported attribute 'blargh' for Enum"): - parse_string( - """ + next( + parse_string( + """ package the_package { enum [2] a { @blargh=123 @@ -273,4 +286,5 @@ def test_parse_enum_bad_modifier(): } } """ + ) ) diff --git a/tests/grammar/test_import.py b/tests/grammar/test_import.py index 4661dfb..bdf20f6 100644 --- a/tests/grammar/test_import.py +++ b/tests/grammar/test_import.py @@ -4,8 +4,9 @@ import pytest -from packtype import Constant, Scalar +from packtype import Constant from packtype.grammar import ParseError, parse_string +from packtype.types.scalar import ScalarType from packtype.types.struct import Struct from packtype.utils import get_width @@ -17,17 +18,20 @@ def test_parse_import(): """Test parsing an import statement""" # First package - pkg_a = parse_string( - """ + pkg_a = next( + parse_string( + """ package pkg_a { A: constant = 42 a_sclr: scalar[A] } """ + ) ) # Second package with import - pkg_b = parse_string( - """ + pkg_b = next( + parse_string( + """ package another_package { import pkg_a::A import pkg_a::a_sclr @@ -38,7 +42,8 @@ def test_parse_import(): } } """, - namespaces={"pkg_a": pkg_a}, + namespaces={"pkg_a": pkg_a}, + ) ) # Package A @@ -48,7 +53,7 @@ def test_parse_import(): assert pkg_a.A.value == 42 assert get_width(pkg_a.A) == -1 - assert issubclass(pkg_a.a_sclr, Scalar) + assert issubclass(pkg_a.a_sclr, ScalarType) assert get_width(pkg_a.a_sclr) == 42 # Package B @@ -65,32 +70,38 @@ def test_parse_import(): def test_parse_import_bad(): """Test parsing an import statement with a bad import statement""" with pytest.raises(ParseError, match="Failed to parse input"): - parse_string( - """ + next( + parse_string( + """ package the_package { import pkg_a:: } """ + ) ) def test_parse_import_unknown(): """Test parsing an import statement with an unknwown package""" with pytest.raises(ImportError, match="Unknown package 'pkg_a'"): - parse_string( - """ + next( + parse_string( + """ package the_package { import pkg_a::B } """ + ) ) with pytest.raises(ImportError, match="'B' not declared in package 'pkg_a'"): - pkg_a = parse_string(r"package pkg_a {}") - parse_string( - """ + pkg_a = next(parse_string(r"package pkg_a {}")) + next( + parse_string( + """ package the_package { import pkg_a::B } """, - namespaces={"pkg_a": pkg_a}, + namespaces={"pkg_a": pkg_a}, + ) ) diff --git a/tests/grammar/test_package.py b/tests/grammar/test_package.py index d34c6bb..153ce16 100644 --- a/tests/grammar/test_package.py +++ b/tests/grammar/test_package.py @@ -13,12 +13,14 @@ def test_parse_package(): """Test parsing a package definition.""" - pkg = parse_string( - """ + pkg = next( + parse_string( + """ package the_package { "This describes the package" } """ + ) ) assert pkg.__name__ == "the_package" assert pkg._pt_name() == "the_package" @@ -29,31 +31,37 @@ def test_parse_package(): def test_parse_package_unclosed(): """Test parsing a package definition that is not closed.""" with pytest.raises(ParseError, match="Failed to parse input"): - parse_string( - """ + next( + parse_string( + """ package the_package { "This describes the package" """ + ) ) def test_parse_package_collision(): """Check that multiple definitions within a package of the same name raises an error.""" with pytest.raises(RedefinitionError, match="'the_name' is already defined as a Scalar"): - parse_string( - """ + next( + parse_string( + """ package the_package { the_name : scalar[3] the_name : constant = 42 } """ + ) ) with pytest.raises(RedefinitionError, match="'the_name' is already defined as a Constant"): - parse_string( - """ + next( + parse_string( + """ package the_package { the_name : constant = 42 the_name : scalar[3] } """ + ) ) diff --git a/tests/grammar/test_scalar.py b/tests/grammar/test_scalar.py index b3f4ba6..93d2398 100644 --- a/tests/grammar/test_scalar.py +++ b/tests/grammar/test_scalar.py @@ -4,8 +4,8 @@ import pytest -from packtype import Scalar from packtype.grammar import ParseError, parse_string +from packtype.types.scalar import ScalarType from packtype.utils import get_width from ..fixtures import reset_registry @@ -15,8 +15,9 @@ def test_parse_scalar(): """Test parsing a scalar definition within a package""" - pkg = parse_string( - """ + pkg = next( + parse_string( + """ package the_package { single_bit: scalar "Single bit scalar" @@ -26,18 +27,19 @@ def test_parse_scalar(): "Declarations are case insensitive" } """ + ) ) assert len(pkg._PT_FIELDS) == 3 # single_bit - assert issubclass(pkg.single_bit, Scalar) + assert issubclass(pkg.single_bit, ScalarType) assert get_width(pkg.single_bit) == 1 assert pkg.single_bit.__doc__ == "Single bit scalar" # multi_bit_8 - assert issubclass(pkg.multi_bit_8, Scalar) + assert issubclass(pkg.multi_bit_8, ScalarType) assert get_width(pkg.multi_bit_8) == 8 assert pkg.multi_bit_8.__doc__ == "Multi-bit scalar" # multi_bit_12 - assert issubclass(pkg.multi_bit_12, Scalar) + assert issubclass(pkg.multi_bit_12, ScalarType) assert get_width(pkg.multi_bit_12) == 12 assert pkg.multi_bit_12.__doc__ == "Declarations are case insensitive" @@ -45,10 +47,12 @@ def test_parse_scalar(): def test_parse_scalar_bad_assign(): """Test parsing a scalar definition with an invalid assignment.""" with pytest.raises(ParseError, match="Failed to parse input"): - parse_string( - """ + next( + parse_string( + """ package the_package { A: scalar[8] = 42 } """ + ) ) diff --git a/tests/grammar/test_struct.py b/tests/grammar/test_struct.py index c3ce894..1dc7748 100644 --- a/tests/grammar/test_struct.py +++ b/tests/grammar/test_struct.py @@ -16,8 +16,9 @@ def test_parse_struct(): """Test parsing a struct definition.""" - pkg = parse_string( - """ + pkg = next( + parse_string( + """ package the_package { // Implicit width, implicitly packed from LSB struct a { @@ -69,6 +70,7 @@ def test_parse_struct(): } } """ + ) ) assert len(pkg._PT_FIELDS) == 8 # a @@ -115,8 +117,9 @@ def test_parse_struct(): def test_parse_struct_description(): """Check that a struct can have a description""" - pkg = parse_string( - """ + pkg = next( + parse_string( + """ package the_package { simple_type_t : scalar[3] @@ -129,6 +132,7 @@ def test_parse_struct_description(): } } """ + ) ) assert len(pkg._PT_FIELDS) == 2 assert issubclass(pkg.simple_struct, Struct) @@ -138,8 +142,9 @@ def test_parse_struct_description(): def test_parse_struct_reference(): """Check that a struct can reference other known types""" - pkg = parse_string( - """ + pkg = next( + parse_string( + """ package the_package { single_bit: scalar multi_bit: scalar[8] @@ -154,6 +159,7 @@ def test_parse_struct_reference(): } } """ + ) ) assert len(pkg._PT_FIELDS) == 4 assert issubclass(pkg.compound_struct, Struct) @@ -172,8 +178,9 @@ def test_parse_struct_oversized(): match="Fields of oversized_struct total 56 bits which does not fit " "within the specified width of 8 bits", ): - parse_string( - """ + next( + parse_string( + """ package the_package { struct [8] oversized_struct { a: scalar[8] @@ -182,14 +189,16 @@ def test_parse_struct_oversized(): } } """ + ) ) def test_parse_struct_bad_decl(): """Check that an error is raised if packing order and width are mixed up""" with pytest.raises(ParseError, match="Failed to parse"): - parse_string( - """ + next( + parse_string( + """ package the_package { struct [60] msb bad_struct { a: scalar[8] @@ -198,6 +207,7 @@ def test_parse_struct_bad_decl(): } } """ + ) ) @@ -206,12 +216,14 @@ def test_parse_struct_bad_field_ref(): with pytest.raises( UnknownEntityError, match="Failed to resolve 'non_existent' to a known constant or type" ): - parse_string( - """ + next( + parse_string( + """ package the_package { struct bad_struct { a: non_existent } } """ + ) ) diff --git a/tests/grammar/test_union.py b/tests/grammar/test_union.py index 10809bd..b05658d 100644 --- a/tests/grammar/test_union.py +++ b/tests/grammar/test_union.py @@ -15,8 +15,9 @@ def test_parse_union(): """Parse simple union definitions.""" - pkg = parse_string( - """ + pkg = next( + parse_string( + """ package the_package { union with_descr { "This is a simple union" @@ -29,6 +30,7 @@ def test_parse_union(): } } """ + ) ) assert len(pkg._PT_FIELDS) == 2 # with_descr @@ -43,8 +45,9 @@ def test_parse_union(): def test_parse_union_complex(): """Check that unions can refer to other types.""" - pkg = parse_string( - """ + pkg = next( + parse_string( + """ package the_package { a_scalar: scalar[8] @@ -71,6 +74,7 @@ def test_parse_union_complex(): } } """ + ) ) assert len(pkg._PT_FIELDS) == 4 assert issubclass(pkg.complex, Union) @@ -87,8 +91,9 @@ def test_parse_union_mismatched_sizes(): UnionError, match="Union member b has a width of 4 that differs from the expected width of 2", ): - parse_string( - """ + next( + parse_string( + """ package the_package { union mismatched { a: scalar[2] @@ -96,6 +101,7 @@ def test_parse_union_mismatched_sizes(): } } """ + ) ) @@ -104,12 +110,14 @@ def test_parse_union_bad_field_ref(): with pytest.raises( UnknownEntityError, match="Failed to resolve 'non_existent' to a known constant or type" ): - parse_string( - """ + next( + parse_string( + """ package the_package { union bad_union { a: non_existent } } """ + ) ) diff --git a/tests/pysyntax/test_array.py b/tests/pysyntax/test_array.py index 4b3eedf..49fc691 100644 --- a/tests/pysyntax/test_array.py +++ b/tests/pysyntax/test_array.py @@ -2,6 +2,9 @@ # SPDX-License-Identifier: Apache-2.0 # +import itertools +from random import choice, getrandbits + import packtype from packtype import Constant, Packing, Scalar @@ -18,8 +21,8 @@ class TestPkg: @TestPkg.struct() class TestStruct: ab: Scalar[12] - cd: 3 * Scalar[3] - ef: TestPkg.EF_NUM * Scalar[9] + cd: Scalar[3][3] + ef: Scalar[9][TestPkg.EF_NUM] inst = TestStruct() assert inst._pt_width == 12 + (3 * 3) + (2 * 9) @@ -39,7 +42,7 @@ class TestPkg: @TestPkg.struct() class TestStruct: ab: Scalar[12] - cd: 3 * Scalar[3] + cd: Scalar[3][3] ef: Scalar[9] inst = TestStruct() @@ -60,7 +63,7 @@ class TestPkg: @TestPkg.struct(packing=Packing.FROM_MSB) class TestStruct: ab: Scalar[12] - cd: 3 * Scalar[3] + cd: Scalar[3][3] ef: Scalar[9] inst = TestStruct() @@ -81,7 +84,7 @@ class TestPkg: @TestPkg.struct() class TestStruct: ab: Scalar[12] - cd: 3 * Scalar[3] + cd: Scalar[3][3] ef: Scalar[9] inst = TestStruct._pt_unpack((53 << 21) | (3 << 18) | (2 << 15) | (1 << 12) | 123) @@ -100,7 +103,7 @@ class TestPkg: @TestPkg.struct(packing=Packing.FROM_MSB) class TestStruct: ab: Scalar[12] - cd: 3 * Scalar[3] + cd: Scalar[3][3] ef: Scalar[9] inst = TestStruct._pt_unpack((123 << 18) | (1 << 15) | (2 << 12) | (3 << 9) | 53) @@ -109,3 +112,180 @@ class TestStruct: assert int(inst.cd[1]) == 2 assert int(inst.cd[2]) == 3 assert int(inst.ef) == 53 + + +def test_array_multidimensional_scalar(): + """Basic test that a multi-dimensional scalar value can be declared""" + + @packtype.package() + class TestPkg: + # This will declare a Scalar[4] with dimensions 5x6x7 + multi: Scalar[4][5][6][7] + + inst = TestPkg.multi() + # Check size and dimensions + assert inst._pt_width == 4 * 5 * 6 * 7 + assert len(inst) == 7 + assert len(inst[0]) == 6 + assert len(inst[0][0]) == 5 + # Write in data + ref = {} + raw = 0 + for x, y, z in itertools.product(range(7), range(6), range(5)): + ref[x, y, z] = getrandbits(4) + raw |= ref[x, y, z] << ((x * 6 * 5 * 4) + (y * 5 * 4) + (z * 4)) + inst[x][y][z] = ref[x, y, z] + # Check persistance + for x, y, z in itertools.product(range(7), range(6), range(5)): + assert inst[x][y][z] == ref[x, y, z] + # Check overall value + assert int(inst) == raw + + +def test_array_multidimensional_rich(): + """Test that multi-dimensional structs, enums, and unions can be declared""" + + @packtype.package() + class Pkg1D: + pass + + @Pkg1D.struct() + class Struct1D: + field_a: Scalar[1] + field_b: Scalar[2] + + @Pkg1D.enum() + class Enum1D: + VAL_A: Constant + VAL_B: Constant + VAL_C: Constant + + @Pkg1D.union() + class Union1D: + raw: Scalar[3] + struct: Struct1D + + @packtype.package() + class Pkg2D: + Struct2D: Struct1D[4] + Enum2D: Enum1D[5] + Union2D: Union1D[6] + + @packtype.package() + class Pkg3D: + Struct3D: Pkg2D.Struct2D[2] + Enum3D: Pkg2D.Enum2D[3] + Union3D: Pkg2D.Union2D[4] + + # === Check struct === + inst_struct = Pkg3D.Struct3D() + assert inst_struct._pt_width == (1 + 2) * 4 * 2 + assert len(inst_struct) == 2 + assert len(inst_struct[0]) == 4 + + # Write in data + ref = {} + raw = 0 + for x, y in itertools.product(range(2), range(4)): + ref[x, y] = (a := getrandbits(1)), (b := getrandbits(2)) + raw |= (a | (b << 1)) << ((x * 4 * 3) + (y * 3)) + inst_struct[x][y].field_a = a + inst_struct[x][y].field_b = b + + # Check persistance + for x, y in itertools.product(range(2), range(4)): + assert inst_struct[x][y].field_a == ref[x, y][0] + assert inst_struct[x][y].field_b == ref[x, y][1] + + # Check overall value + assert int(inst_struct) == raw + + # === Check enum === + inst_enum = Pkg3D.Enum3D() + assert inst_enum._pt_width == 2 * 5 * 3 + assert len(inst_enum) == 3 + assert len(inst_enum[0]) == 5 + + # Write in data + ref = {} + raw = 0 + for x, y in itertools.product(range(3), range(5)): + ref[x, y] = choice((Enum1D.VAL_A, Enum1D.VAL_B, Enum1D.VAL_C)) + raw |= ref[x, y] << ((x * 5 * 2) + (y * 2)) + inst_enum[x][y] = ref[x, y] + + # Check persistance + for x, y in itertools.product(range(3), range(5)): + assert inst_enum[x][y] == ref[x, y] + + # Check overall value + assert int(inst_enum) == raw + + # === Check union === + inst_union = Pkg3D.Union3D() + assert inst_union._pt_width == 3 * 6 * 4 + assert len(inst_union) == 4 + assert len(inst_union[0]) == 6 + + # Write in data + ref = {} + raw = 0 + for x, y in itertools.product(range(4), range(6)): + ref[x, y] = getrandbits(3) + raw |= ref[x, y] << ((x * 6 * 3) + (y * 3)) + inst_union[x][y].raw = ref[x, y] + + # Check persistance + for x, y in itertools.product(range(4), range(6)): + assert inst_union[x][y].raw == ref[x, y] + assert inst_union[x][y].struct == ref[x, y] + + # Check overall value + assert int(inst_union) == raw + + +def test_array_multidimensional_struct_field(): + """Test that structs can have multi-dimensional fields""" + + @packtype.package() + class TestPkg: + Scalar3D: Scalar[2][3][4] + + @TestPkg.struct() + class TestStruct: + field_a: TestPkg.Scalar3D + field_b: Scalar[3][4][5] + + inst = TestStruct() + assert inst._pt_width == (2 * 3 * 4) + (3 * 4 * 5) + inst.field_a = (data_a := getrandbits(2 * 3 * 4)) + inst.field_b = (data_b := getrandbits(3 * 4 * 5)) + assert int(inst.field_a) == data_a + assert int(inst.field_b) == data_b + assert int(inst) == data_a | (data_b << (2 * 3 * 4)) + for x, y in itertools.product(range(4), range(3)): + assert inst.field_a[x][y] == (data_a >> ((x * 3 * 2) + (y * 2))) & 0b11 + for x, y in itertools.product(range(5), range(4)): + assert inst.field_b[x][y] == (data_b >> ((x * 4 * 3) + (y * 3))) & 0b111 + + +def test_array_multidimensional_union_member(): + """Test that unions can have multi-dimensional field members""" + + @packtype.package() + class TestPkg: + Scalar3D: Scalar[2][3][4] + + @TestPkg.union() + class TestUnion: + member_a: TestPkg.Scalar3D + member_b: Scalar[2 * 3 * 4] + + inst = TestUnion() + assert inst._pt_width == 2 * 3 * 4 + inst.member_a = (data_a := getrandbits(2 * 3 * 4)) + assert int(inst.member_a) == data_a + assert int(inst.member_b) == data_a + assert int(inst) == data_a + for x, y in itertools.product(range(4), range(3)): + assert inst.member_a[x][y] == (data_a >> ((x * 3 * 2) + (y * 2))) & 0b11 diff --git a/tests/pysyntax/test_package.py b/tests/pysyntax/test_package.py index f34ee3f..662465e 100644 --- a/tests/pysyntax/test_package.py +++ b/tests/pysyntax/test_package.py @@ -82,7 +82,7 @@ class OuterPkg: @OuterPkg.struct() class OuterStruct: ref_td: InnerPkg.InnerType - ref_st: 2 * InnerStruct + ref_st: InnerStruct[2] assert OuterStruct assert InnerPkg._pt_foreign() == set() diff --git a/tests/pysyntax/test_struct.py b/tests/pysyntax/test_struct.py index df9cff8..35aba2e 100644 --- a/tests/pysyntax/test_struct.py +++ b/tests/pysyntax/test_struct.py @@ -323,7 +323,7 @@ class InnerUnion: @TestPkg.struct() class TestStruct: ab: Scalar[12] - cd: 3 * Scalar[3] + cd: Scalar[3][3] ef: Scalar[9] gh: InnerUnion @@ -350,7 +350,7 @@ class TestPkg: @TestPkg.struct() class TestStruct: ab: Scalar[12] - cd: 3 * Scalar[3] + cd: Scalar[3][3] ef: Scalar[9] # Test a single value being assigned to an array diff --git a/tests/utils/test_utils_basic.py b/tests/utils/test_utils_basic.py index cb5b730..2ca9e51 100644 --- a/tests/utils/test_utils_basic.py +++ b/tests/utils/test_utils_basic.py @@ -8,6 +8,7 @@ import packtype from packtype import Constant, Scalar, utils +from packtype.types.scalar import ScalarType from ..fixtures import reset_registry @@ -76,7 +77,7 @@ class TestUnion: # Unpack a scalar inst_sc = utils.unpack(TestPkg.sc_unsigned, 123) - assert isinstance(inst_sc, Scalar) + assert isinstance(inst_sc, ScalarType) assert inst_sc == 123 assert isinstance(utils.pack(inst_sc), int) assert utils.pack(inst_sc) == 123 diff --git a/tests/utils/test_utils_constant.py b/tests/utils/test_utils_constant.py index 8a947b0..d4d059e 100644 --- a/tests/utils/test_utils_constant.py +++ b/tests/utils/test_utils_constant.py @@ -11,15 +11,17 @@ def test_utils_enum_get_entries(): - PackageA = parse_string( # noqa: N806 - """ + PackageA = next( # noqa: N806 + parse_string( + """ package PackageA { A: constant = 1 B: constant = 2 C: constant = A + B } """, - keep_expression=True, + keep_expression=True, + ) ) assert utils.constant.get_expression(PackageA.C) is not None assert utils.constant.get_expression(PackageA.C).evaluate({"A": 3, "B": 4}.get) == 3 + 4