13 #include "llvm/Support/Debug.h"
20 OpAsmParser &parser, IntegerAttr &selectedRegionAttr,
21 std::optional<OpAsmParser::UnresolvedOperand> &selectedRegionParam);
25 IntegerAttr selectedRegionAttr,
26 Value selectedRegionParam);
28 #define GET_OP_CLASSES
29 #include "mlir/Dialect/Transform/TuneExtension/TuneExtensionOps.cpp.inc"
31 #define DEBUG_TYPE "transform-tune"
32 #define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "] ")
38 void transform::tune::KnobOp::getEffects(
49 results.
setParams(llvm::cast<OpResult>(getResult()), *getSelected());
54 <<
"non-deterministic choice " << getName()
55 <<
" is only resolved through providing a `selected` attr";
59 if (
auto selected = getSelected()) {
60 if (
auto optionsArray = dyn_cast<ArrayAttr>(getOptions())) {
61 if (!llvm::is_contained(optionsArray, selected))
62 return emitOpError(
"provided `selected` attribute is not an element of "
63 "`options` array of attributes");
65 LLVM_DEBUG(
DBGS() <<
"cannot verify `selected` attribute " << selected
66 <<
" is an element of `options` attribute "
78 OpAsmParser &parser, IntegerAttr &selectedRegionAttr,
79 std::optional<OpAsmParser::UnresolvedOperand> &selectedRegionParam) {
80 size_t selectedRegionIdx;
93 if (paramParseRes.has_value()) {
94 if (
failed(*paramParseRes))
97 selectedRegionParam = param;
102 <<
"expected either an integer attribute or a transform.param operand";
107 IntegerAttr selectedRegionAttr,
108 Value selectedRegionParam) {
109 if (selectedRegionAttr)
110 printer << selectedRegionAttr.getValue();
111 if (selectedRegionParam)
112 printer << selectedRegionParam;
115 OperandRange transform::tune::AlternativesOp::getEntrySuccessorOperands(
118 return getOperands().slice(0, 0);
121 void transform::tune::AlternativesOp::getSuccessorRegions(
124 if (
auto selectedRegionIdx = getSelectedRegionAttr())
125 regions.emplace_back(
126 &getAlternatives()[selectedRegionIdx->getSExtValue()],
129 for (
Region &alternative : getAlternatives())
132 regions.emplace_back(getOperation(), getOperation()->getResults());
135 void transform::tune::AlternativesOp::getRegionInvocationBounds(
138 bounds.reserve(getNumRegions());
140 if (
auto selectedRegionIdx = getSelectedRegionAttr()) {
148 void transform::tune::AlternativesOp::getEffects(
159 std::optional<size_t> selectedRegionIdx;
161 if (
auto selectedRegionAttr = getSelectedRegionAttr())
162 selectedRegionIdx = selectedRegionAttr->getSExtValue();
164 if (
Value selectedRegionParam = getSelectedRegionParam()) {
166 IntegerAttr selectedRegionAttr;
167 if (associatedAttrs.size() != 1 ||
168 !(selectedRegionAttr = dyn_cast<IntegerAttr>(associatedAttrs[0])))
170 <<
"param should hold exactly one integer attribute, got: "
171 << associatedAttrs[0];
172 selectedRegionIdx = selectedRegionAttr.getValue().getSExtValue();
175 if (!selectedRegionIdx)
177 <<
" is only resolved through providing a "
178 "`selected_region` attr/param";
180 if (*selectedRegionIdx < 0 || *selectedRegionIdx >= getNumRegions())
182 <<
"'selected_region' attribute/param specifies region at index "
183 << *selectedRegionIdx <<
" while op has only " << getNumRegions()
186 Region &selectedRegion = getRegion(*selectedRegionIdx);
187 auto scope = state.make_region_scope(selectedRegion);
192 state.applyTransform(cast<transform::TransformOpInterface>(transform));
197 for (
const auto &res : getResults())
198 results.
set(res, {});
209 for (
auto *region : getRegions()) {
210 auto yieldTerminator =
211 llvm::dyn_cast_if_present<transform::YieldOp>(region->front().back());
212 if (!yieldTerminator)
213 return emitOpError() <<
"expected '"
214 << transform::YieldOp::getOperationName()
215 <<
"' as terminator";
217 if (yieldTerminator->getNumOperands() != getNumResults())
218 return yieldTerminator.emitOpError()
219 <<
"expected terminator to have as many operands as the parent op "
222 for (
auto [i, operandType, resultType] : llvm::zip_equal(
223 llvm::seq<unsigned>(0, yieldTerminator->getNumOperands()),
224 yieldTerminator->getOperands().getType(), getResultTypes())) {
225 if (operandType == resultType)
227 return yieldTerminator.emitOpError()
228 <<
"the type of the terminator operand #" << i
229 <<
" must match the type of the corresponding parent op result ("
230 << operandType <<
" vs " << resultType <<
")";
234 if (
auto selectedRegionAttr = getSelectedRegionAttr()) {
235 size_t regionIdx = selectedRegionAttr->getSExtValue();
236 if (regionIdx < 0 || regionIdx >= getNumRegions())
238 <<
"'selected_region' attribute specifies region at index "
239 << regionIdx <<
" while op has only " << getNumRegions()
static ParseResult parseAlternativesOpSelectedRegion(OpAsmParser &parser, IntegerAttr &selectedRegionAttr, std::optional< OpAsmParser::UnresolvedOperand > &selectedRegionParam)
static void printAlternativesOpSelectedRegion(OpAsmPrinter &printer, Operation *op, IntegerAttr selectedRegionAttr, Value selectedRegionParam)
virtual OptionalParseResult parseOptionalInteger(APInt &result)=0
Parse an optional integer value from the stream.
virtual Builder & getBuilder() const =0
Return a builder which provides useful access to MLIRContext, global objects like types and attribute...
virtual InFlightDiagnostic emitError(SMLoc loc, const Twine &message={})=0
Emit a diagnostic at the specified location and return failure.
virtual SMLoc getCurrentLocation()=0
Get the location of the next token and store it into the argument.
Block represents an ordered list of Operations.
MutableArrayRef< BlockArgument > BlockArgListType
iterator_range< iterator > without_terminator()
Return an iterator range over the operation within this block excluding the terminator operation at t...
IntegerAttr getIndexAttr(int64_t value)
The result of a transform IR operation application.
static DiagnosedSilenceableFailure success()
Constructs a DiagnosedSilenceableFailure in the success state.
bool isDefiniteFailure() const
Returns true if this is a definite failure.
bool isSilenceableFailure() const
Returns true if this is a silenceable failure.
This class represents upper and lower bounds on the number of times a region of a RegionBranchOpInter...
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
virtual OptionalParseResult parseOptionalOperand(UnresolvedOperand &result, bool allowResultNumber=true)=0
Parse a single operand if present.
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
This class implements the operand iterators for the Operation class.
Operation is the basic unit of execution within MLIR.
This class implements Optional functionality for ParseResult.
bool has_value() const
Returns true if we contain a valid ParseResult value.
This class represents a point being branched from in the methods of the RegionBranchOpInterface.
bool isParent() const
Returns true if branching from the parent op.
This class represents a successor of a region.
This class contains a list of basic blocks and a link to the parent operation it is attached to.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Include the generated interface declarations.
DiagnosedDefiniteFailure emitDefiniteFailure(Location loc, const Twine &message={})
Emits a definite failure with the given message.
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs,...
This is the representation of an operand reference.