14 #include "llvm/ADT/DenseSet.h"
15 #include "llvm/ADT/TypeSwitch.h"
21 #include "mlir/Dialect/PDL/IR/PDLOpsDialect.cpp.inc"
27 void PDLDialect::initialize() {
30 #include "mlir/Dialect/PDL/IR/PDLOps.cpp.inc"
52 if (!llvm::isa_and_nonnull<PatternOp>(op->
getParentOp()))
57 "expected a bindable user when defined in the matcher body of a "
65 if (!isa<PatternOp>(op->
getParentOp()) || isa<RewriteOp>(op))
69 if (!visited.insert(op).second)
74 .Case<OperationOp>([&visited](
auto operation) {
75 for (
Value operand : operation.getOperandValues())
76 visit(operand.getDefiningOp(), visited);
78 .Case<ResultOp, ResultsOp>([&visited](
auto result) {
79 visit(result.getParent().getDefiningOp(), visited);
92 if (getNumOperands() == 0)
93 return emitOpError(
"expected at least one argument");
94 if (llvm::any_of(getResults(), [](
OpResult result) {
95 return isa<OperationType>(result.
getType());
98 "returning an operation from a constraint is not supported");
108 if (getNumOperands() == 0 && getNumResults() == 0)
109 return emitOpError(
"expected at least one argument or result");
118 Value attrType = getValueType();
119 std::optional<Attribute> attrValue = getValue();
122 if (isa<RewriteOp>((*this)->getParentOp()))
124 "expected constant value when specified within a `pdl.rewrite`");
128 return emitOpError(
"expected only one of [`type`, `value`] to be set");
151 ArrayAttr &attrNamesAttr) {
155 auto parseOperands = [&]() {
161 attrNames.push_back(nameAttr);
162 attrOperands.push_back(operand);
174 ArrayAttr attrNames) {
175 if (attrNames.empty())
178 interleaveComma(llvm::seq<int>(0, attrNames.size()), p,
179 [&](
int i) { p << attrNames[i] <<
" = " << attrArgs[i]; });
189 auto canInferTypeFromUse = [&](
OpOperand &use) {
192 ReplaceOp replOpUser = dyn_cast<ReplaceOp>(use.getOwner());
193 if (!replOpUser || use.getOperandNumber() == 0)
196 Operation *replacedOp = replOpUser.getOpValue().getDefiningOp();
197 return replacedOp->
getBlock() != rewriterBlock ||
203 if (llvm::any_of(op.getOp().
getUses(), canInferTypeFromUse))
207 if (resultTypes.empty()) {
210 std::optional<StringRef> rawOpName = op.getOpName();
213 std::optional<RegisteredOperationName> opName =
222 bool expectedAtLeastOneResult =
225 if (expectedAtLeastOneResult) {
227 .
emitOpError(
"must have inferable or constrained result types when "
228 "nested within `pdl.rewrite`")
230 .
append(
"operation is created in a non-inferrable context, but '",
231 *opName,
"' does not implement InferTypeOpInterface");
238 Operation *resultTypeOp = it.value().getDefiningOp();
239 assert(resultTypeOp &&
"expected valid result type operation");
243 if (isa<ApplyNativeRewriteOp>(resultTypeOp))
248 auto constrainsInput = [rewriterBlock](
Operation *user) {
249 return user->getBlock() != rewriterBlock &&
250 isa<OperandOp, OperandsOp, OperationOp>(user);
252 if (TypeOp typeOp = dyn_cast<TypeOp>(resultTypeOp)) {
253 if (typeOp.getConstantType() ||
254 llvm::any_of(typeOp->getUsers(), constrainsInput))
256 }
else if (TypesOp typeOp = dyn_cast<TypesOp>(resultTypeOp)) {
257 if (typeOp.getConstantTypes() ||
258 llvm::any_of(typeOp->getUsers(), constrainsInput))
263 .
emitOpError(
"must have inferable or constrained result types when "
264 "nested within `pdl.rewrite`")
266 .
append(
"result type #", it.index(),
" was not constrained");
272 bool isWithinRewrite = isa_and_nonnull<RewriteOp>((*this)->getParentOp());
273 if (isWithinRewrite && !getOpName())
274 return emitOpError(
"must have an operation name when nested within "
276 ArrayAttr attributeNames = getAttributeValueNamesAttr();
277 auto attributeValues = getAttributeValues();
278 if (attributeNames.size() != attributeValues.size()) {
280 <<
"expected the same number of attribute values and attribute "
282 << attributeNames.size() <<
" names and " << attributeValues.size()
288 if (isWithinRewrite && !mightHaveTypeInference()) {
296 bool OperationOp::hasTypeInference() {
297 if (std::optional<StringRef> rawOpName = getOpName()) {
299 return opName.hasInterface<InferTypeOpInterface>();
304 bool OperationOp::mightHaveTypeInference() {
305 if (std::optional<StringRef> rawOpName = getOpName()) {
307 return opName.mightHaveInterface<InferTypeOpInterface>();
316 LogicalResult PatternOp::verifyRegions() {
317 Region &body = getBodyRegion();
319 auto rewriteOp = dyn_cast<RewriteOp>(term);
321 return emitOpError(
"expected body to terminate with `pdl.rewrite`")
322 .attachNote(term->
getLoc())
323 .append(
"see terminator defined here");
329 if (!isa_and_nonnull<PDLDialect>(op->
getDialect())) {
330 emitOpError(
"expected only `pdl` operations within the pattern body")
331 .attachNote(op->getLoc())
332 .append(
"see non-`pdl` operation defined here");
333 return WalkResult::interrupt();
342 return emitOpError(
"the pattern must contain at least one `pdl.operation`");
354 if (!isa<OperandOp, OperandsOp, ResultOp, ResultsOp, OperationOp>(op))
358 bool hasUserInRewrite =
false;
361 if (isa<RewriteOp>(user) ||
362 (region && isa<RewriteOp>(region->
getParentOp()))) {
363 hasUserInRewrite =
true;
369 if (!hasUserInRewrite)
376 }
else if (!visited.count(&op)) {
378 return emitOpError(
"the operations must form a connected component")
380 .append(
"see a disconnected value / operation here");
388 std::optional<uint16_t> benefit,
389 std::optional<StringRef> name) {
392 state.regions[0]->emplaceBlock();
396 RewriteOp PatternOp::getRewriter() {
397 return cast<RewriteOp>(getBodyRegion().front().getTerminator());
401 StringRef PatternOp::getDefaultDialect() {
402 return PDLDialect::getDialectNamespace();
412 if (!argumentTypes.empty()) {
422 if (argumentTypes.empty())
423 p <<
": " << resultType;
428 for (
Type operandType : getOperandTypes()) {
430 if (operandElementType != elementType) {
431 return emitOpError(
"expected operand to have element type ")
432 << elementType <<
", but got " << operandElementType;
443 if (getReplOperation() && !getReplValues().empty())
444 return emitOpError() <<
"expected no replacement values to be provided"
445 " when the replacement operation is present";
465 IntegerAttr index,
Type resultType) {
467 p <<
" -> " << resultType;
472 return emitOpError() <<
"expected `pdl.range<value>` result type when "
473 "no index is specified, but got: "
483 LogicalResult RewriteOp::verifyRegions() {
484 Region &rewriteRegion = getBodyRegion();
488 if (!rewriteRegion.
empty()) {
490 <<
"expected rewrite region to be empty when rewrite is external";
496 if (rewriteRegion.
empty()) {
497 return emitOpError() <<
"expected rewrite region to be non-empty if "
498 "external name is not specified";
502 if (!getExternalArgs().empty()) {
503 return emitOpError() <<
"expected no external arguments when the "
504 "rewrite is specified inline";
511 StringRef RewriteOp::getDefaultDialect() {
512 return PDLDialect::getDialectNamespace();
520 if (!getConstantTypeAttr())
530 if (!getConstantTypesAttr())
539 #define GET_OP_CLASSES
540 #include "mlir/Dialect/PDL/IR/PDLOps.cpp.inc"
static ParseResult parseResultsValueType(OpAsmParser &p, IntegerAttr index, Type &resultType)
static LogicalResult verifyHasBindingUse(Operation *op)
Returns success if the given operation is not in the main matcher body or is used by a "binding" oper...
static void visit(Operation *op, DenseSet< Operation * > &visited)
Visits all the pdl.operand(s), pdl.result(s), and pdl.operation(s) connected to the given operation.
static LogicalResult verifyResultTypesAreInferrable(OperationOp op, OperandRange resultTypes)
Verifies that the result types of this operation, defined within a pdl.rewrite, can be inferred.
static void printRangeType(OpAsmPrinter &p, RangeOp op, TypeRange argumentTypes, Type resultType)
static ParseResult parseRangeType(OpAsmParser &p, TypeRange argumentTypes, Type &resultType)
static ParseResult parseOperationOpAttributes(OpAsmParser &p, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &attrOperands, ArrayAttr &attrNamesAttr)
static bool hasBindingUse(Operation *op)
Returns true if the given operation is used by a "binding" pdl operation.
static void printOperationOpAttributes(OpAsmPrinter &p, OperationOp op, OperandRange attrArgs, ArrayAttr attrNames)
static void printResultsValueType(OpAsmPrinter &p, ResultsOp op, IntegerAttr index, Type resultType)
static MLIRContext * getContext(OpFoldResult val)
virtual ParseResult parseCommaSeparatedList(Delimiter delimiter, function_ref< ParseResult()> parseElementFn, StringRef contextMessage=StringRef())=0
Parse a list of comma-separated items with an optional delimiter.
virtual Builder & getBuilder() const =0
Return a builder which provides useful access to MLIRContext, global objects like types and attribute...
virtual ParseResult parseRBrace()=0
Parse a } token.
virtual ParseResult parseEqual()=0
Parse a = token.
virtual ParseResult parseColonType(Type &result)=0
Parse a colon followed by a type.
virtual ParseResult parseArrow()=0
Parse a '->' token.
virtual ParseResult parseType(Type &result)=0
Parse a type.
virtual ParseResult parseAttribute(Attribute &result, Type type={})=0
Parse an arbitrary attribute of a given type and return it in result.
virtual ParseResult parseOptionalLBrace()=0
Parse a { token if present.
Block represents an ordered list of Operations.
Operation * getTerminator()
Get the terminator operation of this block.
iterator_range< op_iterator< OpT > > getOps()
Return an iterator range over the operations within this block that are of 'OpT'.
This class is a general helper class for creating context-global objects like types,...
IntegerAttr getI16IntegerAttr(int16_t value)
Ty getType(Args &&...args)
Get or construct an instance of the type Ty with provided arguments.
StringAttr getStringAttr(const Twine &bytes)
ArrayAttr getArrayAttr(ArrayRef< Attribute > value)
Diagnostic & append(Arg1 &&arg1, Arg2 &&arg2, Args &&...args)
Append arguments to the diagnostic.
Diagnostic & attachNote(std::optional< Location > noteLoc=std::nullopt)
Attaches a note to this diagnostic.
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
virtual ParseResult parseOperand(UnresolvedOperand &result, bool allowResultNumber=true)=0
Parse a single SSA value operand name along with a result number if allowResultNumber is true.
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
This class helps build Operations.
This class represents an operand of an operation.
This is a value defined by a result of an operation.
This class provides the API for ops which have an unknown number of results.
This class provides return value APIs for ops that are known to have zero results.
This class implements the operand iterators for the Operation class.
Operation is the basic unit of execution within MLIR.
Dialect * getDialect()
Return the dialect this operation is associated with, or nullptr if the associated dialect is not loa...
bool isBeforeInBlock(Operation *other)
Given an operation 'other' that is within the same parent block, return whether the current operation...
MLIRContext * getContext()
Return the context this operation is associated with.
Location getLoc()
The source location the operation was defined or derived from.
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
Block * getBlock()
Returns the operation block that contains this operation.
user_range getUsers()
Returns a range of all users.
use_range getUses()
Returns a range of all uses, which is useful for iterating over all uses.
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Region * getParentRegion()
Return the region containing this region or nullptr if the region is attached to a top-level operatio...
Operation * getParentOp()
Return the parent operation this region is attached to.
RetT walk(FnT &&callback)
Walk all nested operations, blocks or regions (including this region), depending on the type of callb...
static std::optional< RegisteredOperationName > lookup(StringRef name, MLIRContext *ctx)
Lookup the registered operation information for the given operation.
This class provides an abstraction over the various different ranges of value types.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
A utility result that is used to signal how to proceed with an ongoing walk:
static WalkResult advance()
bool wasInterrupted() const
Returns true if the walk was interrupted.
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Type getRangeElementTypeOrSelf(Type type)
If the given type is a range, return its element type, otherwise return the type itself.
DynamicAPInt getIndex(const ConeV &cone)
Get the index of a cone, i.e., the volume of the parallelepiped spanned by its generators,...
Include the generated interface declarations.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs,...
This is the representation of an operand reference.
This represents an operation in an abstracted form, suitable for use with the builder APIs.