MLIR 22.0.0git
TuneExtensionOps.cpp
Go to the documentation of this file.
1//===- TuneExtensionOps.cpp - Tune extension for the Transform dialect ----===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
13#include "llvm/Support/Debug.h"
14
16
17using namespace mlir;
18
19static ParseResult parseAlternativesOpSelectedRegion(
20 OpAsmParser &parser, IntegerAttr &selectedRegionAttr,
21 std::optional<OpAsmParser::UnresolvedOperand> &selectedRegionParam);
22
24 Operation *op,
25 IntegerAttr selectedRegionAttr,
26 Value selectedRegionParam);
27
28#define GET_OP_CLASSES
29#include "mlir/Dialect/Transform/TuneExtension/TuneExtensionOps.cpp.inc"
30
31#define DEBUG_TYPE "transform-tune"
32#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "] ")
33
34//===----------------------------------------------------------------------===//
35// KnobOp
36//===----------------------------------------------------------------------===//
37
38void transform::tune::KnobOp::getEffects(
40 producesHandle(getOperation()->getOpResults(), effects);
41 onlyReadsPayload(effects);
42}
43
45transform::tune::KnobOp::apply(transform::TransformRewriter &rewriter,
48 if (getSelected()) {
49 results.setParams(llvm::cast<OpResult>(getResult()), *getSelected());
51 }
52
53 return emitDefiniteFailure()
54 << "non-deterministic choice " << getName()
55 << " is only resolved through providing a `selected` attr";
56}
57
58LogicalResult transform::tune::KnobOp::verify() {
59 if (auto selected = getSelected()) {
60 if (auto optionsArray = dyn_cast<ArrayAttr>(getOptions())) {
61 if (!llvm::is_contained(optionsArray, selected))
62 return emitOpError("provided `selected` attribute is not an element of "
63 "`options` array of attributes");
64 } else
65 LLVM_DEBUG(DBGS() << "cannot verify `selected` attribute " << selected
66 << " is an element of `options` attribute "
67 << getOptions());
68 }
69
70 return success();
71}
72
73//===----------------------------------------------------------------------===//
74// AlternativesOp
75//===----------------------------------------------------------------------===//
76
78 OpAsmParser &parser, IntegerAttr &selectedRegionAttr,
79 std::optional<OpAsmParser::UnresolvedOperand> &selectedRegionParam) {
80 size_t selectedRegionIdx;
81 OptionalParseResult attrParseRes =
82 parser.parseOptionalInteger(selectedRegionIdx);
83 if (attrParseRes.has_value()) {
84 if (failed(*attrParseRes))
85 return failure();
86
87 selectedRegionAttr = parser.getBuilder().getIndexAttr(selectedRegionIdx);
88 return success();
89 }
90
92 auto paramParseRes = parser.parseOptionalOperand(param);
93 if (paramParseRes.has_value()) {
94 if (failed(*paramParseRes))
95 return failure();
96
97 selectedRegionParam = param;
98 return success();
99 }
100
101 return parser.emitError(parser.getCurrentLocation())
102 << "expected either an integer attribute or a transform.param operand";
103}
104
106 Operation *op,
107 IntegerAttr selectedRegionAttr,
108 Value selectedRegionParam) {
109 if (selectedRegionAttr)
110 printer << selectedRegionAttr.getValue();
111 if (selectedRegionParam)
112 printer << selectedRegionParam;
113}
114
115OperandRange transform::tune::AlternativesOp::getEntrySuccessorOperands(
116 RegionSuccessor successor) {
117 // No operands will be forwarded to the region(s).
118 return getOperands().slice(0, 0);
119}
120
121void transform::tune::AlternativesOp::getSuccessorRegions(
123 if (point.isParent())
124 if (auto selectedRegionIdx = getSelectedRegionAttr())
125 regions.emplace_back(
126 &getAlternatives()[selectedRegionIdx->getSExtValue()],
128 else
129 for (Region &alternative : getAlternatives())
130 regions.emplace_back(&alternative, Block::BlockArgListType());
131 else
132 regions.emplace_back(getOperation(), getOperation()->getResults());
133}
134
135void transform::tune::AlternativesOp::getRegionInvocationBounds(
137 (void)operands;
138 bounds.reserve(getNumRegions());
139
140 if (auto selectedRegionIdx = getSelectedRegionAttr()) {
141 bounds.resize(getNumRegions(), InvocationBounds(0, 0));
142 bounds[selectedRegionIdx->getSExtValue()] = InvocationBounds(1, 1);
143 } else {
144 bounds.resize(getNumRegions(), InvocationBounds(0, 1));
145 }
146}
147
148void transform::tune::AlternativesOp::getEffects(
150 onlyReadsHandle(getSelectedRegionParamMutable(), effects);
151 producesHandle(getOperation()->getOpResults(), effects);
152 // TODO: should effects from regions be forwarded?
153}
154
156transform::tune::AlternativesOp::apply(transform::TransformRewriter &rewriter,
159 std::optional<size_t> selectedRegionIdx;
160
161 if (auto selectedRegionAttr = getSelectedRegionAttr())
162 selectedRegionIdx = selectedRegionAttr->getSExtValue();
163
164 if (Value selectedRegionParam = getSelectedRegionParam()) {
165 ArrayRef<Attribute> associatedAttrs = state.getParams(selectedRegionParam);
166 IntegerAttr selectedRegionAttr;
167 if (associatedAttrs.size() != 1 ||
168 !(selectedRegionAttr = dyn_cast<IntegerAttr>(associatedAttrs[0])))
169 return emitDefiniteFailure()
170 << "param should hold exactly one integer attribute, got: "
171 << associatedAttrs[0];
172 selectedRegionIdx = selectedRegionAttr.getValue().getSExtValue();
173 }
174
175 if (!selectedRegionIdx)
176 return emitDefiniteFailure() << "non-deterministic choice " << getName()
177 << " is only resolved through providing a "
178 "`selected_region` attr/param";
179
180 if (*selectedRegionIdx < 0 || *selectedRegionIdx >= getNumRegions())
181 return emitDefiniteFailure()
182 << "'selected_region' attribute/param specifies region at index "
183 << *selectedRegionIdx << " while op has only " << getNumRegions()
184 << " regions";
185
186 Region &selectedRegion = getRegion(*selectedRegionIdx);
187 auto scope = state.make_region_scope(selectedRegion);
188 Block &block = selectedRegion.front();
189 // Apply the region's ops one by one.
190 for (Operation &transform : block.without_terminator()) {
192 state.applyTransform(cast<transform::TransformOpInterface>(transform));
193 if (result.isDefiniteFailure())
194 return result;
195
196 if (result.isSilenceableFailure()) {
197 for (const auto &res : getResults())
198 results.set(res, {});
199 return result;
200 }
201 }
202 // Forward the operation mapping for values yielded from the region to the
203 // values produced by the alternatives op.
204 transform::detail::forwardTerminatorOperands(&block, state, results);
206}
207
208LogicalResult transform::tune::AlternativesOp::verify() {
209 for (auto *region : getRegions()) {
210 auto yieldTerminator =
211 llvm::dyn_cast_if_present<transform::YieldOp>(region->front().back());
212 if (!yieldTerminator)
213 return emitOpError() << "expected '"
214 << transform::YieldOp::getOperationName()
215 << "' as terminator";
216
217 if (yieldTerminator->getNumOperands() != getNumResults())
218 return yieldTerminator.emitOpError()
219 << "expected terminator to have as many operands as the parent op "
220 "has results";
221
222 for (auto [i, operandType, resultType] : llvm::zip_equal(
223 llvm::seq<unsigned>(0, yieldTerminator->getNumOperands()),
224 yieldTerminator->getOperands().getType(), getResultTypes())) {
225 if (operandType == resultType)
226 continue;
227 return yieldTerminator.emitOpError()
228 << "the type of the terminator operand #" << i
229 << " must match the type of the corresponding parent op result ("
230 << operandType << " vs " << resultType << ")";
231 }
232 }
233
234 if (auto selectedRegionAttr = getSelectedRegionAttr()) {
235 size_t regionIdx = selectedRegionAttr->getSExtValue();
236 if (regionIdx < 0 || regionIdx >= getNumRegions())
237 return emitOpError()
238 << "'selected_region' attribute specifies region at index "
239 << regionIdx << " while op has only " << getNumRegions()
240 << " regions";
241 }
242
243 return success();
244}
return success()
p<< " : "<< getMemRefType()<< ", "<< getType();}static LogicalResult verifyVectorMemoryOp(Operation *op, MemRefType memrefType, VectorType vectorType) { if(memrefType.getElementType() !=vectorType.getElementType()) return op-> emitOpError("requires memref and vector types of the same elemental type")
Given a list of lists of parsed operands, populates uniqueOperands with unique operands.
static ParseResult parseAlternativesOpSelectedRegion(OpAsmParser &parser, IntegerAttr &selectedRegionAttr, std::optional< OpAsmParser::UnresolvedOperand > &selectedRegionParam)
static void printAlternativesOpSelectedRegion(OpAsmPrinter &printer, Operation *op, IntegerAttr selectedRegionAttr, Value selectedRegionParam)
#define DBGS()
virtual OptionalParseResult parseOptionalInteger(APInt &result)=0
Parse an optional integer value from the stream.
virtual Builder & getBuilder() const =0
Return a builder which provides useful access to MLIRContext, global objects like types and attribute...
virtual InFlightDiagnostic emitError(SMLoc loc, const Twine &message={})=0
Emit a diagnostic at the specified location and return failure.
virtual SMLoc getCurrentLocation()=0
Get the location of the next token and store it into the argument.
Block represents an ordered list of Operations.
Definition Block.h:33
MutableArrayRef< BlockArgument > BlockArgListType
Definition Block.h:85
iterator_range< iterator > without_terminator()
Return an iterator range over the operation within this block excluding the terminator operation at t...
Definition Block.h:212
IntegerAttr getIndexAttr(int64_t value)
Definition Builders.cpp:108
The result of a transform IR operation application.
static DiagnosedSilenceableFailure success()
Constructs a DiagnosedSilenceableFailure in the success state.
This class represents upper and lower bounds on the number of times a region of a RegionBranchOpInter...
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
virtual OptionalParseResult parseOptionalOperand(UnresolvedOperand &result, bool allowResultNumber=true)=0
Parse a single operand if present.
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
This class implements the operand iterators for the Operation class.
Definition ValueRange.h:43
Operation is the basic unit of execution within MLIR.
Definition Operation.h:88
This class implements Optional functionality for ParseResult.
bool has_value() const
Returns true if we contain a valid ParseResult value.
This class represents a point being branched from in the methods of the RegionBranchOpInterface.
bool isParent() const
Returns true if branching from the parent op.
This class represents a successor of a region.
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition Region.h:26
Block & front()
Definition Region.h:65
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
Local mapping between values defined by a specific op implementing the TransformOpInterface and the p...
void setParams(OpResult value, ArrayRef< TransformState::Param > params)
Indicates that the result of the transform IR op at the given position corresponds to the given list ...
void set(OpResult value, Range &&ops)
Indicates that the result of the transform IR op at the given position corresponds to the given list ...
This is a special rewriter to be used in transform op implementations, providing additional helper fu...
The state maintained across applications of various ops implementing the TransformOpInterface.
DiagnosedSilenceableFailure applyTransform(TransformOpInterface transform)
Applies the transformation specified by the given transform op and updates the state accordingly.
RegionScope make_region_scope(Region &region)
Creates a new region scope for the given region.
ArrayRef< Attribute > getParams(Value value) const
Returns the list of parameters that the given transform IR value corresponds to.
void forwardTerminatorOperands(Block *block, transform::TransformState &state, transform::TransformResults &results)
Populates results with payload associations that match exactly those of the operands to block's termi...
void producesHandle(ResultRange handles, SmallVectorImpl< MemoryEffects::EffectInstance > &effects)
void onlyReadsHandle(MutableArrayRef< OpOperand > handles, SmallVectorImpl< MemoryEffects::EffectInstance > &effects)
Include the generated interface declarations.
DiagnosedDefiniteFailure emitDefiniteFailure(Location loc, const Twine &message={})
Emits a definite failure with the given message.
This is the representation of an operand reference.