Skip to content

Operation

Operation dataclass

Bases: IRNode

A generic operation. Operation definitions inherit this class.

Source code in xdsl/ir/core.py
 754
 755
 756
 757
 758
 759
 760
 761
 762
 763
 764
 765
 766
 767
 768
 769
 770
 771
 772
 773
 774
 775
 776
 777
 778
 779
 780
 781
 782
 783
 784
 785
 786
 787
 788
 789
 790
 791
 792
 793
 794
 795
 796
 797
 798
 799
 800
 801
 802
 803
 804
 805
 806
 807
 808
 809
 810
 811
 812
 813
 814
 815
 816
 817
 818
 819
 820
 821
 822
 823
 824
 825
 826
 827
 828
 829
 830
 831
 832
 833
 834
 835
 836
 837
 838
 839
 840
 841
 842
 843
 844
 845
 846
 847
 848
 849
 850
 851
 852
 853
 854
 855
 856
 857
 858
 859
 860
 861
 862
 863
 864
 865
 866
 867
 868
 869
 870
 871
 872
 873
 874
 875
 876
 877
 878
 879
 880
 881
 882
 883
 884
 885
 886
 887
 888
 889
 890
 891
 892
 893
 894
 895
 896
 897
 898
 899
 900
 901
 902
 903
 904
 905
 906
 907
 908
 909
 910
 911
 912
 913
 914
 915
 916
 917
 918
 919
 920
 921
 922
 923
 924
 925
 926
 927
 928
 929
 930
 931
 932
 933
 934
 935
 936
 937
 938
 939
 940
 941
 942
 943
 944
 945
 946
 947
 948
 949
 950
 951
 952
 953
 954
 955
 956
 957
 958
 959
 960
 961
 962
 963
 964
 965
 966
 967
 968
 969
 970
 971
 972
 973
 974
 975
 976
 977
 978
 979
 980
 981
 982
 983
 984
 985
 986
 987
 988
 989
 990
 991
 992
 993
 994
 995
 996
 997
 998
 999
1000
1001
1002
1003
1004
1005
1006
1007
1008
1009
1010
1011
1012
1013
1014
1015
1016
1017
1018
1019
1020
1021
1022
1023
1024
1025
1026
1027
1028
1029
1030
1031
1032
1033
1034
1035
1036
1037
1038
1039
1040
1041
1042
1043
1044
1045
1046
1047
1048
1049
1050
1051
1052
1053
1054
1055
1056
1057
1058
1059
1060
1061
1062
1063
1064
1065
1066
1067
1068
1069
1070
1071
1072
1073
1074
1075
1076
1077
1078
1079
1080
1081
1082
1083
1084
1085
1086
1087
1088
1089
1090
1091
1092
1093
1094
1095
1096
1097
1098
1099
1100
1101
1102
1103
1104
1105
1106
1107
1108
1109
1110
1111
1112
1113
1114
1115
1116
1117
1118
1119
1120
1121
1122
1123
1124
1125
1126
1127
1128
1129
1130
1131
1132
1133
1134
1135
1136
1137
1138
1139
1140
1141
1142
1143
1144
1145
1146
1147
1148
1149
1150
1151
1152
1153
1154
1155
1156
1157
1158
1159
1160
1161
1162
1163
1164
1165
1166
1167
1168
1169
1170
1171
1172
1173
1174
1175
1176
1177
1178
1179
1180
1181
1182
1183
1184
1185
1186
1187
1188
1189
1190
1191
1192
1193
1194
1195
1196
1197
1198
1199
1200
1201
1202
1203
1204
1205
1206
1207
1208
1209
1210
1211
1212
1213
1214
1215
1216
1217
1218
1219
1220
1221
1222
1223
1224
1225
1226
1227
1228
1229
1230
1231
1232
1233
1234
1235
1236
1237
1238
1239
1240
1241
1242
1243
1244
1245
1246
1247
1248
1249
1250
1251
1252
1253
1254
1255
1256
1257
1258
1259
1260
1261
1262
1263
1264
1265
1266
1267
1268
1269
1270
1271
1272
1273
1274
1275
1276
1277
1278
1279
1280
1281
1282
1283
1284
1285
1286
1287
1288
1289
1290
1291
1292
1293
1294
1295
1296
1297
1298
1299
1300
1301
1302
1303
1304
1305
1306
1307
1308
1309
1310
1311
1312
1313
1314
1315
@dataclass
class Operation(IRNode):
    """A generic operation. Operation definitions inherit this class."""

    name: ClassVar[str] = field(repr=False)
    """The operation name. Should be a static member of the class"""

    _operands: tuple[SSAValue, ...] = field(default=())
    """The operation operands."""

    results: tuple[OpResult, ...] = field(default=())
    """The results created by the operation."""

    _successors: tuple[Block, ...] = field(default=())
    """
    The basic blocks that the operation may give control to.
    This list should be empty for non-terminator operations.
    """

    properties: dict[str, Attribute] = field(default_factory=dict[str, Attribute])
    """
    The properties attached to the operation.
    Properties are inherent to the definition of an operation's semantics, and
    thus cannot be discarded by transformations.
    """

    attributes: dict[str, Attribute] = field(default_factory=dict[str, Attribute])
    """The attributes attached to the operation."""

    regions: tuple[Region, ...] = field(default=())
    """Regions arguments of the operation."""

    parent: Block | None = field(default=None, repr=False)
    """The block containing this operation."""

    _next_op: Operation | None = field(default=None, repr=False)
    """Next operation in block containing this operation."""

    _prev_op: Operation | None = field(default=None, repr=False)
    """Previous operation in block containing this operation."""

    traits: ClassVar[OpTraits]
    """
    Traits attached to an operation definition.
    This is a static field, and is made empty by default by PyRDL if not set
    by the operation definition.
    """

    @property
    def parent_node(self) -> IRNode | None:
        return self.parent

    @property
    def result_types(self) -> Sequence[Attribute]:
        return tuple(r.type for r in self.results)

    @property
    def operand_types(self) -> Sequence[Attribute]:
        return tuple(operand.type for operand in self.operands)

    def parent_op(self) -> Operation | None:
        if p := self.parent_region():
            return p.parent
        return None

    def parent_region(self) -> Region | None:
        if (p := self.parent_block()) is not None:
            return p.parent
        return None

    def parent_block(self) -> Block | None:
        return self.parent

    @property
    def next_op(self) -> Operation | None:
        """
        Next operation in block containing this operation.
        """
        return self._next_op

    def _insert_next_op(self, new_op: Operation) -> None:
        """
        Sets `next_op` on `self`, and `prev_op` on `self.next_op`.
        """

        if self._next_op is not None:
            # update next node
            self._next_op._prev_op = new_op

        # set next and previous on new node
        new_op._prev_op = self
        new_op._next_op = self._next_op

        # update self
        self._next_op = new_op

    @property
    def prev_op(self) -> Operation | None:
        """
        Previous operation in block containing this operation.
        """
        return self._prev_op

    def _insert_prev_op(self, new_op: Operation) -> None:
        """
        Sets `prev_op` on `self`, and `next_op` on `self.prev_op`.
        """

        if self._prev_op is not None:
            # update prev node
            self._prev_op._next_op = new_op

        # set next and previous on new node
        new_op._prev_op = self._prev_op
        new_op._next_op = self

        # update self
        self._prev_op = new_op

    @property
    def operands(self) -> OpOperands:
        return OpOperands(self)

    @operands.setter
    def operands(self, new: Sequence[SSAValue]):
        new = tuple(new)
        for idx, operand in enumerate(self._operands):
            operand.remove_use(Use(self, idx))
        for idx, operand in enumerate(new):
            operand.add_use(Use(self, idx))
        self._operands = new

    @property
    def successors(self) -> OpSuccessors:
        return OpSuccessors(self)

    @successors.setter
    def successors(self, new: Sequence[Block]):
        new = tuple(new)
        for idx, successor in enumerate(self._successors):
            successor.remove_use(Use(self, idx))
        for idx, successor in enumerate(new):
            successor.add_use(Use(self, idx))
        self._successors = new

    def __post_init__(self):
        assert self.name != ""
        assert isinstance(self.name, str)

    def __init__(
        self,
        *,
        operands: Sequence[SSAValue] = (),
        result_types: Sequence[Attribute] = (),
        properties: Mapping[str, Attribute] = {},
        attributes: Mapping[str, Attribute] = {},
        successors: Sequence[Block] = (),
        regions: Sequence[Region] = (),
    ) -> None:
        super().__init__()

        # This is assumed to exist by Operation.operand setter.
        self.operands = operands

        self.results = tuple(
            OpResult(result_type, self, idx)
            for (idx, result_type) in enumerate(result_types)
        )
        self.properties = dict(properties)
        self.attributes = dict(attributes)
        self.successors = list(successors)
        self.regions = ()
        for region in regions:
            self.add_region(region)

        self.__post_init__()

    @classmethod
    def create(
        cls: type[Self],
        *,
        operands: Sequence[SSAValue] = (),
        result_types: Sequence[Attribute] = (),
        properties: Mapping[str, Attribute] = {},
        attributes: Mapping[str, Attribute] = {},
        successors: Sequence[Block] = (),
        regions: Sequence[Region] = (),
    ) -> Self:
        op = cls.__new__(cls)
        Operation.__init__(
            op,
            operands=operands,
            result_types=result_types,
            properties=properties,
            attributes=attributes,
            successors=successors,
            regions=regions,
        )
        return op

    def add_region(self, region: Region) -> None:
        """Add an unattached region to the operation."""
        if region.parent:
            raise Exception(
                "Cannot add region that is already attached on an operation."
            )
        self.regions += (region,)
        region.parent = self

    def get_region_index(self, region: Region) -> int:
        """Get the region position in the operation."""
        if region.parent is not self:
            raise Exception("Region is not attached to the operation.")
        return next(
            idx for idx, curr_region in enumerate(self.regions) if curr_region is region
        )

    def detach_region(self, region: int | Region) -> Region:
        """
        Detach a region from the operation.
        Returns the detached region.
        """
        if isinstance(region, Region):
            region_idx = self.get_region_index(region)
        else:
            region_idx = region
            region = self.regions[region_idx]
        region.parent = None
        self.regions = self.regions[:region_idx] + self.regions[region_idx + 1 :]
        return region

    def drop_all_references(self) -> None:
        """
        Drop all references to other operations.
        This function is called prior to deleting an operation.
        """
        self.parent = None
        for idx, operand in enumerate(self.operands):
            operand.remove_use(Use(self, idx))
        for region in self.regions:
            region.drop_all_references()

    def walk(
        self, *, reverse: bool = False, region_first: bool = False
    ) -> Iterator[Operation]:
        """
        Iterate all operations contained in the operation (including this one).
        If region_first is set, then the operation regions are iterated before the
        operation. If reverse is set, then the region, block, and operation lists are
        iterated in reverse order.
        """
        if not region_first:
            yield self
        for region in reversed(self.regions) if reverse else self.regions:
            yield from region.walk(reverse=reverse, region_first=region_first)
        if region_first:
            yield self

    def walk_blocks(self, *, reverse: bool = False) -> Iterator[Block]:
        """
        Iterate over all the blocks nested in the region.
        Iterate in reverse order if reverse is True.
        """
        for region in reversed(self.regions) if reverse else self.regions:
            for block in reversed(region.blocks) if reverse else region.blocks:
                yield from block.walk_blocks(reverse=reverse)

    def get_attr_or_prop(self, name: str) -> Attribute | None:
        """
        Get a named attribute or property.
        It first look into the property dictionary, then into the attribute dictionary.
        """
        if name in self.properties:
            return self.properties[name]
        if name in self.attributes:
            return self.attributes[name]
        return None

    def verify(self, verify_nested_ops: bool = True) -> None:
        for operand in self.operands:
            if isinstance(operand, ErasedSSAValue):
                raise Exception("Erased SSA value is used by the operation")

        parent_block = self.parent
        parent_region = None if parent_block is None else parent_block.parent

        if self.successors:
            if parent_block is None or parent_region is None:
                raise VerifyException(
                    f"Operation {self.name} with block successors does not belong to a block or a region"
                )

            if parent_block.last_op is not self:
                raise VerifyException(
                    f"Operation {self.name} with block successors must terminate its parent block"
                )

            for succ in self.successors:
                if succ.parent != parent_block.parent:
                    raise VerifyException(
                        f"Operation {self.name} is branching to a block of a different region"
                    )

        if parent_block is not None and parent_region is not None:
            if parent_block.last_op == self:
                if len(parent_region.blocks) == 1:
                    if (
                        parent_op := parent_region.parent
                    ) is not None and not parent_op.has_trait(NoTerminator):
                        if not self.has_trait(IsTerminator):
                            raise VerifyException(
                                f"Operation {self.name} terminates block in "
                                "single-block region but is not a terminator"
                            )
                elif len(parent_region.blocks) > 1:
                    if not self.has_trait(IsTerminator):
                        raise VerifyException(
                            f"Operation {self.name} terminates block in multi-block "
                            "region but is not a terminator"
                        )

        if verify_nested_ops:
            for region in self.regions:
                region.verify()

        # Custom verifier
        try:
            self.verify_()
        except VerifyException as err:
            self.emit_error(
                "Operation does not verify: " + str(err), underlying_error=err
            )

    def verify_(self) -> None:
        pass

    _OperationType = TypeVar("_OperationType", bound="Operation")

    @classmethod
    def parse(cls: type[_OperationType], parser: Parser) -> _OperationType:
        parser.raise_error(f"Operation {cls.name} does not have a custom format.")

    def print(self, printer: Printer):
        return printer.print_op_with_default_format(self)

    def clone_without_regions(
        self: OpT,
        value_mapper: dict[SSAValue, SSAValue] | None = None,
        block_mapper: dict[Block, Block] | None = None,
        *,
        clone_name_hints: bool = True,
    ) -> OpT:
        """Clone an operation, with empty regions instead."""
        if value_mapper is None:
            value_mapper = {}
        if block_mapper is None:
            block_mapper = {}
        operands = [
            (value_mapper[operand] if operand in value_mapper else operand)
            for operand in self._operands
        ]
        result_types = self.result_types
        attributes = self.attributes.copy()
        properties = self.properties.copy()
        successors = [
            (block_mapper[successor] if successor in block_mapper else successor)
            for successor in self._successors
        ]
        regions = [Region() for _ in self.regions]
        cloned_op = self.create(
            operands=operands,
            result_types=result_types,
            attributes=attributes,
            properties=properties,
            successors=successors,
            regions=regions,
        )
        for self_result, cloned_result in zip(
            self.results, cloned_op.results, strict=True
        ):
            value_mapper[self_result] = cloned_result
            if clone_name_hints:
                cloned_result.name_hint = self_result.name_hint
        return cloned_op

    def clone(
        self: OpT,
        value_mapper: dict[SSAValue, SSAValue] | None = None,
        block_mapper: dict[Block, Block] | None = None,
        *,
        clone_name_hints: bool = True,
    ) -> OpT:
        """Clone an operation with all its regions and operations in them."""
        if value_mapper is None:
            value_mapper = {}
        if block_mapper is None:
            block_mapper = {}
        op = self.clone_without_regions(
            value_mapper, block_mapper, clone_name_hints=clone_name_hints
        )
        for idx, region in enumerate(self.regions):
            region.clone_into(
                op.regions[idx],
                0,
                value_mapper,
                block_mapper,
                clone_name_hints=clone_name_hints,
            )
        return op

    @classmethod
    def has_trait(
        cls,
        trait: type[OpTrait] | OpTrait,
        *,
        value_if_unregistered: bool = True,
    ) -> bool:
        """
        Check if the operation implements a trait with the given parameters.
        If the operation is not registered, return value_if_unregisteed instead.
        """

        from xdsl.dialects.builtin import UnregisteredOp

        if issubclass(cls, UnregisteredOp):
            return value_if_unregistered

        return cls.get_trait(trait) is not None

    @classmethod
    def get_trait(cls, trait: type[OpTraitInvT] | OpTraitInvT) -> OpTraitInvT | None:
        """
        Return a trait with the given type and parameters, if it exists.
        """
        if isinstance(trait, type):
            for t in cls.traits:
                if isinstance(t, cast(type[OpTraitInvT], trait)):
                    return t
        else:
            for t in cls.traits:
                if t == trait:
                    return cast(OpTraitInvT, t)
        return None

    @classmethod
    def get_traits_of_type(cls, trait_type: type[OpTraitInvT]) -> list[OpTraitInvT]:
        """
        Get all the traits of the given type satisfied by this operation.
        """
        return [t for t in cls.traits if isinstance(t, trait_type)]

    def erase(self, safe_erase: bool = True, drop_references: bool = True) -> None:
        """
        Erase the operation, and remove all its references to other operations.
        If safe_erase is specified, check that the operation results are not used.
        """
        assert self.parent is None, (
            "Operation with parents should first be detached " + "before erasure."
        )
        if drop_references:
            self.drop_all_references()
        for result in self.results:
            result.erase(safe_erase=safe_erase)

    def detach(self):
        """Detach the operation from its parent block."""
        if self.parent is None:
            raise Exception("Cannot detach a toplevel operation.")
        self.parent.detach_op(self)

    def is_structurally_equivalent(
        self,
        other: IRNode,
        context: dict[IRNode | SSAValue, IRNode | SSAValue] | None = None,
    ) -> bool:
        """
        Check if two operations are structurally equivalent.
        The context is a mapping of IR nodes to IR nodes that are already known
        to be equivalent. This enables checking whether the use dependencies and
        successors are equivalent.
        """
        if context is None:
            context = {}
        if not isinstance(other, Operation):
            return False
        if self.name != other.name:
            return False
        if (
            len(self.operands) != len(other.operands)
            or len(self.results) != len(other.results)
            or len(self.regions) != len(other.regions)
            or len(self.successors) != len(other.successors)
            or self.attributes != other.attributes
            or self.properties != other.properties
        ):
            return False
        if (
            self.parent is not None
            and other.parent is not None
            and context.get(self.parent) != other.parent
        ):
            return False
        if not all(
            context.get(operand, operand) == other_operand
            for operand, other_operand in zip(self.operands, other.operands)
        ):
            return False
        if not all(
            context.get(successor, successor) == other_successor
            for successor, other_successor in zip(self.successors, other.successors)
        ):
            return False
        if not all(
            region.is_structurally_equivalent(other_region, context)
            for region, other_region in zip(self.regions, other.regions)
        ):
            return False
        # Add results of this operation to the context
        for result, other_result in zip(self.results, other.results):
            context[result] = other_result

        return True

    def emit_error(
        self,
        message: str,
        exception_type: type[Exception] = VerifyException,
        underlying_error: Exception | None = None,
    ) -> NoReturn:
        """Emit an error with the given message."""
        from xdsl.utils.diagnostic import Diagnostic

        diagnostic = Diagnostic()
        diagnostic.add_message(self, message)
        diagnostic.raise_exception(message, self, exception_type, underlying_error)

    @classmethod
    def dialect_name(cls) -> str:
        return Dialect.split_name(cls.name)[0]

    def __eq__(self, other: object) -> bool:
        return self is other

    def __hash__(self) -> int:
        return id(self)

    def __str__(self) -> str:
        from xdsl.printer import Printer

        res = StringIO()
        printer = Printer(stream=res)
        printer.print_op(self)
        return res.getvalue()

    def __format__(self, __format_spec: str) -> str:
        desc = str(self)
        if "\n" in desc:
            # Description is multi-line, indent each line
            desc = "\n".join("\t" + line for line in desc.splitlines())
            # Add newline before and after
            desc = f"\n{desc}\n"
        return f"{self.__class__.__qualname__}({desc})"

name: str = field(repr=False) class-attribute

The operation name. Should be a static member of the class

operands: OpOperands = operands instance-attribute property writable

results: tuple[OpResult, ...] = tuple(OpResult(result_type, self, idx) for (idx, result_type) in enumerate(result_types)) class-attribute instance-attribute

The results created by the operation.

successors: OpSuccessors = list(successors) instance-attribute property writable

properties: dict[str, Attribute] = dict(properties) class-attribute instance-attribute

The properties attached to the operation. Properties are inherent to the definition of an operation's semantics, and thus cannot be discarded by transformations.

attributes: dict[str, Attribute] = dict(attributes) class-attribute instance-attribute

The attributes attached to the operation.

regions: tuple[Region, ...] = () class-attribute instance-attribute

Regions arguments of the operation.

parent: Block | None = field(default=None, repr=False) class-attribute instance-attribute

The block containing this operation.

next_op: Operation | None property

Next operation in block containing this operation.

prev_op: Operation | None property

Previous operation in block containing this operation.

traits: OpTraits class-attribute

Traits attached to an operation definition. This is a static field, and is made empty by default by PyRDL if not set by the operation definition.

OpOperands dataclass

Bases: Sequence[SSAValue]

A view of the operand list of an operation. Any modification to the view is reflected on the operation.

Source code in xdsl/ir/core.py
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
@dataclass
class OpOperands(Sequence[SSAValue]):
    """
    A view of the operand list of an operation.
    Any modification to the view is reflected on the operation.
    """

    _op: Operation
    """The operation owning the operands."""

    @overload
    def __getitem__(self, idx: int) -> SSAValue: ...

    @overload
    def __getitem__(self, idx: slice) -> Sequence[SSAValue]: ...

    def __getitem__(self, idx: int | slice) -> SSAValue | Sequence[SSAValue]:
        return self._op._operands[idx]  # pyright: ignore[reportPrivateUsage]

    def __setitem__(self, idx: int, operand: SSAValue) -> None:
        operands = self._op._operands  # pyright: ignore[reportPrivateUsage]
        operands[idx].remove_use(Use(self._op, idx))
        operand.add_use(Use(self._op, idx))
        new_operands = (*operands[:idx], operand, *operands[idx + 1 :])
        self._op._operands = new_operands  # pyright: ignore[reportPrivateUsage]

    def __iter__(self) -> Iterator[SSAValue]:
        return iter(self._op._operands)  # pyright: ignore[reportPrivateUsage]

    def __len__(self) -> int:
        return len(self._op._operands)  # pyright: ignore[reportPrivateUsage]

    def __eq__(self, other: object):
        if not isinstance(other, OpOperands):
            return False
        return (
            self._op._operands  # pyright: ignore[reportPrivateUsage]
            == other._op._operands  # pyright: ignore[reportPrivateUsage]
        )

    def __hash__(self):
        return hash(self._op._operands)  # pyright: ignore[reportPrivateUsage]

OpTraits

Bases: Iterable[OpTrait]

An operation's traits. Some operations have mutually recursive traits, such as one is always the parent operation of the other. For this case, the operation's traits can be declared lazily, and resolved only at the first use.

Source code in xdsl/ir/core.py
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
class OpTraits(Iterable[OpTrait]):
    """
    An operation's traits.
    Some operations have mutually recursive traits, such as one is always the parent
    operation of the other.
    For this case, the operation's traits can be declared lazily, and resolved only
    at the first use.
    """

    _traits: frozenset[OpTrait] | Callable[[], tuple[OpTrait, ...]]

    def __init__(
        self, traits: frozenset[OpTrait] | Callable[[], tuple[OpTrait, ...]]
    ) -> None:
        self._traits = traits

    @property
    def traits(self) -> frozenset[OpTrait]:
        """Returns a copy of this instance's traits."""
        if callable(self._traits):
            self._traits = frozenset(self._traits())
        return self._traits

    def add_trait(self, trait: OpTrait):
        """Adds a trait to the class."""
        self._traits = self.traits.union((trait,))

    def __iter__(self) -> Iterator[OpTrait]:
        return iter(self.traits)

    def __eq__(self, value: object, /) -> bool:
        return isinstance(value, OpTraits) and self._traits == value._traits

traits: frozenset[OpTrait] property

Returns a copy of this instance's traits.

add_trait(trait: OpTrait)

Adds a trait to the class.

Source code in xdsl/ir/core.py
743
744
745
def add_trait(self, trait: OpTrait):
    """Adds a trait to the class."""
    self._traits = self.traits.union((trait,))