MLIR 22.0.0git
MatchInterfaces.h
Go to the documentation of this file.
1//===- MatchInterfaces.h - Transform Dialect Interfaces ---------*- C++ -*-===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#ifndef MLIR_DIALECT_TRANSFORM_IR_MATCHINTERFACES_H
10#define MLIR_DIALECT_TRANSFORM_IR_MATCHINTERFACES_H
11
12#include <optional>
13#include <type_traits>
14
17#include "llvm/ADT/STLExtras.h"
18
19namespace mlir {
20namespace transform {
21class MatchOpInterface;
22
23namespace detail {
24/// Dispatch `matchOperation` based on Operation* or std::optional<Operation*>
25/// first operand.
26template <typename OpTy>
28 TransformResults &results,
29 TransformState &state) {
30 if constexpr (std::is_same_v<
31 typename llvm::function_traits<
32 decltype(&OpTy::matchOperation)>::template arg_t<0>,
33 Operation *>) {
34 return op.matchOperation(nullptr, results, state);
35 } else {
36 return op.matchOperation(std::nullopt, results, state);
37 }
38}
39} // namespace detail
40
41template <typename OpTy>
43 : public OpTrait::TraitBase<OpTy, AtMostOneOpMatcherOpTrait> {
44 template <typename T>
45 using has_get_operand_handle =
46 decltype(std::declval<T &>().getOperandHandle());
47 template <typename T>
48 using has_match_operation_ptr = decltype(std::declval<T &>().matchOperation(
49 std::declval<Operation *>(), std::declval<TransformResults &>(),
50 std::declval<TransformState &>()));
51 template <typename T>
52 using has_match_operation_optional =
53 decltype(std::declval<T &>().matchOperation(
54 std::declval<std::optional<Operation *>>(),
55 std::declval<TransformResults &>(),
56 std::declval<TransformState &>()));
57
58public:
59 static LogicalResult verifyTrait(Operation *op) {
60 static_assert(llvm::is_detected<has_get_operand_handle, OpTy>::value,
61 "AtMostOneOpMatcherOpTrait/SingleOpMatcherOpTrait expects "
62 "operation type to have the getOperandHandle() method");
63 static_assert(
64 llvm::is_detected<has_match_operation_ptr, OpTy>::value ||
65 llvm::is_detected<has_match_operation_optional, OpTy>::value,
66 "AtMostOneOpMatcherOpTrait/SingleOpMatcherOpTrait expected operation "
67 "type to have either the matchOperation(Operation *, TransformResults "
68 "&, TransformState &) or the matchOperation(std::optional<Operation*>, "
69 "TransformResults &, TransformState &) method");
70
71 // This must be a dynamic assert because interface registration is dynamic.
72 assert(
73 isa<MatchOpInterface>(op) &&
74 "AtMostOneOpMatcherOpTrait/SingleOpMatchOpTrait is only available on "
75 "operations with MatchOpInterface");
76 Value operandHandle = cast<OpTy>(op).getOperandHandle();
77 if (!isa<TransformHandleTypeInterface>(operandHandle.getType())) {
78 return op->emitError() << "AtMostOneOpMatcherOpTrait/"
79 "SingleOpMatchOpTrait requires the op handle "
80 "to be of TransformHandleTypeInterface";
81 }
82
83 return success();
84 }
85
87 TransformResults &results,
88 TransformState &state) {
89 Value operandHandle = cast<OpTy>(this->getOperation()).getOperandHandle();
90 auto payload = state.getPayloadOps(operandHandle);
91 if (!llvm::hasNItemsOrLess(payload, 1)) {
92 return emitDefiniteFailure(this->getOperation()->getLoc())
93 << "AtMostOneOpMatcherOpTrait requires the operand handle to "
94 "point to at most one payload op";
95 }
96 if (payload.empty()) {
97 return detail::matchOptionalOperation(cast<OpTy>(this->getOperation()),
98 results, state);
99 }
100 return cast<OpTy>(this->getOperation())
101 .matchOperation(*payload.begin(), results, state);
102 }
103
105 onlyReadsHandle(this->getOperation()->getOpOperands(), effects);
106 producesHandle(this->getOperation()->getOpResults(), effects);
107 onlyReadsPayload(effects);
108 }
109};
110
111template <typename OpTy>
113
114public:
116 TransformResults &results,
117 TransformState &state) {
118 Value operandHandle = cast<OpTy>(this->getOperation()).getOperandHandle();
119 auto payload = state.getPayloadOps(operandHandle);
120 if (!llvm::hasSingleElement(payload)) {
121 return emitDefiniteFailure(this->getOperation()->getLoc())
122 << "SingleOpMatchOpTrait requires the operand handle to point to "
123 "a single payload op";
124 }
125 return static_cast<AtMostOneOpMatcherOpTrait<OpTy> *>(this)->apply(
126 rewriter, results, state);
127 }
128};
129
130template <typename OpTy>
132 : public OpTrait::TraitBase<OpTy, SingleValueMatcherOpTrait> {
133public:
134 static LogicalResult verifyTrait(Operation *op) {
135 // This must be a dynamic assert because interface registration is
136 // dynamic.
137 assert(isa<MatchOpInterface>(op) &&
138 "SingleValueMatchOpTrait is only available on operations with "
139 "MatchOpInterface");
140
141 Value operandHandle = cast<OpTy>(op).getOperandHandle();
142 if (!isa<TransformValueHandleTypeInterface>(operandHandle.getType())) {
143 return op->emitError() << "SingleValueMatchOpTrait requires an operand "
144 "of TransformValueHandleTypeInterface";
145 }
146
147 return success();
148 }
149
151 TransformResults &results,
152 TransformState &state) {
153 Value operandHandle = cast<OpTy>(this->getOperation()).getOperandHandle();
154 auto payload = state.getPayloadValues(operandHandle);
155 if (!llvm::hasSingleElement(payload)) {
156 return emitDefiniteFailure(this->getOperation()->getLoc())
157 << "SingleValueMatchOpTrait requires the value handle to point "
158 "to a single payload value";
159 }
160
161 return cast<OpTy>(this->getOperation())
162 .matchValue(*payload.begin(), results, state);
163 }
164
166 onlyReadsHandle(this->getOperation()->getOpOperands(), effects);
167 producesHandle(this->getOperation()->getOpResults(), effects);
168 onlyReadsPayload(effects);
169 }
170};
171
172//===----------------------------------------------------------------------===//
173// Printing/parsing for positional specification matchers
174//===----------------------------------------------------------------------===//
175
176/// Parses a positional index specification for transform match operations.
177/// The following forms are accepted:
178///
179/// - `all`: sets `isAll` and returns;
180/// - comma-separated-integer-list: populates `rawDimList` with the values;
181/// - `except` `(` comma-separated-integer-list `)`: populates `rawDimList`
182/// with the values and sets `isInverted`.
183ParseResult parseTransformMatchDims(OpAsmParser &parser,
184 DenseI64ArrayAttr &rawDimList,
185 UnitAttr &isInverted, UnitAttr &isAll);
186
187/// Prints a positional index specification for transform match operations.
189 DenseI64ArrayAttr rawDimList, UnitAttr isInverted,
190 UnitAttr isAll);
191
192//===----------------------------------------------------------------------===//
193// Utilities for positional specification matchers
194//===----------------------------------------------------------------------===//
195
196/// Checks if the positional specification defined is valid and reports errors
197/// otherwise.
199 bool inverted, bool all);
200
201/// Populates `result` with the positional identifiers relative to `maxNumber`.
202/// If `isAll` is set, the result will contain all numbers from `0` to
203/// `maxNumber - 1` inclusive regardless of `rawList`. Otherwise, negative
204/// values from `rawList` are are interpreted as counting backwards from
205/// `maxNumber`, i.e., `-1` is interpreted a `maxNumber - 1`, while positive
206/// numbers remain as is. If `isInverted` is set, populates `result` with those
207/// values from the `0` to `maxNumber - 1` inclusive range that don't appear in
208/// `rawList`. If `rawList` contains values that are greater than or equal to
209/// `maxNumber` or less than `-maxNumber`, produces a silenceable error at the
210/// given location. `maxNumber` must be positive. If `rawList` contains
211/// duplicate numbers or numbers that become duplicate after negative value
212/// remapping, emits a silenceable error.
214expandTargetSpecification(Location loc, bool isAll, bool isInverted,
215 ArrayRef<int64_t> rawList, int64_t maxNumber,
217
218} // namespace transform
219} // namespace mlir
220
221#include "mlir/Dialect/Transform/Interfaces/MatchInterfaces.h.inc"
222
223#endif // MLIR_DIALECT_TRANSFORM_IR_MATCHINTERFACES_H
return success()
The result of a transform IR operation application.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition Location.h:76
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
Helper class for implementing traits.
Operation is the basic unit of execution within MLIR.
Definition Operation.h:88
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
Type getType() const
Return the type of this value.
Definition Value.h:105
void getEffects(SmallVectorImpl< MemoryEffects::EffectInstance > &effects)
static LogicalResult verifyTrait(Operation *op)
DiagnosedSilenceableFailure apply(TransformRewriter &rewriter, TransformResults &results, TransformState &state)
DiagnosedSilenceableFailure apply(TransformRewriter &rewriter, TransformResults &results, TransformState &state)
static LogicalResult verifyTrait(Operation *op)
void getEffects(SmallVectorImpl< MemoryEffects::EffectInstance > &effects)
DiagnosedSilenceableFailure apply(TransformRewriter &rewriter, TransformResults &results, TransformState &state)
Local mapping between values defined by a specific op implementing the TransformOpInterface and the p...
This is a special rewriter to be used in transform op implementations, providing additional helper fu...
The state maintained across applications of various ops implementing the TransformOpInterface.
auto getPayloadOps(Value value) const
Returns an iterator that enumerates all ops that the given transform IR value corresponds to.
auto getPayloadValues(Value handleValue) const
Returns an iterator that enumerates all payload IR values that the given transform IR value correspon...
DiagnosedSilenceableFailure matchOptionalOperation(OpTy op, TransformResults &results, TransformState &state)
Dispatch matchOperation based on Operation* or std::optional<Operation*> first operand.
LogicalResult verifyTransformMatchDimsOp(Operation *op, ArrayRef< int64_t > raw, bool inverted, bool all)
Checks if the positional specification defined is valid and reports errors otherwise.
void printTransformMatchDims(OpAsmPrinter &printer, Operation *op, DenseI64ArrayAttr rawDimList, UnitAttr isInverted, UnitAttr isAll)
Prints a positional index specification for transform match operations.
void onlyReadsPayload(SmallVectorImpl< MemoryEffects::EffectInstance > &effects)
DiagnosedSilenceableFailure expandTargetSpecification(Location loc, bool isAll, bool isInverted, ArrayRef< int64_t > rawList, int64_t maxNumber, SmallVectorImpl< int64_t > &result)
Populates result with the positional identifiers relative to maxNumber.
ParseResult parseTransformMatchDims(OpAsmParser &parser, DenseI64ArrayAttr &rawDimList, UnitAttr &isInverted, UnitAttr &isAll)
Parses a positional index specification for transform match operations.
void producesHandle(ResultRange handles, SmallVectorImpl< MemoryEffects::EffectInstance > &effects)
void onlyReadsHandle(MutableArrayRef< OpOperand > handles, SmallVectorImpl< MemoryEffects::EffectInstance > &effects)
Include the generated interface declarations.
detail::DenseArrayAttrImpl< int64_t > DenseI64ArrayAttr
DiagnosedDefiniteFailure emitDefiniteFailure(Location loc, const Twine &message={})
Emits a definite failure with the given message.