14 #include "llvm/ADT/DenseSet.h"
15 #include "llvm/ADT/TypeSwitch.h"
21 #include "mlir/Dialect/PDL/IR/PDLOpsDialect.cpp.inc"
27 void PDLDialect::initialize() {
30 #include "mlir/Dialect/PDL/IR/PDLOps.cpp.inc"
52 if (!llvm::isa_and_nonnull<PatternOp>(op->
getParentOp()))
57 "expected a bindable user when defined in the matcher body of a "
65 if (!isa<PatternOp>(op->
getParentOp()) || isa<RewriteOp>(op))
69 if (visited.contains(op))
77 .Case<OperationOp>([&visited](
auto operation) {
78 for (
Value operand : operation.getOperandValues())
79 visit(operand.getDefiningOp(), visited);
81 .Case<ResultOp, ResultsOp>([&visited](
auto result) {
82 visit(result.getParent().getDefiningOp(), visited);
95 if (getNumOperands() == 0)
96 return emitOpError(
"expected at least one argument");
105 if (getNumOperands() == 0 && getNumResults() == 0)
106 return emitOpError(
"expected at least one argument or result");
115 Value attrType = getValueType();
116 std::optional<Attribute> attrValue = getValue();
119 if (isa<RewriteOp>((*this)->getParentOp()))
121 "expected constant value when specified within a `pdl.rewrite`");
125 return emitOpError(
"expected only one of [`type`, `value`] to be set");
148 ArrayAttr &attrNamesAttr) {
152 auto parseOperands = [&]() {
158 attrNames.push_back(nameAttr);
159 attrOperands.push_back(operand);
171 ArrayAttr attrNames) {
172 if (attrNames.empty())
175 interleaveComma(llvm::seq<int>(0, attrNames.size()), p,
176 [&](
int i) { p << attrNames[i] <<
" = " << attrArgs[i]; });
186 auto canInferTypeFromUse = [&](
OpOperand &use) {
189 ReplaceOp replOpUser = dyn_cast<ReplaceOp>(use.getOwner());
190 if (!replOpUser || use.getOperandNumber() == 0)
193 Operation *replacedOp = replOpUser.getOpValue().getDefiningOp();
194 return replacedOp->
getBlock() != rewriterBlock ||
200 if (llvm::any_of(op.getOp().
getUses(), canInferTypeFromUse))
204 if (resultTypes.empty()) {
207 std::optional<StringRef> rawOpName = op.getOpName();
210 std::optional<RegisteredOperationName> opName =
219 bool expectedAtLeastOneResult =
222 if (expectedAtLeastOneResult) {
224 .
emitOpError(
"must have inferable or constrained result types when "
225 "nested within `pdl.rewrite`")
227 .
append(
"operation is created in a non-inferrable context, but '",
228 *opName,
"' does not implement InferTypeOpInterface");
235 Operation *resultTypeOp = it.value().getDefiningOp();
236 assert(resultTypeOp &&
"expected valid result type operation");
240 if (isa<ApplyNativeRewriteOp>(resultTypeOp))
245 auto constrainsInput = [rewriterBlock](
Operation *user) {
246 return user->getBlock() != rewriterBlock &&
247 isa<OperandOp, OperandsOp, OperationOp>(user);
249 if (TypeOp typeOp = dyn_cast<TypeOp>(resultTypeOp)) {
250 if (typeOp.getConstantType() ||
251 llvm::any_of(typeOp->getUsers(), constrainsInput))
253 }
else if (TypesOp typeOp = dyn_cast<TypesOp>(resultTypeOp)) {
254 if (typeOp.getConstantTypes() ||
255 llvm::any_of(typeOp->getUsers(), constrainsInput))
260 .
emitOpError(
"must have inferable or constrained result types when "
261 "nested within `pdl.rewrite`")
263 .
append(
"result type #", it.index(),
" was not constrained");
269 bool isWithinRewrite = isa_and_nonnull<RewriteOp>((*this)->getParentOp());
270 if (isWithinRewrite && !getOpName())
271 return emitOpError(
"must have an operation name when nested within "
273 ArrayAttr attributeNames = getAttributeValueNamesAttr();
274 auto attributeValues = getAttributeValues();
275 if (attributeNames.size() != attributeValues.size()) {
277 <<
"expected the same number of attribute values and attribute "
279 << attributeNames.size() <<
" names and " << attributeValues.size()
285 if (isWithinRewrite && !mightHaveTypeInference()) {
293 bool OperationOp::hasTypeInference() {
294 if (std::optional<StringRef> rawOpName = getOpName()) {
296 return opName.hasInterface<InferTypeOpInterface>();
301 bool OperationOp::mightHaveTypeInference() {
302 if (std::optional<StringRef> rawOpName = getOpName()) {
304 return opName.mightHaveInterface<InferTypeOpInterface>();
314 Region &body = getBodyRegion();
316 auto rewriteOp = dyn_cast<RewriteOp>(term);
318 return emitOpError(
"expected body to terminate with `pdl.rewrite`")
319 .attachNote(term->
getLoc())
320 .append(
"see terminator defined here");
326 if (!isa_and_nonnull<PDLDialect>(op->
getDialect())) {
327 emitOpError(
"expected only `pdl` operations within the pattern body")
328 .attachNote(op->getLoc())
329 .append(
"see non-`pdl` operation defined here");
330 return WalkResult::interrupt();
339 return emitOpError(
"the pattern must contain at least one `pdl.operation`");
351 if (!isa<OperandOp, OperandsOp, ResultOp, ResultsOp, OperationOp>(op))
355 bool hasUserInRewrite =
false;
358 if (isa<RewriteOp>(user) ||
359 (region && isa<RewriteOp>(region->
getParentOp()))) {
360 hasUserInRewrite =
true;
366 if (!hasUserInRewrite)
373 }
else if (!visited.count(&op)) {
375 return emitOpError(
"the operations must form a connected component")
377 .append(
"see a disconnected value / operation here");
385 std::optional<uint16_t> benefit,
386 std::optional<StringRef> name) {
389 state.regions[0]->emplaceBlock();
393 RewriteOp PatternOp::getRewriter() {
394 return cast<RewriteOp>(getBodyRegion().front().getTerminator());
398 StringRef PatternOp::getDefaultDialect() {
399 return PDLDialect::getDialectNamespace();
409 if (!argumentTypes.empty()) {
419 if (argumentTypes.empty())
420 p <<
": " << resultType;
424 Type elementType = getType().getElementType();
425 for (
Type operandType : getOperandTypes()) {
427 if (operandElementType != elementType) {
428 return emitOpError(
"expected operand to have element type ")
429 << elementType <<
", but got " << operandElementType;
440 if (getReplOperation() && !getReplValues().empty())
441 return emitOpError() <<
"expected no replacement values to be provided"
442 " when the replacement operation is present";
462 IntegerAttr index,
Type resultType) {
464 p <<
" -> " << resultType;
468 if (!getIndex() && llvm::isa<pdl::ValueType>(getType())) {
469 return emitOpError() <<
"expected `pdl.range<value>` result type when "
470 "no index is specified, but got: "
481 Region &rewriteRegion = getBodyRegion();
485 if (!rewriteRegion.
empty()) {
487 <<
"expected rewrite region to be empty when rewrite is external";
493 if (rewriteRegion.
empty()) {
494 return emitOpError() <<
"expected rewrite region to be non-empty if "
495 "external name is not specified";
499 if (!getExternalArgs().empty()) {
500 return emitOpError() <<
"expected no external arguments when the "
501 "rewrite is specified inline";
508 StringRef RewriteOp::getDefaultDialect() {
509 return PDLDialect::getDialectNamespace();
517 if (!getConstantTypeAttr())
527 if (!getConstantTypesAttr())
536 #define GET_OP_CLASSES
537 #include "mlir/Dialect/PDL/IR/PDLOps.cpp.inc"
static ParseResult parseResultsValueType(OpAsmParser &p, IntegerAttr index, Type &resultType)
static LogicalResult verifyHasBindingUse(Operation *op)
Returns success if the given operation is not in the main matcher body or is used by a "binding" oper...
static void visit(Operation *op, DenseSet< Operation * > &visited)
Visits all the pdl.operand(s), pdl.result(s), and pdl.operation(s) connected to the given operation.
static LogicalResult verifyResultTypesAreInferrable(OperationOp op, OperandRange resultTypes)
Verifies that the result types of this operation, defined within a pdl.rewrite, can be inferred.
static void printRangeType(OpAsmPrinter &p, RangeOp op, TypeRange argumentTypes, Type resultType)
static ParseResult parseRangeType(OpAsmParser &p, TypeRange argumentTypes, Type &resultType)
static ParseResult parseOperationOpAttributes(OpAsmParser &p, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &attrOperands, ArrayAttr &attrNamesAttr)
static bool hasBindingUse(Operation *op)
Returns true if the given operation is used by a "binding" pdl operation.
static void printOperationOpAttributes(OpAsmPrinter &p, OperationOp op, OperandRange attrArgs, ArrayAttr attrNames)
static void printResultsValueType(OpAsmPrinter &p, ResultsOp op, IntegerAttr index, Type resultType)
static MLIRContext * getContext(OpFoldResult val)
virtual ParseResult parseCommaSeparatedList(Delimiter delimiter, function_ref< ParseResult()> parseElementFn, StringRef contextMessage=StringRef())=0
Parse a list of comma-separated items with an optional delimiter.
virtual Builder & getBuilder() const =0
Return a builder which provides useful access to MLIRContext, global objects like types and attribute...
virtual ParseResult parseRBrace()=0
Parse a } token.
virtual ParseResult parseEqual()=0
Parse a = token.
virtual ParseResult parseColonType(Type &result)=0
Parse a colon followed by a type.
virtual ParseResult parseArrow()=0
Parse a '->' token.
virtual ParseResult parseType(Type &result)=0
Parse a type.
virtual ParseResult parseAttribute(Attribute &result, Type type={})=0
Parse an arbitrary attribute of a given type and return it in result.
virtual ParseResult parseOptionalLBrace()=0
Parse a { token if present.
Block represents an ordered list of Operations.
Operation * getTerminator()
Get the terminator operation of this block.
iterator_range< op_iterator< OpT > > getOps()
Return an iterator range over the operations within this block that are of 'OpT'.
This class is a general helper class for creating context-global objects like types,...
IntegerAttr getI16IntegerAttr(int16_t value)
Ty getType(Args &&...args)
Get or construct an instance of the type Ty with provided arguments.
StringAttr getStringAttr(const Twine &bytes)
ArrayAttr getArrayAttr(ArrayRef< Attribute > value)
Diagnostic & append(Arg1 &&arg1, Arg2 &&arg2, Args &&...args)
Append arguments to the diagnostic.
Diagnostic & attachNote(std::optional< Location > noteLoc=std::nullopt)
Attaches a note to this diagnostic.
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
virtual ParseResult parseOperand(UnresolvedOperand &result, bool allowResultNumber=true)=0
Parse a single SSA value operand name along with a result number if allowResultNumber is true.
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
This class helps build Operations.
This class represents an operand of an operation.
This class provides the API for ops which have an unknown number of results.
This class provides return value APIs for ops that are known to have zero results.
This class implements the operand iterators for the Operation class.
Operation is the basic unit of execution within MLIR.
Dialect * getDialect()
Return the dialect this operation is associated with, or nullptr if the associated dialect is not loa...
bool isBeforeInBlock(Operation *other)
Given an operation 'other' that is within the same parent block, return whether the current operation...
MLIRContext * getContext()
Return the context this operation is associated with.
Location getLoc()
The source location the operation was defined or derived from.
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
Block * getBlock()
Returns the operation block that contains this operation.
user_range getUsers()
Returns a range of all users.
use_range getUses()
Returns a range of all uses, which is useful for iterating over all uses.
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
This class represents success/failure for parsing-like operations that find it important to chain tog...
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Region * getParentRegion()
Return the region containing this region or nullptr if the region is attached to a top-level operatio...
Operation * getParentOp()
Return the parent operation this region is attached to.
std::enable_if_t< std::is_same< RetT, void >::value, RetT > walk(FnT &&callback)
Walk the operations in this region.
static std::optional< RegisteredOperationName > lookup(StringRef name, MLIRContext *ctx)
Lookup the registered operation information for the given operation.
This class provides an abstraction over the various different ranges of value types.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
A utility result that is used to signal how to proceed with an ongoing walk:
static WalkResult advance()
bool wasInterrupted() const
Returns true if the walk was interrupted.
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Type getRangeElementTypeOrSelf(Type type)
If the given type is a range, return its element type, otherwise return the type itself.
Include the generated interface declarations.
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
bool succeeded(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a success value.
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs,...
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value.
This class represents an efficient way to signal success or failure.
This is the representation of an operand reference.
This represents an operation in an abstracted form, suitable for use with the builder APIs.