MLIR  20.0.0git
MatchInterfaces.h
Go to the documentation of this file.
1 //===- MatchInterfaces.h - Transform Dialect Interfaces ---------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #ifndef MLIR_DIALECT_TRANSFORM_IR_MATCHINTERFACES_H
10 #define MLIR_DIALECT_TRANSFORM_IR_MATCHINTERFACES_H
11 
12 #include <optional>
13 #include <type_traits>
14 
16 #include "mlir/IR/OpDefinition.h"
17 #include "llvm/ADT/STLExtras.h"
18 
19 namespace mlir {
20 namespace transform {
21 class MatchOpInterface;
22 
23 namespace detail {
24 /// Dispatch `matchOperation` based on Operation* or std::optional<Operation*>
25 /// first operand.
26 template <typename OpTy>
28  TransformResults &results,
29  TransformState &state) {
30  if constexpr (std::is_same_v<
31  typename llvm::function_traits<
32  decltype(&OpTy::matchOperation)>::template arg_t<0>,
33  Operation *>) {
34  return op.matchOperation(nullptr, results, state);
35  } else {
36  return op.matchOperation(std::nullopt, results, state);
37  }
38 }
39 } // namespace detail
40 
41 template <typename OpTy>
43  : public OpTrait::TraitBase<OpTy, AtMostOneOpMatcherOpTrait> {
44  template <typename T>
45  using has_get_operand_handle =
46  decltype(std::declval<T &>().getOperandHandle());
47  template <typename T>
48  using has_match_operation_ptr = decltype(std::declval<T &>().matchOperation(
49  std::declval<Operation *>(), std::declval<TransformResults &>(),
50  std::declval<TransformState &>()));
51  template <typename T>
52  using has_match_operation_optional =
53  decltype(std::declval<T &>().matchOperation(
54  std::declval<std::optional<Operation *>>(),
55  std::declval<TransformResults &>(),
56  std::declval<TransformState &>()));
57 
58 public:
59  static LogicalResult verifyTrait(Operation *op) {
60  static_assert(llvm::is_detected<has_get_operand_handle, OpTy>::value,
61  "AtMostOneOpMatcherOpTrait/SingleOpMatcherOpTrait expects "
62  "operation type to have the getOperandHandle() method");
63  static_assert(
64  llvm::is_detected<has_match_operation_ptr, OpTy>::value ||
65  llvm::is_detected<has_match_operation_optional, OpTy>::value,
66  "AtMostOneOpMatcherOpTrait/SingleOpMatcherOpTrait expected operation "
67  "type to have either the matchOperation(Operation *, TransformResults "
68  "&, TransformState &) or the matchOperation(std::optional<Operation*>, "
69  "TransformResults &, TransformState &) method");
70 
71  // This must be a dynamic assert because interface registration is dynamic.
72  assert(
73  isa<MatchOpInterface>(op) &&
74  "AtMostOneOpMatcherOpTrait/SingleOpMatchOpTrait is only available on "
75  "operations with MatchOpInterface");
76  Value operandHandle = cast<OpTy>(op).getOperandHandle();
77  if (!isa<TransformHandleTypeInterface>(operandHandle.getType())) {
78  return op->emitError() << "AtMostOneOpMatcherOpTrait/"
79  "SingleOpMatchOpTrait requires the op handle "
80  "to be of TransformHandleTypeInterface";
81  }
82 
83  return success();
84  }
85 
87  TransformResults &results,
88  TransformState &state) {
89  Value operandHandle = cast<OpTy>(this->getOperation()).getOperandHandle();
90  auto payload = state.getPayloadOps(operandHandle);
91  if (!llvm::hasNItemsOrLess(payload, 1)) {
92  return emitDefiniteFailure(this->getOperation()->getLoc())
93  << "AtMostOneOpMatcherOpTrait requires the operand handle to "
94  "point to at most one payload op";
95  }
96  if (payload.empty()) {
97  return detail::matchOptionalOperation(cast<OpTy>(this->getOperation()),
98  results, state);
99  }
100  return cast<OpTy>(this->getOperation())
101  .matchOperation(*payload.begin(), results, state);
102  }
103 
105  onlyReadsHandle(this->getOperation()->getOpOperands(), effects);
106  producesHandle(this->getOperation()->getOpResults(), effects);
107  onlyReadsPayload(effects);
108  }
109 };
110 
111 template <typename OpTy>
113 
114 public:
116  TransformResults &results,
117  TransformState &state) {
118  Value operandHandle = cast<OpTy>(this->getOperation()).getOperandHandle();
119  auto payload = state.getPayloadOps(operandHandle);
120  if (!llvm::hasSingleElement(payload)) {
121  return emitDefiniteFailure(this->getOperation()->getLoc())
122  << "SingleOpMatchOpTrait requires the operand handle to point to "
123  "a single payload op";
124  }
125  return static_cast<AtMostOneOpMatcherOpTrait<OpTy> *>(this)->apply(
126  rewriter, results, state);
127  }
128 };
129 
130 template <typename OpTy>
132  : public OpTrait::TraitBase<OpTy, SingleValueMatcherOpTrait> {
133 public:
134  static LogicalResult verifyTrait(Operation *op) {
135  // This must be a dynamic assert because interface registration is
136  // dynamic.
137  assert(isa<MatchOpInterface>(op) &&
138  "SingleValueMatchOpTrait is only available on operations with "
139  "MatchOpInterface");
140 
141  Value operandHandle = cast<OpTy>(op).getOperandHandle();
142  if (!isa<TransformValueHandleTypeInterface>(operandHandle.getType())) {
143  return op->emitError() << "SingleValueMatchOpTrait requires an operand "
144  "of TransformValueHandleTypeInterface";
145  }
146 
147  return success();
148  }
149 
151  TransformResults &results,
152  TransformState &state) {
153  Value operandHandle = cast<OpTy>(this->getOperation()).getOperandHandle();
154  auto payload = state.getPayloadValues(operandHandle);
155  if (!llvm::hasSingleElement(payload)) {
156  return emitDefiniteFailure(this->getOperation()->getLoc())
157  << "SingleValueMatchOpTrait requires the value handle to point "
158  "to a single payload value";
159  }
160 
161  return cast<OpTy>(this->getOperation())
162  .matchValue(*payload.begin(), results, state);
163  }
164 
166  onlyReadsHandle(this->getOperation()->getOpOperands(), effects);
167  producesHandle(this->getOperation()->getOpResults(), effects);
168  onlyReadsPayload(effects);
169  }
170 };
171 
172 //===----------------------------------------------------------------------===//
173 // Printing/parsing for positional specification matchers
174 //===----------------------------------------------------------------------===//
175 
176 /// Parses a positional index specification for transform match operations.
177 /// The following forms are accepted:
178 ///
179 /// - `all`: sets `isAll` and returns;
180 /// - comma-separated-integer-list: populates `rawDimList` with the values;
181 /// - `except` `(` comma-separated-integer-list `)`: populates `rawDimList`
182 /// with the values and sets `isInverted`.
183 ParseResult parseTransformMatchDims(OpAsmParser &parser,
184  DenseI64ArrayAttr &rawDimList,
185  UnitAttr &isInverted, UnitAttr &isAll);
186 
187 /// Prints a positional index specification for transform match operations.
189  DenseI64ArrayAttr rawDimList, UnitAttr isInverted,
190  UnitAttr isAll);
191 
192 //===----------------------------------------------------------------------===//
193 // Utilities for positional specification matchers
194 //===----------------------------------------------------------------------===//
195 
196 /// Checks if the positional specification defined is valid and reports errors
197 /// otherwise.
199  bool inverted, bool all);
200 
201 /// Populates `result` with the positional identifiers relative to `maxNumber`.
202 /// If `isAll` is set, the result will contain all numbers from `0` to
203 /// `maxNumber - 1` inclusive regardless of `rawList`. Otherwise, negative
204 /// values from `rawList` are are interpreted as counting backwards from
205 /// `maxNumber`, i.e., `-1` is interpreted a `maxNumber - 1`, while positive
206 /// numbers remain as is. If `isInverted` is set, populates `result` with those
207 /// values from the `0` to `maxNumber - 1` inclusive range that don't appear in
208 /// `rawList`. If `rawList` contains values that are greater than or equal to
209 /// `maxNumber` or less than `-maxNumber`, produces a silenceable error at the
210 /// given location. `maxNumber` must be positive. If `rawList` contains
211 /// duplicate numbers or numbers that become duplicate after negative value
212 /// remapping, emits a silenceable error.
214 expandTargetSpecification(Location loc, bool isAll, bool isInverted,
215  ArrayRef<int64_t> rawList, int64_t maxNumber,
216  SmallVectorImpl<int64_t> &result);
217 
218 } // namespace transform
219 } // namespace mlir
220 
221 #include "mlir/Dialect/Transform/Interfaces/MatchInterfaces.h.inc"
222 
223 #endif // MLIR_DIALECT_TRANSFORM_IR_MATCHINTERFACES_H
The result of a transform IR operation application.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:66
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
Helper class for implementing traits.
Definition: OpDefinition.h:373
Operation * getOperation()
Return the ultimate Operation being worked on.
Definition: OpDefinition.h:376
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:129
Base class for DenseArrayAttr that is instantiated and specialized for each supported element type be...
void getEffects(SmallVectorImpl< MemoryEffects::EffectInstance > &effects)
static LogicalResult verifyTrait(Operation *op)
DiagnosedSilenceableFailure apply(TransformRewriter &rewriter, TransformResults &results, TransformState &state)
DiagnosedSilenceableFailure apply(TransformRewriter &rewriter, TransformResults &results, TransformState &state)
static LogicalResult verifyTrait(Operation *op)
void getEffects(SmallVectorImpl< MemoryEffects::EffectInstance > &effects)
DiagnosedSilenceableFailure apply(TransformRewriter &rewriter, TransformResults &results, TransformState &state)
Local mapping between values defined by a specific op implementing the TransformOpInterface and the p...
This is a special rewriter to be used in transform op implementations, providing additional helper fu...
The state maintained across applications of various ops implementing the TransformOpInterface.
DiagnosedSilenceableFailure matchOptionalOperation(OpTy op, TransformResults &results, TransformState &state)
Dispatch matchOperation based on Operation* or std::optional<Operation*> first operand.
LogicalResult verifyTransformMatchDimsOp(Operation *op, ArrayRef< int64_t > raw, bool inverted, bool all)
Checks if the positional specification defined is valid and reports errors otherwise.
void printTransformMatchDims(OpAsmPrinter &printer, Operation *op, DenseI64ArrayAttr rawDimList, UnitAttr isInverted, UnitAttr isAll)
Prints a positional index specification for transform match operations.
void onlyReadsPayload(SmallVectorImpl< MemoryEffects::EffectInstance > &effects)
DiagnosedSilenceableFailure expandTargetSpecification(Location loc, bool isAll, bool isInverted, ArrayRef< int64_t > rawList, int64_t maxNumber, SmallVectorImpl< int64_t > &result)
Populates result with the positional identifiers relative to maxNumber.
ParseResult parseTransformMatchDims(OpAsmParser &parser, DenseI64ArrayAttr &rawDimList, UnitAttr &isInverted, UnitAttr &isAll)
Parses a positional index specification for transform match operations.
void producesHandle(ResultRange handles, SmallVectorImpl< MemoryEffects::EffectInstance > &effects)
void onlyReadsHandle(MutableArrayRef< OpOperand > handles, SmallVectorImpl< MemoryEffects::EffectInstance > &effects)
Include the generated interface declarations.
DiagnosedDefiniteFailure emitDefiniteFailure(Location loc, const Twine &message={})
Emits a definite failure with the given message.