14#include "llvm/ADT/TypeSwitch.h"
20#include "mlir/Dialect/PDL/IR/PDLOpsDialect.cpp.inc"
26void PDLDialect::initialize() {
29#include "mlir/Dialect/PDL/IR/PDLOps.cpp.inc"
51 if (!llvm::isa_and_nonnull<PatternOp>(op->
getParentOp()))
56 "expected a bindable user when defined in the matcher body of a "
64 if (!isa<PatternOp>(op->
getParentOp()) || isa<RewriteOp>(op))
68 if (!visited.insert(op).second)
73 .Case<OperationOp>([&visited](
auto operation) {
74 for (
Value operand : operation.getOperandValues())
75 visit(operand.getDefiningOp(), visited);
77 .Case<ResultOp, ResultsOp>([&visited](
auto result) {
90LogicalResult ApplyNativeConstraintOp::verify() {
91 if (getNumOperands() == 0)
92 return emitOpError(
"expected at least one argument");
94 return isa<OperationType>(
result.getType());
97 "returning an operation from a constraint is not supported");
106LogicalResult ApplyNativeRewriteOp::verify() {
107 if (getNumOperands() == 0 && getNumResults() == 0)
108 return emitOpError(
"expected at least one argument or result");
116LogicalResult AttributeOp::verify() {
118 std::optional<Attribute> attrValue = getValue();
121 if (isa<RewriteOp>((*this)->getParentOp()))
123 "expected constant value when specified within a `pdl.rewrite`");
127 return emitOpError(
"expected only one of [`type`, `value`] to be set");
154 auto parseOperands = [&]() {
160 attrNames.push_back(nameAttr);
161 attrOperands.push_back(operand);
174 if (attrNames.empty())
177 interleaveComma(llvm::seq<int>(0, attrNames.size()), p,
178 [&](
int i) { p << attrNames[i] <<
" = " << attrArgs[i]; });
187 Block *rewriterBlock = op->getBlock();
188 auto canInferTypeFromUse = [&](
OpOperand &use) {
191 ReplaceOp replOpUser = dyn_cast<ReplaceOp>(use.getOwner());
192 if (!replOpUser || use.getOperandNumber() == 0)
195 Operation *replacedOp = replOpUser.getOpValue().getDefiningOp();
196 return replacedOp->
getBlock() != rewriterBlock ||
202 if (llvm::any_of(op.getOp().getUses(), canInferTypeFromUse))
206 if (resultTypes.empty()) {
209 std::optional<StringRef> rawOpName = op.getOpName();
212 std::optional<RegisteredOperationName> opName =
221 bool expectedAtLeastOneResult =
224 if (expectedAtLeastOneResult) {
226 .emitOpError(
"must have inferable or constrained result types when "
227 "nested within `pdl.rewrite`")
229 .append(
"operation is created in a non-inferrable context, but '",
230 *opName,
"' does not implement InferTypeOpInterface");
236 for (
const auto &it : llvm::enumerate(resultTypes)) {
237 Operation *resultTypeOp = it.value().getDefiningOp();
238 assert(resultTypeOp &&
"expected valid result type operation");
242 if (isa<ApplyNativeRewriteOp>(resultTypeOp))
247 auto constrainsInput = [rewriterBlock](
Operation *user) {
248 return user->getBlock() != rewriterBlock &&
249 isa<OperandOp, OperandsOp, OperationOp>(user);
251 if (TypeOp typeOp = dyn_cast<TypeOp>(resultTypeOp)) {
252 if (typeOp.getConstantType() ||
253 llvm::any_of(typeOp->getUsers(), constrainsInput))
255 }
else if (TypesOp typeOp = dyn_cast<TypesOp>(resultTypeOp)) {
256 if (typeOp.getConstantTypes() ||
257 llvm::any_of(typeOp->getUsers(), constrainsInput))
262 .emitOpError(
"must have inferable or constrained result types when "
263 "nested within `pdl.rewrite`")
265 .append(
"result type #", it.index(),
" was not constrained");
270LogicalResult OperationOp::verify() {
271 bool isWithinRewrite = isa_and_nonnull<RewriteOp>((*this)->getParentOp());
272 if (isWithinRewrite && !getOpName())
273 return emitOpError(
"must have an operation name when nested within "
275 ArrayAttr attributeNames = getAttributeValueNamesAttr();
276 auto attributeValues = getAttributeValues();
277 if (attributeNames.size() != attributeValues.size()) {
279 <<
"expected the same number of attribute values and attribute "
281 << attributeNames.size() <<
" names and " << attributeValues.size()
287 if (isWithinRewrite && !mightHaveTypeInference()) {
295bool OperationOp::hasTypeInference() {
296 if (std::optional<StringRef> rawOpName = getOpName()) {
298 return opName.hasInterface<InferTypeOpInterface>();
303bool OperationOp::mightHaveTypeInference() {
304 if (std::optional<StringRef> rawOpName = getOpName()) {
306 return opName.mightHaveInterface<InferTypeOpInterface>();
315LogicalResult PatternOp::verifyRegions() {
316 Region &body = getBodyRegion();
318 auto rewriteOp = dyn_cast<RewriteOp>(term);
320 return emitOpError(
"expected body to terminate with `pdl.rewrite`")
321 .attachNote(term->
getLoc())
322 .append(
"see terminator defined here");
327 WalkResult
result = body.
walk([&](Operation *op) -> WalkResult {
328 if (!isa_and_nonnull<PDLDialect>(op->
getDialect())) {
329 emitOpError(
"expected only `pdl` operations within the pattern body")
331 .append(
"see non-`pdl` operation defined here");
336 if (
result.wasInterrupted())
341 return emitOpError(
"the pattern must contain at least one `pdl.operation`");
351 for (Operation &op : body.
front()) {
353 if (!isa<OperandOp, OperandsOp, ResultOp, ResultsOp, OperationOp>(op))
357 bool hasUserInRewrite =
false;
358 for (Operation *user : op.
getUsers()) {
359 Region *region = user->getParentRegion();
360 if (isa<RewriteOp>(user) ||
361 (region && isa<RewriteOp>(region->
getParentOp()))) {
362 hasUserInRewrite =
true;
368 if (!hasUserInRewrite)
375 }
else if (!visited.count(&op)) {
377 return emitOpError(
"the operations must form a connected component")
379 .append(
"see a disconnected value / operation here");
386void PatternOp::build(OpBuilder &builder, OperationState &state,
387 std::optional<uint16_t> benefit,
388 std::optional<StringRef> name) {
391 state.
regions[0]->emplaceBlock();
395RewriteOp PatternOp::getRewriter() {
396 return cast<RewriteOp>(getBodyRegion().front().getTerminator());
400StringRef PatternOp::getDefaultDialect() {
401 return PDLDialect::getDialectNamespace();
411 if (!argumentTypes.empty()) {
421 if (argumentTypes.empty())
422 p <<
": " << resultType;
425LogicalResult RangeOp::verify() {
426 Type elementType =
getType().getElementType();
427 for (Type operandType : getOperandTypes()) {
429 if (operandElementType != elementType) {
430 return emitOpError(
"expected operand to have element type ")
431 << elementType <<
", but got " << operandElementType;
441LogicalResult ReplaceOp::verify() {
442 if (getReplOperation() && !getReplValues().empty())
443 return emitOpError() <<
"expected no replacement values to be provided"
444 " when the replacement operation is present";
466 p <<
" -> " << resultType;
469LogicalResult ResultsOp::verify() {
471 return emitOpError() <<
"expected `pdl.range<value>` result type when "
472 "no index is specified, but got: "
482LogicalResult RewriteOp::verifyRegions() {
483 Region &rewriteRegion = getBodyRegion();
487 if (!rewriteRegion.
empty()) {
489 <<
"expected rewrite region to be empty when rewrite is external";
495 if (rewriteRegion.
empty()) {
496 return emitOpError() <<
"expected rewrite region to be non-empty if "
497 "external name is not specified";
501 if (!getExternalArgs().empty()) {
502 return emitOpError() <<
"expected no external arguments when the "
503 "rewrite is specified inline";
510StringRef RewriteOp::getDefaultDialect() {
511 return PDLDialect::getDialectNamespace();
518LogicalResult TypeOp::verify() {
519 if (!getConstantTypeAttr())
528LogicalResult TypesOp::verify() {
529 if (!getConstantTypesAttr())
538#define GET_OP_CLASSES
539#include "mlir/Dialect/PDL/IR/PDLOps.cpp.inc"
p<< " : "<< getMemRefType()<< ", "<< getType();}static LogicalResult verifyVectorMemoryOp(Operation *op, MemRefType memrefType, VectorType vectorType) { if(memrefType.getElementType() !=vectorType.getElementType()) return op-> emitOpError("requires memref and vector types of the same elemental type")
Given a list of lists of parsed operands, populates uniqueOperands with unique operands.
static ParseResult parseResultsValueType(OpAsmParser &p, IntegerAttr index, Type &resultType)
static LogicalResult verifyHasBindingUse(Operation *op)
Returns success if the given operation is not in the main matcher body or is used by a "binding" oper...
static void visit(Operation *op, DenseSet< Operation * > &visited)
Visits all the pdl.operand(s), pdl.result(s), and pdl.operation(s) connected to the given operation.
static LogicalResult verifyResultTypesAreInferrable(OperationOp op, OperandRange resultTypes)
Verifies that the result types of this operation, defined within a pdl.rewrite, can be inferred.
static void printRangeType(OpAsmPrinter &p, RangeOp op, TypeRange argumentTypes, Type resultType)
static ParseResult parseRangeType(OpAsmParser &p, TypeRange argumentTypes, Type &resultType)
static ParseResult parseOperationOpAttributes(OpAsmParser &p, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &attrOperands, ArrayAttr &attrNamesAttr)
static bool hasBindingUse(Operation *op)
Returns true if the given operation is used by a "binding" pdl operation.
static void printOperationOpAttributes(OpAsmPrinter &p, OperationOp op, OperandRange attrArgs, ArrayAttr attrNames)
static void printResultsValueType(OpAsmPrinter &p, ResultsOp op, IntegerAttr index, Type resultType)
static Type getValueType(Attribute attr)
virtual Builder & getBuilder() const =0
Return a builder which provides useful access to MLIRContext, global objects like types and attribute...
virtual ParseResult parseCommaSeparatedList(Delimiter delimiter, function_ref< ParseResult()> parseElementFn, StringRef contextMessage=StringRef())=0
Parse a list of comma-separated items with an optional delimiter.
virtual ParseResult parseRBrace()=0
Parse a } token.
virtual ParseResult parseEqual()=0
Parse a = token.
virtual ParseResult parseColonType(Type &result)=0
Parse a colon followed by a type.
virtual ParseResult parseArrow()=0
Parse a '->' token.
virtual ParseResult parseType(Type &result)=0
Parse a type.
virtual ParseResult parseAttribute(Attribute &result, Type type={})=0
Parse an arbitrary attribute of a given type and return it in result.
virtual ParseResult parseOptionalLBrace()=0
Parse a { token if present.
Block represents an ordered list of Operations.
iterator_range< op_iterator< OpT > > getOps()
Return an iterator range over the operations within this block that are of 'OpT'.
Operation * getTerminator()
Get the terminator operation of this block.
This class is a general helper class for creating context-global objects like types,...
IntegerAttr getI16IntegerAttr(int16_t value)
Ty getType(Args &&...args)
Get or construct an instance of the type Ty with provided arguments.
StringAttr getStringAttr(const Twine &bytes)
ArrayAttr getArrayAttr(ArrayRef< Attribute > value)
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
virtual ParseResult parseOperand(UnresolvedOperand &result, bool allowResultNumber=true)=0
Parse a single SSA value operand name along with a result number if allowResultNumber is true.
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
This class represents an operand of an operation.
This is a value defined by a result of an operation.
This class provides the API for ops which have an unknown number of results.
This class provides return value APIs for ops that are known to have zero results.
This class implements the operand iterators for the Operation class.
Operation is the basic unit of execution within MLIR.
Dialect * getDialect()
Return the dialect this operation is associated with, or nullptr if the associated dialect is not loa...
bool isBeforeInBlock(Operation *other)
Given an operation 'other' that is within the same parent block, return whether the current operation...
Block * getBlock()
Returns the operation block that contains this operation.
Location getLoc()
The source location the operation was defined or derived from.
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
user_range getUsers()
Returns a range of all users.
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
Operation * getParentOp()
Return the parent operation this region is attached to.
RetT walk(FnT &&callback)
Walk all nested operations, blocks or regions (including this region), depending on the type of callb...
static std::optional< RegisteredOperationName > lookup(StringRef name, MLIRContext *ctx)
Lookup the registered operation information for the given operation.
This class provides an abstraction over the various different ranges of value types.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
static WalkResult advance()
static WalkResult interrupt()
Type getRangeElementTypeOrSelf(Type type)
If the given type is a range, return its element type, otherwise return the type itself.
DynamicAPInt getIndex(const ConeV &cone)
Get the index of a cone, i.e., the volume of the parallelepiped spanned by its generators,...
Include the generated interface declarations.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
llvm::DenseSet< ValueT, ValueInfoT > DenseSet
llvm::TypeSwitch< T, ResultT > TypeSwitch
This is the representation of an operand reference.
SmallVector< std::unique_ptr< Region >, 1 > regions
Regions that the op will hold.