MLIR 23.0.0git
TuneExtensionOps.cpp
Go to the documentation of this file.
1//===- TuneExtensionOps.cpp - Tune extension for the Transform dialect ----===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
13#include "llvm/Support/Debug.h"
14
16
17using namespace mlir;
18
19static ParseResult parseAlternativesOpSelectedRegion(
20 OpAsmParser &parser, IntegerAttr &selectedRegionAttr,
21 std::optional<OpAsmParser::UnresolvedOperand> &selectedRegionParam);
22
24 Operation *op,
25 IntegerAttr selectedRegionAttr,
26 Value selectedRegionParam);
27
28#define GET_OP_CLASSES
29#include "mlir/Dialect/Transform/TuneExtension/TuneExtensionOps.cpp.inc"
30
31#define DEBUG_TYPE "transform-tune"
32#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "] ")
33
34//===----------------------------------------------------------------------===//
35// KnobOp
36//===----------------------------------------------------------------------===//
37
38void transform::tune::KnobOp::getEffects(
40 producesHandle(getOperation()->getOpResults(), effects);
41 onlyReadsPayload(effects);
42}
43
45transform::tune::KnobOp::apply(transform::TransformRewriter &rewriter,
48 if (getSelected()) {
49 results.setParams(llvm::cast<OpResult>(getResult()), *getSelected());
51 }
52
53 return emitDefiniteFailure()
54 << "non-deterministic choice " << getName()
55 << " is only resolved through providing a `selected` attr";
56}
57
58LogicalResult transform::tune::KnobOp::verify() {
59 if (auto selected = getSelected()) {
60 if (auto optionsArray = dyn_cast<ArrayAttr>(getOptions())) {
61 if (!llvm::is_contained(optionsArray, selected))
62 return emitOpError("provided `selected` attribute is not an element of "
63 "`options` array of attributes");
64 } else
65 LLVM_DEBUG(DBGS() << "cannot verify `selected` attribute " << selected
66 << " is an element of `options` attribute "
67 << getOptions());
68 }
69
70 return success();
71}
72
73//===----------------------------------------------------------------------===//
74// AlternativesOp
75//===----------------------------------------------------------------------===//
76
78 OpAsmParser &parser, IntegerAttr &selectedRegionAttr,
79 std::optional<OpAsmParser::UnresolvedOperand> &selectedRegionParam) {
80 size_t selectedRegionIdx;
81 OptionalParseResult attrParseRes =
82 parser.parseOptionalInteger(selectedRegionIdx);
83 if (attrParseRes.has_value()) {
84 if (failed(*attrParseRes))
85 return failure();
86
87 selectedRegionAttr = parser.getBuilder().getIndexAttr(selectedRegionIdx);
88 return success();
89 }
90
92 auto paramParseRes = parser.parseOptionalOperand(param);
93 if (paramParseRes.has_value()) {
94 if (failed(*paramParseRes))
95 return failure();
96
97 selectedRegionParam = param;
98 return success();
99 }
100
101 return parser.emitError(parser.getCurrentLocation())
102 << "expected either an integer attribute or a transform.param operand";
103}
104
106 Operation *op,
107 IntegerAttr selectedRegionAttr,
108 Value selectedRegionParam) {
109 if (selectedRegionAttr)
110 printer << selectedRegionAttr.getValue();
111 if (selectedRegionParam)
112 printer << selectedRegionParam;
113}
114
115OperandRange transform::tune::AlternativesOp::getEntrySuccessorOperands(
116 RegionSuccessor successor) {
117 // No operands will be forwarded to the region(s).
118 return getOperands().slice(0, 0);
119}
120
121void transform::tune::AlternativesOp::getSuccessorRegions(
123 if (point.isParent())
124 if (auto selectedRegionIdx = getSelectedRegionAttr())
125 regions.emplace_back(
126 &getAlternatives()[selectedRegionIdx->getSExtValue()]);
127 else
128 for (Region &alternative : getAlternatives())
129 regions.emplace_back(&alternative);
130 else
131 regions.push_back(RegionSuccessor::parent());
132}
133
135transform::tune::AlternativesOp::getSuccessorInputs(RegionSuccessor successor) {
136 return successor.isParent() ? ValueRange(getOperation()->getResults())
137 : ValueRange();
138}
139
140void transform::tune::AlternativesOp::getRegionInvocationBounds(
142 (void)operands;
143 bounds.reserve(getNumRegions());
144
145 if (auto selectedRegionIdx = getSelectedRegionAttr()) {
146 bounds.resize(getNumRegions(), InvocationBounds(0, 0));
147 bounds[selectedRegionIdx->getSExtValue()] = InvocationBounds(1, 1);
148 } else {
149 bounds.resize(getNumRegions(), InvocationBounds(0, 1));
150 }
151}
152
153void transform::tune::AlternativesOp::getEffects(
155 onlyReadsHandle(getSelectedRegionParamMutable(), effects);
156 producesHandle(getOperation()->getOpResults(), effects);
157 // TODO: should effects from regions be forwarded?
158}
159
161transform::tune::AlternativesOp::apply(transform::TransformRewriter &rewriter,
164 std::optional<int64_t> selectedRegionIdx;
165
166 if (auto selectedRegionAttr = getSelectedRegionAttr())
167 selectedRegionIdx = selectedRegionAttr->getSExtValue();
168
169 if (Value selectedRegionParam = getSelectedRegionParam()) {
170 ArrayRef<Attribute> associatedAttrs = state.getParams(selectedRegionParam);
171 IntegerAttr selectedRegionAttr;
172 if (associatedAttrs.size() != 1 ||
173 !(selectedRegionAttr = dyn_cast<IntegerAttr>(associatedAttrs[0])))
174 return emitDefiniteFailure()
175 << "param should hold exactly one integer attribute, got: "
176 << associatedAttrs[0];
177 selectedRegionIdx = selectedRegionAttr.getValue().getSExtValue();
178 }
179
180 if (!selectedRegionIdx)
181 return emitDefiniteFailure() << "non-deterministic choice " << getName()
182 << " is only resolved through providing a "
183 "`selected_region` attr/param";
184
185 if (*selectedRegionIdx < 0 || *selectedRegionIdx >= getNumRegions())
186 return emitDefiniteFailure()
187 << "'selected_region' attribute/param specifies region at index "
188 << *selectedRegionIdx << " while op has only " << getNumRegions()
189 << " regions";
190
191 Region &selectedRegion = getRegion(*selectedRegionIdx);
192 auto scope = state.make_region_scope(selectedRegion);
193 Block &block = selectedRegion.front();
194 // Apply the region's ops one by one.
195 for (Operation &transform : block.without_terminator()) {
197 state.applyTransform(cast<transform::TransformOpInterface>(transform));
198 if (result.isDefiniteFailure())
199 return result;
200
201 if (result.isSilenceableFailure()) {
202 for (const auto &res : getResults())
203 results.set(res, {});
204 return result;
205 }
206 }
207 // Forward the operation mapping for values yielded from the region to the
208 // values produced by the alternatives op.
209 transform::detail::forwardTerminatorOperands(&block, state, results);
211}
212
213LogicalResult transform::tune::AlternativesOp::verify() {
214 for (auto *region : getRegions()) {
215 auto yieldTerminator =
216 llvm::dyn_cast_if_present<transform::YieldOp>(region->front().back());
217 if (!yieldTerminator)
218 return emitOpError() << "expected '"
219 << transform::YieldOp::getOperationName()
220 << "' as terminator";
221
222 if (yieldTerminator->getNumOperands() != getNumResults())
223 return yieldTerminator.emitOpError()
224 << "expected terminator to have as many operands as the parent op "
225 "has results";
226
227 for (auto [i, operandType, resultType] : llvm::zip_equal(
228 llvm::seq<unsigned>(0, yieldTerminator->getNumOperands()),
229 yieldTerminator->getOperands().getType(), getResultTypes())) {
230 if (operandType == resultType)
231 continue;
232 return yieldTerminator.emitOpError()
233 << "the type of the terminator operand #" << i
234 << " must match the type of the corresponding parent op result ("
235 << operandType << " vs " << resultType << ")";
236 }
237 }
238
239 if (auto selectedRegionAttr = getSelectedRegionAttr()) {
240 int64_t regionIdx = selectedRegionAttr->getSExtValue();
241 if (regionIdx < 0 || regionIdx >= getNumRegions())
242 return emitOpError()
243 << "'selected_region' attribute specifies region at index "
244 << regionIdx << " while op has only " << getNumRegions()
245 << " regions";
246 }
247
248 return success();
249}
return success()
p<< " : "<< getMemRefType()<< ", "<< getType();}static LogicalResult verifyVectorMemoryOp(Operation *op, MemRefType memrefType, VectorType vectorType) { if(memrefType.getElementType() !=vectorType.getElementType()) return op-> emitOpError("requires memref and vector types of the same elemental type")
Given a list of lists of parsed operands, populates uniqueOperands with unique operands.
static ParseResult parseAlternativesOpSelectedRegion(OpAsmParser &parser, IntegerAttr &selectedRegionAttr, std::optional< OpAsmParser::UnresolvedOperand > &selectedRegionParam)
static void printAlternativesOpSelectedRegion(OpAsmPrinter &printer, Operation *op, IntegerAttr selectedRegionAttr, Value selectedRegionParam)
#define DBGS()
virtual OptionalParseResult parseOptionalInteger(APInt &result)=0
Parse an optional integer value from the stream.
virtual Builder & getBuilder() const =0
Return a builder which provides useful access to MLIRContext, global objects like types and attribute...
virtual InFlightDiagnostic emitError(SMLoc loc, const Twine &message={})=0
Emit a diagnostic at the specified location and return failure.
virtual SMLoc getCurrentLocation()=0
Get the location of the next token and store it into the argument.
Block represents an ordered list of Operations.
Definition Block.h:33
iterator_range< iterator > without_terminator()
Return an iterator range over the operation within this block excluding the terminator operation at t...
Definition Block.h:222
IntegerAttr getIndexAttr(int64_t value)
Definition Builders.cpp:108
The result of a transform IR operation application.
static DiagnosedSilenceableFailure success()
Constructs a DiagnosedSilenceableFailure in the success state.
This class represents upper and lower bounds on the number of times a region of a RegionBranchOpInter...
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
virtual OptionalParseResult parseOptionalOperand(UnresolvedOperand &result, bool allowResultNumber=true)=0
Parse a single operand if present.
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
This class implements the operand iterators for the Operation class.
Definition ValueRange.h:43
Operation is the basic unit of execution within MLIR.
Definition Operation.h:88
This class implements Optional functionality for ParseResult.
bool has_value() const
Returns true if we contain a valid ParseResult value.
This class represents a point being branched from in the methods of the RegionBranchOpInterface.
bool isParent() const
Returns true if branching from the parent op.
This class represents a successor of a region.
static RegionSuccessor parent()
Initialize a successor that branches after/out of the parent operation.
bool isParent() const
Return true if the successor is the parent operation.
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition Region.h:26
Block & front()
Definition Region.h:65
This class provides an abstraction over the different types of ranges over Values.
Definition ValueRange.h:387
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
Local mapping between values defined by a specific op implementing the TransformOpInterface and the p...
void setParams(OpResult value, ArrayRef< TransformState::Param > params)
Indicates that the result of the transform IR op at the given position corresponds to the given list ...
void set(OpResult value, Range &&ops)
Indicates that the result of the transform IR op at the given position corresponds to the given list ...
This is a special rewriter to be used in transform op implementations, providing additional helper fu...
The state maintained across applications of various ops implementing the TransformOpInterface.
DiagnosedSilenceableFailure applyTransform(TransformOpInterface transform)
Applies the transformation specified by the given transform op and updates the state accordingly.
RegionScope make_region_scope(Region &region)
Creates a new region scope for the given region.
ArrayRef< Attribute > getParams(Value value) const
Returns the list of parameters that the given transform IR value corresponds to.
void forwardTerminatorOperands(Block *block, transform::TransformState &state, transform::TransformResults &results)
Populates results with payload associations that match exactly those of the operands to block's termi...
void producesHandle(ResultRange handles, SmallVectorImpl< MemoryEffects::EffectInstance > &effects)
void onlyReadsHandle(MutableArrayRef< OpOperand > handles, SmallVectorImpl< MemoryEffects::EffectInstance > &effects)
Include the generated interface declarations.
DiagnosedDefiniteFailure emitDefiniteFailure(Location loc, const Twine &message={})
Emits a definite failure with the given message.
This is the representation of an operand reference.