13#include "llvm/Support/Debug.h"
20 OpAsmParser &parser, IntegerAttr &selectedRegionAttr,
21 std::optional<OpAsmParser::UnresolvedOperand> &selectedRegionParam);
25 IntegerAttr selectedRegionAttr,
26 Value selectedRegionParam);
29#include "mlir/Dialect/Transform/TuneExtension/TuneExtensionOps.cpp.inc"
31#define DEBUG_TYPE "transform-tune"
32#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "] ")
38void transform::tune::KnobOp::getEffects(
40 producesHandle(getOperation()->getOpResults(), effects);
41 onlyReadsPayload(effects);
49 results.
setParams(llvm::cast<OpResult>(getResult()), *getSelected());
54 <<
"non-deterministic choice " << getName()
55 <<
" is only resolved through providing a `selected` attr";
58LogicalResult transform::tune::KnobOp::verify() {
59 if (
auto selected = getSelected()) {
60 if (
auto optionsArray = dyn_cast<ArrayAttr>(getOptions())) {
61 if (!llvm::is_contained(optionsArray, selected))
62 return emitOpError(
"provided `selected` attribute is not an element of "
63 "`options` array of attributes");
65 LLVM_DEBUG(
DBGS() <<
"cannot verify `selected` attribute " << selected
66 <<
" is an element of `options` attribute "
78 OpAsmParser &parser, IntegerAttr &selectedRegionAttr,
79 std::optional<OpAsmParser::UnresolvedOperand> &selectedRegionParam) {
80 size_t selectedRegionIdx;
84 if (failed(*attrParseRes))
93 if (paramParseRes.has_value()) {
94 if (failed(*paramParseRes))
97 selectedRegionParam = param;
102 <<
"expected either an integer attribute or a transform.param operand";
107 IntegerAttr selectedRegionAttr,
108 Value selectedRegionParam) {
109 if (selectedRegionAttr)
110 printer << selectedRegionAttr.getValue();
111 if (selectedRegionParam)
112 printer << selectedRegionParam;
115OperandRange transform::tune::AlternativesOp::getEntrySuccessorOperands(
118 return getOperands().slice(0, 0);
121void transform::tune::AlternativesOp::getSuccessorRegions(
124 if (
auto selectedRegionIdx = getSelectedRegionAttr())
125 regions.emplace_back(
126 &getAlternatives()[selectedRegionIdx->getSExtValue()]);
128 for (
Region &alternative : getAlternatives())
129 regions.emplace_back(&alternative);
135transform::tune::AlternativesOp::getSuccessorInputs(
RegionSuccessor successor) {
140void transform::tune::AlternativesOp::getRegionInvocationBounds(
143 bounds.reserve(getNumRegions());
145 if (
auto selectedRegionIdx = getSelectedRegionAttr()) {
153void transform::tune::AlternativesOp::getEffects(
164 std::optional<int64_t> selectedRegionIdx;
166 if (
auto selectedRegionAttr = getSelectedRegionAttr())
167 selectedRegionIdx = selectedRegionAttr->getSExtValue();
169 if (
Value selectedRegionParam = getSelectedRegionParam()) {
171 IntegerAttr selectedRegionAttr;
172 if (associatedAttrs.size() != 1 ||
173 !(selectedRegionAttr = dyn_cast<IntegerAttr>(associatedAttrs[0])))
175 <<
"param should hold exactly one integer attribute, got: "
176 << associatedAttrs[0];
177 selectedRegionIdx = selectedRegionAttr.getValue().getSExtValue();
180 if (!selectedRegionIdx)
182 <<
" is only resolved through providing a "
183 "`selected_region` attr/param";
185 if (*selectedRegionIdx < 0 || *selectedRegionIdx >= getNumRegions())
187 <<
"'selected_region' attribute/param specifies region at index "
188 << *selectedRegionIdx <<
" while op has only " << getNumRegions()
191 Region &selectedRegion = getRegion(*selectedRegionIdx);
198 if (
result.isDefiniteFailure())
201 if (
result.isSilenceableFailure()) {
202 for (
const auto &res : getResults())
203 results.
set(res, {});
213LogicalResult transform::tune::AlternativesOp::verify() {
214 for (
auto *region : getRegions()) {
215 auto yieldTerminator =
216 llvm::dyn_cast_if_present<transform::YieldOp>(region->front().back());
217 if (!yieldTerminator)
219 << transform::YieldOp::getOperationName()
220 <<
"' as terminator";
222 if (yieldTerminator->getNumOperands() != getNumResults())
223 return yieldTerminator.emitOpError()
224 <<
"expected terminator to have as many operands as the parent op "
227 for (
auto [i, operandType, resultType] : llvm::zip_equal(
228 llvm::seq<unsigned>(0, yieldTerminator->getNumOperands()),
229 yieldTerminator->getOperands().getType(), getResultTypes())) {
230 if (operandType == resultType)
232 return yieldTerminator.emitOpError()
233 <<
"the type of the terminator operand #" << i
234 <<
" must match the type of the corresponding parent op result ("
235 << operandType <<
" vs " << resultType <<
")";
239 if (
auto selectedRegionAttr = getSelectedRegionAttr()) {
240 int64_t regionIdx = selectedRegionAttr->getSExtValue();
241 if (regionIdx < 0 || regionIdx >= getNumRegions())
243 <<
"'selected_region' attribute specifies region at index "
244 << regionIdx <<
" while op has only " << getNumRegions()
p<< " : "<< getMemRefType()<< ", "<< getType();}static LogicalResult verifyVectorMemoryOp(Operation *op, MemRefType memrefType, VectorType vectorType) { if(memrefType.getElementType() !=vectorType.getElementType()) return op-> emitOpError("requires memref and vector types of the same elemental type")
Given a list of lists of parsed operands, populates uniqueOperands with unique operands.
static ParseResult parseAlternativesOpSelectedRegion(OpAsmParser &parser, IntegerAttr &selectedRegionAttr, std::optional< OpAsmParser::UnresolvedOperand > &selectedRegionParam)
static void printAlternativesOpSelectedRegion(OpAsmPrinter &printer, Operation *op, IntegerAttr selectedRegionAttr, Value selectedRegionParam)
virtual OptionalParseResult parseOptionalInteger(APInt &result)=0
Parse an optional integer value from the stream.
virtual Builder & getBuilder() const =0
Return a builder which provides useful access to MLIRContext, global objects like types and attribute...
virtual InFlightDiagnostic emitError(SMLoc loc, const Twine &message={})=0
Emit a diagnostic at the specified location and return failure.
virtual SMLoc getCurrentLocation()=0
Get the location of the next token and store it into the argument.
Block represents an ordered list of Operations.
iterator_range< iterator > without_terminator()
Return an iterator range over the operation within this block excluding the terminator operation at t...
IntegerAttr getIndexAttr(int64_t value)
The result of a transform IR operation application.
static DiagnosedSilenceableFailure success()
Constructs a DiagnosedSilenceableFailure in the success state.
This class represents upper and lower bounds on the number of times a region of a RegionBranchOpInter...
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
virtual OptionalParseResult parseOptionalOperand(UnresolvedOperand &result, bool allowResultNumber=true)=0
Parse a single operand if present.
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
This class implements the operand iterators for the Operation class.
Operation is the basic unit of execution within MLIR.
This class implements Optional functionality for ParseResult.
bool has_value() const
Returns true if we contain a valid ParseResult value.
This class represents a point being branched from in the methods of the RegionBranchOpInterface.
bool isParent() const
Returns true if branching from the parent op.
This class represents a successor of a region.
static RegionSuccessor parent()
Initialize a successor that branches after/out of the parent operation.
bool isParent() const
Return true if the successor is the parent operation.
This class contains a list of basic blocks and a link to the parent operation it is attached to.
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Include the generated interface declarations.
DiagnosedDefiniteFailure emitDefiniteFailure(Location loc, const Twine &message={})
Emits a definite failure with the given message.
This is the representation of an operand reference.