14 #include "llvm/ADT/TypeSwitch.h"
20 #include "mlir/Dialect/PDL/IR/PDLOpsDialect.cpp.inc"
26 void PDLDialect::initialize() {
29 #include "mlir/Dialect/PDL/IR/PDLOps.cpp.inc"
51 if (!llvm::isa_and_nonnull<PatternOp>(op->
getParentOp()))
56 "expected a bindable user when defined in the matcher body of a "
64 if (!isa<PatternOp>(op->
getParentOp()) || isa<RewriteOp>(op))
68 if (!visited.insert(op).second)
73 .Case<OperationOp>([&visited](
auto operation) {
74 for (
Value operand : operation.getOperandValues())
75 visit(operand.getDefiningOp(), visited);
77 .Case<ResultOp, ResultsOp>([&visited](
auto result) {
78 visit(result.getParent().getDefiningOp(), visited);
91 if (getNumOperands() == 0)
92 return emitOpError(
"expected at least one argument");
93 if (llvm::any_of(getResults(), [](
OpResult result) {
94 return isa<OperationType>(result.
getType());
97 "returning an operation from a constraint is not supported");
107 if (getNumOperands() == 0 && getNumResults() == 0)
108 return emitOpError(
"expected at least one argument or result");
117 Value attrType = getValueType();
118 std::optional<Attribute> attrValue = getValue();
121 if (isa<RewriteOp>((*this)->getParentOp()))
123 "expected constant value when specified within a `pdl.rewrite`");
127 return emitOpError(
"expected only one of [`type`, `value`] to be set");
150 ArrayAttr &attrNamesAttr) {
154 auto parseOperands = [&]() {
160 attrNames.push_back(nameAttr);
161 attrOperands.push_back(operand);
173 ArrayAttr attrNames) {
174 if (attrNames.empty())
177 interleaveComma(llvm::seq<int>(0, attrNames.size()), p,
178 [&](
int i) { p << attrNames[i] <<
" = " << attrArgs[i]; });
187 Block *rewriterBlock = op->getBlock();
188 auto canInferTypeFromUse = [&](
OpOperand &use) {
191 ReplaceOp replOpUser = dyn_cast<ReplaceOp>(use.getOwner());
192 if (!replOpUser || use.getOperandNumber() == 0)
195 Operation *replacedOp = replOpUser.getOpValue().getDefiningOp();
196 return replacedOp->
getBlock() != rewriterBlock ||
202 if (llvm::any_of(op.getOp().getUses(), canInferTypeFromUse))
206 if (resultTypes.empty()) {
209 std::optional<StringRef> rawOpName = op.getOpName();
212 std::optional<RegisteredOperationName> opName =
221 bool expectedAtLeastOneResult =
224 if (expectedAtLeastOneResult) {
226 .emitOpError(
"must have inferable or constrained result types when "
227 "nested within `pdl.rewrite`")
229 .append(
"operation is created in a non-inferrable context, but '",
230 *opName,
"' does not implement InferTypeOpInterface");
237 Operation *resultTypeOp = it.value().getDefiningOp();
238 assert(resultTypeOp &&
"expected valid result type operation");
242 if (isa<ApplyNativeRewriteOp>(resultTypeOp))
247 auto constrainsInput = [rewriterBlock](
Operation *user) {
248 return user->getBlock() != rewriterBlock &&
249 isa<OperandOp, OperandsOp, OperationOp>(user);
251 if (TypeOp typeOp = dyn_cast<TypeOp>(resultTypeOp)) {
252 if (typeOp.getConstantType() ||
253 llvm::any_of(typeOp->getUsers(), constrainsInput))
255 }
else if (TypesOp typeOp = dyn_cast<TypesOp>(resultTypeOp)) {
256 if (typeOp.getConstantTypes() ||
257 llvm::any_of(typeOp->getUsers(), constrainsInput))
262 .emitOpError(
"must have inferable or constrained result types when "
263 "nested within `pdl.rewrite`")
265 .append(
"result type #", it.index(),
" was not constrained");
271 bool isWithinRewrite = isa_and_nonnull<RewriteOp>((*this)->getParentOp());
272 if (isWithinRewrite && !getOpName())
273 return emitOpError(
"must have an operation name when nested within "
275 ArrayAttr attributeNames = getAttributeValueNamesAttr();
276 auto attributeValues = getAttributeValues();
277 if (attributeNames.size() != attributeValues.size()) {
279 <<
"expected the same number of attribute values and attribute "
281 << attributeNames.size() <<
" names and " << attributeValues.size()
287 if (isWithinRewrite && !mightHaveTypeInference()) {
295 bool OperationOp::hasTypeInference() {
296 if (std::optional<StringRef> rawOpName = getOpName()) {
298 return opName.hasInterface<InferTypeOpInterface>();
303 bool OperationOp::mightHaveTypeInference() {
304 if (std::optional<StringRef> rawOpName = getOpName()) {
306 return opName.mightHaveInterface<InferTypeOpInterface>();
315 LogicalResult PatternOp::verifyRegions() {
316 Region &body = getBodyRegion();
318 auto rewriteOp = dyn_cast<RewriteOp>(term);
320 return emitOpError(
"expected body to terminate with `pdl.rewrite`")
321 .attachNote(term->
getLoc())
322 .append(
"see terminator defined here");
328 if (!isa_and_nonnull<PDLDialect>(op->
getDialect())) {
329 emitOpError(
"expected only `pdl` operations within the pattern body")
330 .attachNote(op->getLoc())
331 .append(
"see non-`pdl` operation defined here");
332 return WalkResult::interrupt();
341 return emitOpError(
"the pattern must contain at least one `pdl.operation`");
353 if (!isa<OperandOp, OperandsOp, ResultOp, ResultsOp, OperationOp>(op))
357 bool hasUserInRewrite =
false;
360 if (isa<RewriteOp>(user) ||
361 (region && isa<RewriteOp>(region->
getParentOp()))) {
362 hasUserInRewrite =
true;
368 if (!hasUserInRewrite)
375 }
else if (!visited.count(&op)) {
377 return emitOpError(
"the operations must form a connected component")
379 .append(
"see a disconnected value / operation here");
387 std::optional<uint16_t> benefit,
388 std::optional<StringRef> name) {
391 state.regions[0]->emplaceBlock();
395 RewriteOp PatternOp::getRewriter() {
396 return cast<RewriteOp>(getBodyRegion().front().getTerminator());
400 StringRef PatternOp::getDefaultDialect() {
401 return PDLDialect::getDialectNamespace();
411 if (!argumentTypes.empty()) {
421 if (argumentTypes.empty())
422 p <<
": " << resultType;
427 for (
Type operandType : getOperandTypes()) {
429 if (operandElementType != elementType) {
430 return emitOpError(
"expected operand to have element type ")
431 << elementType <<
", but got " << operandElementType;
442 if (getReplOperation() && !getReplValues().empty())
443 return emitOpError() <<
"expected no replacement values to be provided"
444 " when the replacement operation is present";
464 IntegerAttr index,
Type resultType) {
466 p <<
" -> " << resultType;
471 return emitOpError() <<
"expected `pdl.range<value>` result type when "
472 "no index is specified, but got: "
482 LogicalResult RewriteOp::verifyRegions() {
483 Region &rewriteRegion = getBodyRegion();
487 if (!rewriteRegion.
empty()) {
489 <<
"expected rewrite region to be empty when rewrite is external";
495 if (rewriteRegion.
empty()) {
496 return emitOpError() <<
"expected rewrite region to be non-empty if "
497 "external name is not specified";
501 if (!getExternalArgs().empty()) {
502 return emitOpError() <<
"expected no external arguments when the "
503 "rewrite is specified inline";
510 StringRef RewriteOp::getDefaultDialect() {
511 return PDLDialect::getDialectNamespace();
519 if (!getConstantTypeAttr())
529 if (!getConstantTypesAttr())
538 #define GET_OP_CLASSES
539 #include "mlir/Dialect/PDL/IR/PDLOps.cpp.inc"
static ParseResult parseResultsValueType(OpAsmParser &p, IntegerAttr index, Type &resultType)
static LogicalResult verifyHasBindingUse(Operation *op)
Returns success if the given operation is not in the main matcher body or is used by a "binding" oper...
static void visit(Operation *op, DenseSet< Operation * > &visited)
Visits all the pdl.operand(s), pdl.result(s), and pdl.operation(s) connected to the given operation.
static LogicalResult verifyResultTypesAreInferrable(OperationOp op, OperandRange resultTypes)
Verifies that the result types of this operation, defined within a pdl.rewrite, can be inferred.
static void printRangeType(OpAsmPrinter &p, RangeOp op, TypeRange argumentTypes, Type resultType)
static ParseResult parseRangeType(OpAsmParser &p, TypeRange argumentTypes, Type &resultType)
static ParseResult parseOperationOpAttributes(OpAsmParser &p, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &attrOperands, ArrayAttr &attrNamesAttr)
static bool hasBindingUse(Operation *op)
Returns true if the given operation is used by a "binding" pdl operation.
static void printOperationOpAttributes(OpAsmPrinter &p, OperationOp op, OperandRange attrArgs, ArrayAttr attrNames)
static void printResultsValueType(OpAsmPrinter &p, ResultsOp op, IntegerAttr index, Type resultType)
static MLIRContext * getContext(OpFoldResult val)
virtual ParseResult parseCommaSeparatedList(Delimiter delimiter, function_ref< ParseResult()> parseElementFn, StringRef contextMessage=StringRef())=0
Parse a list of comma-separated items with an optional delimiter.
virtual Builder & getBuilder() const =0
Return a builder which provides useful access to MLIRContext, global objects like types and attribute...
virtual ParseResult parseRBrace()=0
Parse a } token.
virtual ParseResult parseEqual()=0
Parse a = token.
virtual ParseResult parseColonType(Type &result)=0
Parse a colon followed by a type.
virtual ParseResult parseArrow()=0
Parse a '->' token.
virtual ParseResult parseType(Type &result)=0
Parse a type.
virtual ParseResult parseAttribute(Attribute &result, Type type={})=0
Parse an arbitrary attribute of a given type and return it in result.
virtual ParseResult parseOptionalLBrace()=0
Parse a { token if present.
Block represents an ordered list of Operations.
Operation * getTerminator()
Get the terminator operation of this block.
iterator_range< op_iterator< OpT > > getOps()
Return an iterator range over the operations within this block that are of 'OpT'.
This class is a general helper class for creating context-global objects like types,...
IntegerAttr getI16IntegerAttr(int16_t value)
Ty getType(Args &&...args)
Get or construct an instance of the type Ty with provided arguments.
StringAttr getStringAttr(const Twine &bytes)
ArrayAttr getArrayAttr(ArrayRef< Attribute > value)
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
virtual ParseResult parseOperand(UnresolvedOperand &result, bool allowResultNumber=true)=0
Parse a single SSA value operand name along with a result number if allowResultNumber is true.
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
This class helps build Operations.
This class represents an operand of an operation.
This is a value defined by a result of an operation.
This class provides the API for ops which have an unknown number of results.
This class provides return value APIs for ops that are known to have zero results.
This class implements the operand iterators for the Operation class.
Operation is the basic unit of execution within MLIR.
Dialect * getDialect()
Return the dialect this operation is associated with, or nullptr if the associated dialect is not loa...
bool isBeforeInBlock(Operation *other)
Given an operation 'other' that is within the same parent block, return whether the current operation...
Location getLoc()
The source location the operation was defined or derived from.
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
Block * getBlock()
Returns the operation block that contains this operation.
user_range getUsers()
Returns a range of all users.
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Region * getParentRegion()
Return the region containing this region or nullptr if the region is attached to a top-level operatio...
Operation * getParentOp()
Return the parent operation this region is attached to.
RetT walk(FnT &&callback)
Walk all nested operations, blocks or regions (including this region), depending on the type of callb...
static std::optional< RegisteredOperationName > lookup(StringRef name, MLIRContext *ctx)
Lookup the registered operation information for the given operation.
This class provides an abstraction over the various different ranges of value types.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
A utility result that is used to signal how to proceed with an ongoing walk:
static WalkResult advance()
bool wasInterrupted() const
Returns true if the walk was interrupted.
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Type getRangeElementTypeOrSelf(Type type)
If the given type is a range, return its element type, otherwise return the type itself.
DynamicAPInt getIndex(const ConeV &cone)
Get the index of a cone, i.e., the volume of the parallelepiped spanned by its generators,...
Include the generated interface declarations.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs,...
This is the representation of an operand reference.
This represents an operation in an abstracted form, suitable for use with the builder APIs.