14 #include "llvm/ADT/DenseSet.h"
15 #include "llvm/ADT/TypeSwitch.h"
21 #include "mlir/Dialect/PDL/IR/PDLOpsDialect.cpp.inc"
27 void PDLDialect::initialize() {
30 #include "mlir/Dialect/PDL/IR/PDLOps.cpp.inc"
52 if (!llvm::isa_and_nonnull<PatternOp>(op->
getParentOp()))
57 "expected a bindable user when defined in the matcher body of a "
65 if (!isa<PatternOp>(op->
getParentOp()) || isa<RewriteOp>(op))
69 if (visited.contains(op))
77 .Case<OperationOp>([&visited](
auto operation) {
78 for (
Value operand : operation.getOperandValues())
79 visit(operand.getDefiningOp(), visited);
81 .Case<ResultOp, ResultsOp>([&visited](
auto result) {
82 visit(result.getParent().getDefiningOp(), visited);
95 if (getNumOperands() == 0)
96 return emitOpError(
"expected at least one argument");
97 if (llvm::any_of(getResults(), [](
OpResult result) {
98 return isa<OperationType>(result.
getType());
101 "returning an operation from a constraint is not supported");
111 if (getNumOperands() == 0 && getNumResults() == 0)
112 return emitOpError(
"expected at least one argument or result");
121 Value attrType = getValueType();
122 std::optional<Attribute> attrValue = getValue();
125 if (isa<RewriteOp>((*this)->getParentOp()))
127 "expected constant value when specified within a `pdl.rewrite`");
131 return emitOpError(
"expected only one of [`type`, `value`] to be set");
154 ArrayAttr &attrNamesAttr) {
158 auto parseOperands = [&]() {
164 attrNames.push_back(nameAttr);
165 attrOperands.push_back(operand);
177 ArrayAttr attrNames) {
178 if (attrNames.empty())
181 interleaveComma(llvm::seq<int>(0, attrNames.size()), p,
182 [&](
int i) { p << attrNames[i] <<
" = " << attrArgs[i]; });
192 auto canInferTypeFromUse = [&](
OpOperand &use) {
195 ReplaceOp replOpUser = dyn_cast<ReplaceOp>(use.getOwner());
196 if (!replOpUser || use.getOperandNumber() == 0)
199 Operation *replacedOp = replOpUser.getOpValue().getDefiningOp();
200 return replacedOp->
getBlock() != rewriterBlock ||
206 if (llvm::any_of(op.getOp().
getUses(), canInferTypeFromUse))
210 if (resultTypes.empty()) {
213 std::optional<StringRef> rawOpName = op.getOpName();
216 std::optional<RegisteredOperationName> opName =
225 bool expectedAtLeastOneResult =
228 if (expectedAtLeastOneResult) {
230 .
emitOpError(
"must have inferable or constrained result types when "
231 "nested within `pdl.rewrite`")
233 .
append(
"operation is created in a non-inferrable context, but '",
234 *opName,
"' does not implement InferTypeOpInterface");
241 Operation *resultTypeOp = it.value().getDefiningOp();
242 assert(resultTypeOp &&
"expected valid result type operation");
246 if (isa<ApplyNativeRewriteOp>(resultTypeOp))
251 auto constrainsInput = [rewriterBlock](
Operation *user) {
252 return user->getBlock() != rewriterBlock &&
253 isa<OperandOp, OperandsOp, OperationOp>(user);
255 if (TypeOp typeOp = dyn_cast<TypeOp>(resultTypeOp)) {
256 if (typeOp.getConstantType() ||
257 llvm::any_of(typeOp->getUsers(), constrainsInput))
259 }
else if (TypesOp typeOp = dyn_cast<TypesOp>(resultTypeOp)) {
260 if (typeOp.getConstantTypes() ||
261 llvm::any_of(typeOp->getUsers(), constrainsInput))
266 .
emitOpError(
"must have inferable or constrained result types when "
267 "nested within `pdl.rewrite`")
269 .
append(
"result type #", it.index(),
" was not constrained");
275 bool isWithinRewrite = isa_and_nonnull<RewriteOp>((*this)->getParentOp());
276 if (isWithinRewrite && !getOpName())
277 return emitOpError(
"must have an operation name when nested within "
279 ArrayAttr attributeNames = getAttributeValueNamesAttr();
280 auto attributeValues = getAttributeValues();
281 if (attributeNames.size() != attributeValues.size()) {
283 <<
"expected the same number of attribute values and attribute "
285 << attributeNames.size() <<
" names and " << attributeValues.size()
291 if (isWithinRewrite && !mightHaveTypeInference()) {
299 bool OperationOp::hasTypeInference() {
300 if (std::optional<StringRef> rawOpName = getOpName()) {
302 return opName.hasInterface<InferTypeOpInterface>();
307 bool OperationOp::mightHaveTypeInference() {
308 if (std::optional<StringRef> rawOpName = getOpName()) {
310 return opName.mightHaveInterface<InferTypeOpInterface>();
320 Region &body = getBodyRegion();
322 auto rewriteOp = dyn_cast<RewriteOp>(term);
324 return emitOpError(
"expected body to terminate with `pdl.rewrite`")
325 .attachNote(term->
getLoc())
326 .append(
"see terminator defined here");
332 if (!isa_and_nonnull<PDLDialect>(op->
getDialect())) {
333 emitOpError(
"expected only `pdl` operations within the pattern body")
334 .attachNote(op->getLoc())
335 .append(
"see non-`pdl` operation defined here");
336 return WalkResult::interrupt();
345 return emitOpError(
"the pattern must contain at least one `pdl.operation`");
357 if (!isa<OperandOp, OperandsOp, ResultOp, ResultsOp, OperationOp>(op))
361 bool hasUserInRewrite =
false;
364 if (isa<RewriteOp>(user) ||
365 (region && isa<RewriteOp>(region->
getParentOp()))) {
366 hasUserInRewrite =
true;
372 if (!hasUserInRewrite)
379 }
else if (!visited.count(&op)) {
381 return emitOpError(
"the operations must form a connected component")
383 .append(
"see a disconnected value / operation here");
391 std::optional<uint16_t> benefit,
392 std::optional<StringRef> name) {
395 state.regions[0]->emplaceBlock();
399 RewriteOp PatternOp::getRewriter() {
400 return cast<RewriteOp>(getBodyRegion().front().getTerminator());
404 StringRef PatternOp::getDefaultDialect() {
405 return PDLDialect::getDialectNamespace();
415 if (!argumentTypes.empty()) {
425 if (argumentTypes.empty())
426 p <<
": " << resultType;
430 Type elementType = getType().getElementType();
431 for (
Type operandType : getOperandTypes()) {
433 if (operandElementType != elementType) {
434 return emitOpError(
"expected operand to have element type ")
435 << elementType <<
", but got " << operandElementType;
446 if (getReplOperation() && !getReplValues().empty())
447 return emitOpError() <<
"expected no replacement values to be provided"
448 " when the replacement operation is present";
468 IntegerAttr index,
Type resultType) {
470 p <<
" -> " << resultType;
474 if (!
getIndex() && llvm::isa<pdl::ValueType>(getType())) {
475 return emitOpError() <<
"expected `pdl.range<value>` result type when "
476 "no index is specified, but got: "
487 Region &rewriteRegion = getBodyRegion();
491 if (!rewriteRegion.
empty()) {
493 <<
"expected rewrite region to be empty when rewrite is external";
499 if (rewriteRegion.
empty()) {
500 return emitOpError() <<
"expected rewrite region to be non-empty if "
501 "external name is not specified";
505 if (!getExternalArgs().empty()) {
506 return emitOpError() <<
"expected no external arguments when the "
507 "rewrite is specified inline";
514 StringRef RewriteOp::getDefaultDialect() {
515 return PDLDialect::getDialectNamespace();
523 if (!getConstantTypeAttr())
533 if (!getConstantTypesAttr())
542 #define GET_OP_CLASSES
543 #include "mlir/Dialect/PDL/IR/PDLOps.cpp.inc"
static ParseResult parseResultsValueType(OpAsmParser &p, IntegerAttr index, Type &resultType)
static LogicalResult verifyHasBindingUse(Operation *op)
Returns success if the given operation is not in the main matcher body or is used by a "binding" oper...
static void visit(Operation *op, DenseSet< Operation * > &visited)
Visits all the pdl.operand(s), pdl.result(s), and pdl.operation(s) connected to the given operation.
static LogicalResult verifyResultTypesAreInferrable(OperationOp op, OperandRange resultTypes)
Verifies that the result types of this operation, defined within a pdl.rewrite, can be inferred.
static void printRangeType(OpAsmPrinter &p, RangeOp op, TypeRange argumentTypes, Type resultType)
static ParseResult parseRangeType(OpAsmParser &p, TypeRange argumentTypes, Type &resultType)
static ParseResult parseOperationOpAttributes(OpAsmParser &p, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &attrOperands, ArrayAttr &attrNamesAttr)
static bool hasBindingUse(Operation *op)
Returns true if the given operation is used by a "binding" pdl operation.
static void printOperationOpAttributes(OpAsmPrinter &p, OperationOp op, OperandRange attrArgs, ArrayAttr attrNames)
static void printResultsValueType(OpAsmPrinter &p, ResultsOp op, IntegerAttr index, Type resultType)
static MLIRContext * getContext(OpFoldResult val)
virtual ParseResult parseCommaSeparatedList(Delimiter delimiter, function_ref< ParseResult()> parseElementFn, StringRef contextMessage=StringRef())=0
Parse a list of comma-separated items with an optional delimiter.
virtual Builder & getBuilder() const =0
Return a builder which provides useful access to MLIRContext, global objects like types and attribute...
virtual ParseResult parseRBrace()=0
Parse a } token.
virtual ParseResult parseEqual()=0
Parse a = token.
virtual ParseResult parseColonType(Type &result)=0
Parse a colon followed by a type.
virtual ParseResult parseArrow()=0
Parse a '->' token.
virtual ParseResult parseType(Type &result)=0
Parse a type.
virtual ParseResult parseAttribute(Attribute &result, Type type={})=0
Parse an arbitrary attribute of a given type and return it in result.
virtual ParseResult parseOptionalLBrace()=0
Parse a { token if present.
Block represents an ordered list of Operations.
Operation * getTerminator()
Get the terminator operation of this block.
iterator_range< op_iterator< OpT > > getOps()
Return an iterator range over the operations within this block that are of 'OpT'.
This class is a general helper class for creating context-global objects like types,...
IntegerAttr getI16IntegerAttr(int16_t value)
Ty getType(Args &&...args)
Get or construct an instance of the type Ty with provided arguments.
StringAttr getStringAttr(const Twine &bytes)
ArrayAttr getArrayAttr(ArrayRef< Attribute > value)
Diagnostic & append(Arg1 &&arg1, Arg2 &&arg2, Args &&...args)
Append arguments to the diagnostic.
Diagnostic & attachNote(std::optional< Location > noteLoc=std::nullopt)
Attaches a note to this diagnostic.
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
virtual ParseResult parseOperand(UnresolvedOperand &result, bool allowResultNumber=true)=0
Parse a single SSA value operand name along with a result number if allowResultNumber is true.
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
This class helps build Operations.
This class represents an operand of an operation.
This is a value defined by a result of an operation.
This class provides the API for ops which have an unknown number of results.
This class provides return value APIs for ops that are known to have zero results.
This class implements the operand iterators for the Operation class.
Operation is the basic unit of execution within MLIR.
Dialect * getDialect()
Return the dialect this operation is associated with, or nullptr if the associated dialect is not loa...
bool isBeforeInBlock(Operation *other)
Given an operation 'other' that is within the same parent block, return whether the current operation...
MLIRContext * getContext()
Return the context this operation is associated with.
Location getLoc()
The source location the operation was defined or derived from.
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
Block * getBlock()
Returns the operation block that contains this operation.
user_range getUsers()
Returns a range of all users.
use_range getUses()
Returns a range of all uses, which is useful for iterating over all uses.
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
This class represents success/failure for parsing-like operations that find it important to chain tog...
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Region * getParentRegion()
Return the region containing this region or nullptr if the region is attached to a top-level operatio...
Operation * getParentOp()
Return the parent operation this region is attached to.
RetT walk(FnT &&callback)
Walk all nested operations, blocks or regions (including this region), depending on the type of callb...
static std::optional< RegisteredOperationName > lookup(StringRef name, MLIRContext *ctx)
Lookup the registered operation information for the given operation.
This class provides an abstraction over the various different ranges of value types.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
A utility result that is used to signal how to proceed with an ongoing walk:
static WalkResult advance()
bool wasInterrupted() const
Returns true if the walk was interrupted.
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Type getRangeElementTypeOrSelf(Type type)
If the given type is a range, return its element type, otherwise return the type itself.
MPInt getIndex(const ConeV &cone)
Get the index of a cone, i.e., the volume of the parallelepiped spanned by its generators,...
Include the generated interface declarations.
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
bool succeeded(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a success value.
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs,...
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value.
This class represents an efficient way to signal success or failure.
This is the representation of an operand reference.
This represents an operation in an abstracted form, suitable for use with the builder APIs.