MLIR  22.0.0git
TuneExtensionOps.cpp
Go to the documentation of this file.
1 //===- TuneExtensionOps.cpp - Tune extension for the Transform dialect ----===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
13 #include "llvm/Support/Debug.h"
14 
16 
17 using namespace mlir;
18 
19 static ParseResult parseAlternativesOpSelectedRegion(
20  OpAsmParser &parser, IntegerAttr &selectedRegionAttr,
21  std::optional<OpAsmParser::UnresolvedOperand> &selectedRegionParam);
22 
24  Operation *op,
25  IntegerAttr selectedRegionAttr,
26  Value selectedRegionParam);
27 
28 #define GET_OP_CLASSES
29 #include "mlir/Dialect/Transform/TuneExtension/TuneExtensionOps.cpp.inc"
30 
31 #define DEBUG_TYPE "transform-tune"
32 #define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "] ")
33 
34 //===----------------------------------------------------------------------===//
35 // KnobOp
36 //===----------------------------------------------------------------------===//
37 
38 void transform::tune::KnobOp::getEffects(
40  producesHandle(getOperation()->getOpResults(), effects);
41  onlyReadsPayload(effects);
42 }
43 
45 transform::tune::KnobOp::apply(transform::TransformRewriter &rewriter,
48  if (getSelected()) {
49  results.setParams(llvm::cast<OpResult>(getResult()), *getSelected());
51  }
52 
53  return emitDefiniteFailure()
54  << "non-deterministic choice " << getName()
55  << " is only resolved through providing a `selected` attr";
56 }
57 
58 LogicalResult transform::tune::KnobOp::verify() {
59  if (auto selected = getSelected()) {
60  if (auto optionsArray = dyn_cast<ArrayAttr>(getOptions())) {
61  if (!llvm::is_contained(optionsArray, selected))
62  return emitOpError("provided `selected` attribute is not an element of "
63  "`options` array of attributes");
64  } else
65  LLVM_DEBUG(DBGS() << "cannot verify `selected` attribute " << selected
66  << " is an element of `options` attribute "
67  << getOptions());
68  }
69 
70  return success();
71 }
72 
73 //===----------------------------------------------------------------------===//
74 // AlternativesOp
75 //===----------------------------------------------------------------------===//
76 
78  OpAsmParser &parser, IntegerAttr &selectedRegionAttr,
79  std::optional<OpAsmParser::UnresolvedOperand> &selectedRegionParam) {
80  size_t selectedRegionIdx;
81  OptionalParseResult attrParseRes =
82  parser.parseOptionalInteger(selectedRegionIdx);
83  if (attrParseRes.has_value()) {
84  if (failed(*attrParseRes))
85  return failure();
86 
87  selectedRegionAttr = parser.getBuilder().getIndexAttr(selectedRegionIdx);
88  return success();
89  }
90 
92  auto paramParseRes = parser.parseOptionalOperand(param);
93  if (paramParseRes.has_value()) {
94  if (failed(*paramParseRes))
95  return failure();
96 
97  selectedRegionParam = param;
98  return success();
99  }
100 
101  return parser.emitError(parser.getCurrentLocation())
102  << "expected either an integer attribute or a transform.param operand";
103 }
104 
106  Operation *op,
107  IntegerAttr selectedRegionAttr,
108  Value selectedRegionParam) {
109  if (selectedRegionAttr)
110  printer << selectedRegionAttr.getValue();
111  if (selectedRegionParam)
112  printer << selectedRegionParam;
113 }
114 
115 OperandRange transform::tune::AlternativesOp::getEntrySuccessorOperands(
116  RegionSuccessor successor) {
117  // No operands will be forwarded to the region(s).
118  return getOperands().slice(0, 0);
119 }
120 
121 void transform::tune::AlternativesOp::getSuccessorRegions(
123  if (point.isParent())
124  if (auto selectedRegionIdx = getSelectedRegionAttr())
125  regions.emplace_back(
126  &getAlternatives()[selectedRegionIdx->getSExtValue()],
128  else
129  for (Region &alternative : getAlternatives())
130  regions.emplace_back(&alternative, Block::BlockArgListType());
131  else
132  regions.emplace_back(getOperation(), getOperation()->getResults());
133 }
134 
135 void transform::tune::AlternativesOp::getRegionInvocationBounds(
137  (void)operands;
138  bounds.reserve(getNumRegions());
139 
140  if (auto selectedRegionIdx = getSelectedRegionAttr()) {
141  bounds.resize(getNumRegions(), InvocationBounds(0, 0));
142  bounds[selectedRegionIdx->getSExtValue()] = InvocationBounds(1, 1);
143  } else {
144  bounds.resize(getNumRegions(), InvocationBounds(0, 1));
145  }
146 }
147 
148 void transform::tune::AlternativesOp::getEffects(
150  onlyReadsHandle(getSelectedRegionParamMutable(), effects);
151  producesHandle(getOperation()->getOpResults(), effects);
152  // TODO: should effects from regions be forwarded?
153 }
154 
156 transform::tune::AlternativesOp::apply(transform::TransformRewriter &rewriter,
158  transform::TransformState &state) {
159  std::optional<size_t> selectedRegionIdx;
160 
161  if (auto selectedRegionAttr = getSelectedRegionAttr())
162  selectedRegionIdx = selectedRegionAttr->getSExtValue();
163 
164  if (Value selectedRegionParam = getSelectedRegionParam()) {
165  ArrayRef<Attribute> associatedAttrs = state.getParams(selectedRegionParam);
166  IntegerAttr selectedRegionAttr;
167  if (associatedAttrs.size() != 1 ||
168  !(selectedRegionAttr = dyn_cast<IntegerAttr>(associatedAttrs[0])))
169  return emitDefiniteFailure()
170  << "param should hold exactly one integer attribute, got: "
171  << associatedAttrs[0];
172  selectedRegionIdx = selectedRegionAttr.getValue().getSExtValue();
173  }
174 
175  if (!selectedRegionIdx)
176  return emitDefiniteFailure() << "non-deterministic choice " << getName()
177  << " is only resolved through providing a "
178  "`selected_region` attr/param";
179 
180  if (*selectedRegionIdx < 0 || *selectedRegionIdx >= getNumRegions())
181  return emitDefiniteFailure()
182  << "'selected_region' attribute/param specifies region at index "
183  << *selectedRegionIdx << " while op has only " << getNumRegions()
184  << " regions";
185 
186  Region &selectedRegion = getRegion(*selectedRegionIdx);
187  auto scope = state.make_region_scope(selectedRegion);
188  Block &block = selectedRegion.front();
189  // Apply the region's ops one by one.
190  for (Operation &transform : block.without_terminator()) {
192  state.applyTransform(cast<transform::TransformOpInterface>(transform));
193  if (result.isDefiniteFailure())
194  return result;
195 
196  if (result.isSilenceableFailure()) {
197  for (const auto &res : getResults())
198  results.set(res, {});
199  return result;
200  }
201  }
202  // Forward the operation mapping for values yielded from the region to the
203  // values produced by the alternatives op.
204  transform::detail::forwardTerminatorOperands(&block, state, results);
206 }
207 
209  for (auto *region : getRegions()) {
210  auto yieldTerminator =
211  llvm::dyn_cast_if_present<transform::YieldOp>(region->front().back());
212  if (!yieldTerminator)
213  return emitOpError() << "expected '"
214  << transform::YieldOp::getOperationName()
215  << "' as terminator";
216 
217  if (yieldTerminator->getNumOperands() != getNumResults())
218  return yieldTerminator.emitOpError()
219  << "expected terminator to have as many operands as the parent op "
220  "has results";
221 
222  for (auto [i, operandType, resultType] : llvm::zip_equal(
223  llvm::seq<unsigned>(0, yieldTerminator->getNumOperands()),
224  yieldTerminator->getOperands().getType(), getResultTypes())) {
225  if (operandType == resultType)
226  continue;
227  return yieldTerminator.emitOpError()
228  << "the type of the terminator operand #" << i
229  << " must match the type of the corresponding parent op result ("
230  << operandType << " vs " << resultType << ")";
231  }
232  }
233 
234  if (auto selectedRegionAttr = getSelectedRegionAttr()) {
235  size_t regionIdx = selectedRegionAttr->getSExtValue();
236  if (regionIdx < 0 || regionIdx >= getNumRegions())
237  return emitOpError()
238  << "'selected_region' attribute specifies region at index "
239  << regionIdx << " while op has only " << getNumRegions()
240  << " regions";
241  }
242 
243  return success();
244 }
static ParseResult parseAlternativesOpSelectedRegion(OpAsmParser &parser, IntegerAttr &selectedRegionAttr, std::optional< OpAsmParser::UnresolvedOperand > &selectedRegionParam)
static void printAlternativesOpSelectedRegion(OpAsmPrinter &printer, Operation *op, IntegerAttr selectedRegionAttr, Value selectedRegionParam)
#define DBGS()
virtual OptionalParseResult parseOptionalInteger(APInt &result)=0
Parse an optional integer value from the stream.
virtual Builder & getBuilder() const =0
Return a builder which provides useful access to MLIRContext, global objects like types and attribute...
virtual InFlightDiagnostic emitError(SMLoc loc, const Twine &message={})=0
Emit a diagnostic at the specified location and return failure.
virtual SMLoc getCurrentLocation()=0
Get the location of the next token and store it into the argument.
Block represents an ordered list of Operations.
Definition: Block.h:33
MutableArrayRef< BlockArgument > BlockArgListType
Definition: Block.h:85
iterator_range< iterator > without_terminator()
Return an iterator range over the operation within this block excluding the terminator operation at t...
Definition: Block.h:212
IntegerAttr getIndexAttr(int64_t value)
Definition: Builders.cpp:108
The result of a transform IR operation application.
static DiagnosedSilenceableFailure success()
Constructs a DiagnosedSilenceableFailure in the success state.
bool isDefiniteFailure() const
Returns true if this is a definite failure.
bool isSilenceableFailure() const
Returns true if this is a silenceable failure.
This class represents upper and lower bounds on the number of times a region of a RegionBranchOpInter...
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
virtual OptionalParseResult parseOptionalOperand(UnresolvedOperand &result, bool allowResultNumber=true)=0
Parse a single operand if present.
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
This class implements the operand iterators for the Operation class.
Definition: ValueRange.h:43
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
This class implements Optional functionality for ParseResult.
Definition: OpDefinition.h:40
bool has_value() const
Returns true if we contain a valid ParseResult value.
Definition: OpDefinition.h:50
This class represents a point being branched from in the methods of the RegionBranchOpInterface.
bool isParent() const
Returns true if branching from the parent op.
This class represents a successor of a region.
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition: Region.h:26
Block & front()
Definition: Region.h:65
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Local mapping between values defined by a specific op implementing the TransformOpInterface and the p...
void setParams(OpResult value, ArrayRef< TransformState::Param > params)
Indicates that the result of the transform IR op at the given position corresponds to the given list ...
void set(OpResult value, Range &&ops)
Indicates that the result of the transform IR op at the given position corresponds to the given list ...
This is a special rewriter to be used in transform op implementations, providing additional helper fu...
The state maintained across applications of various ops implementing the TransformOpInterface.
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition: Remarks.h:561
void forwardTerminatorOperands(Block *block, transform::TransformState &state, transform::TransformResults &results)
Populates results with payload associations that match exactly those of the operands to block's termi...
void onlyReadsPayload(SmallVectorImpl< MemoryEffects::EffectInstance > &effects)
void producesHandle(ResultRange handles, SmallVectorImpl< MemoryEffects::EffectInstance > &effects)
void onlyReadsHandle(MutableArrayRef< OpOperand > handles, SmallVectorImpl< MemoryEffects::EffectInstance > &effects)
Include the generated interface declarations.
DiagnosedDefiniteFailure emitDefiniteFailure(Location loc, const Twine &message={})
Emits a definite failure with the given message.
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs,...
Definition: Verifier.cpp:423
This is the representation of an operand reference.