22 #include "llvm/ADT/ScopedHashTable.h"
23 #include "llvm/ADT/StringExtras.h"
24 #include "llvm/ADT/TypeSwitch.h"
38 const llvm::SourceMgr &sourceMgr)
39 : builder(mlirContext), odsContext(context.getODSContext()),
40 sourceMgr(sourceMgr) {
50 Location genLoc(llvm::SMRange loc) {
return genLoc(loc.Start); }
107 bool isNegated =
false);
110 template <
typename PDLOpT,
typename T>
113 bool isNegated =
false);
123 using VariableMapTy =
124 llvm::ScopedHashTable<const ast::VariableDecl *, SmallVector<Value>>;
125 VariableMapTy variables;
131 const llvm::SourceMgr &sourceMgr;
137 builder.create<ModuleOp>(genLoc(module.getLoc()));
138 builder.setInsertionPointToStart(mlirModule->getBody());
147 Location CodeGen::genLoc(llvm::SMLoc loc) {
148 unsigned fileID = sourceMgr.FindBufferContainingLoc(loc);
152 auto &bufferInfo = sourceMgr.getBufferInfo(fileID);
153 unsigned lineNo = bufferInfo.getLineNumber(loc.getPointer());
155 (loc.getPointer() - bufferInfo.getPointerForLineNumber(lineNo)) + 1;
156 auto *buffer = sourceMgr.getMemoryBuffer(fileID);
159 buffer->getBufferIdentifier(), lineNo, column);
165 return builder.getType<pdl::AttributeType>();
168 return builder.getType<pdl::OperationType>();
171 return builder.getType<pdl::TypeType>();
174 return builder.getType<pdl::ValueType>();
181 void CodeGen::gen(
const ast::Node *node) {
187 [&](
auto derivedNode) { this->genImpl(derivedNode); })
188 .Case([&](
const ast::Expr *expr) { genExpr(expr); });
196 VariableMapTy::ScopeTy varScope(variables);
208 builder.
create<pdl::RewriteOp>(loc, rootExpr, StringAttr(),
216 Value rootExpr = genSingleExpr(stmt->getRootOpExpr());
217 Location loc = genLoc(stmt->getLoc());
222 builder.create<pdl::EraseOp>(loc, rootExpr);
229 Value rootExpr = genSingleExpr(stmt->getRootOpExpr());
230 Location loc = genLoc(stmt->getLoc());
238 replValues.push_back(genSingleExpr(replExpr));
242 bool usesReplOperation =
243 replValues.size() == 1 &&
244 isa<pdl::OperationType>(replValues.front().getType());
245 builder.create<pdl::ReplaceOp>(
246 loc, rootExpr, usesReplOperation ? replValues[0] :
Value(),
252 Value rootExpr = genSingleExpr(stmt->getRootOpExpr());
286 pdl::PatternOp pattern = builder.
create<pdl::PatternOp>(
288 name ? std::optional<StringRef>(name->
getName())
289 : std::optional<StringRef>());
292 builder.setInsertionPointToStart(pattern.getBody());
297 auto it = variables.begin(varDecl);
298 if (it != variables.end())
305 values = genExpr(initExpr);
307 values.push_back(genNonInitializerVar(varDecl, genLoc(varDecl->getLoc())));
310 applyVarConstraints(varDecl, values);
312 variables.insert(varDecl, values);
319 auto getTypeConstraint = [&]() ->
Value {
325 [&,
this](
auto *cst) ->
Value {
326 if (
auto *typeConstraintExpr = cst->getTypeExpr())
327 return this->genSingleExpr(typeConstraintExpr);
339 Type mlirType = genType(type);
340 if (isa<ast::ValueType>(type))
341 return builder.create<pdl::OperandOp>(loc, mlirType, getTypeConstraint());
342 if (isa<ast::TypeType>(type))
343 return builder.create<pdl::TypeOp>(loc, mlirType, TypeAttr());
344 if (isa<ast::AttributeType>(type))
345 return builder.create<pdl::AttributeOp>(loc, getTypeConstraint());
347 Value operands = builder.create<pdl::OperandsOp>(
350 Value results = builder.create<pdl::TypesOp>(
353 return builder.create<pdl::OperationOp>(loc, opType.getName(), operands,
359 ast::Type eleTy = rangeTy.getElementType();
360 if (isa<ast::ValueType>(eleTy))
361 return builder.create<pdl::OperandsOp>(loc, mlirType,
362 getTypeConstraint());
363 if (isa<ast::TypeType>(eleTy))
364 return builder.create<pdl::TypesOp>(loc, mlirType, ArrayAttr());
367 llvm_unreachable(
"invalid non-initialized variable type");
375 if (
const auto *userCst = dyn_cast<ast::UserConstraintDecl>(ref.constraint))
376 genConstraintCall(userCst, genLoc(ref.referenceLoc), values);
388 [&](
auto derivedNode) {
return this->genExprImpl(derivedNode); })
389 .Case<const ast::CallExpr, const ast::DeclRefExpr, const ast::TupleExpr>(
390 [&](
auto derivedNode) {
391 return llvm::getSingleElement(this->genExprImpl(derivedNode));
397 .Case<const ast::CallExpr, const ast::DeclRefExpr, const ast::TupleExpr>(
398 [&](
auto derivedNode) {
return this->genExprImpl(derivedNode); })
400 return {genSingleExpr(expr)};
406 assert(attr &&
"invalid MLIR attribute data");
407 return builder.create<pdl::AttributeOp>(genLoc(expr->getLoc()), attr);
411 Location loc = genLoc(expr->getLoc());
414 arguments.push_back(genSingleExpr(arg));
417 auto *callableExpr = dyn_cast<ast::DeclRefExpr>(expr->
getCallableExpr());
418 assert(callableExpr &&
"unhandled CallExpr callable");
421 const ast::Decl *callable = callableExpr->getDecl();
422 if (
const auto *decl = dyn_cast<ast::UserConstraintDecl>(callable))
423 return genConstraintCall(decl, loc, arguments, expr->
getIsNegated());
424 if (
const auto *decl = dyn_cast<ast::UserRewriteDecl>(callable))
425 return genRewriteCall(decl, loc, arguments);
426 llvm_unreachable(
"unhandled CallExpr callable");
430 if (
const auto *varDecl = dyn_cast<ast::VariableDecl>(expr->
getDecl()))
431 return genVar(varDecl);
432 llvm_unreachable(
"unknown decl reference expression");
436 Location loc = genLoc(expr->getLoc());
443 if (isa<ast::AllResultsMemberAccessExpr>(expr)) {
444 Type mlirType = genType(expr->getType());
445 if (isa<pdl::ValueType>(mlirType))
446 return builder.create<pdl::ResultOp>(loc, mlirType, parentExprs[0],
447 builder.getI32IntegerAttr(0));
448 return builder.create<pdl::ResultsOp>(loc, mlirType, parentExprs[0]);
453 assert(llvm::isDigit(name[0]) &&
454 "unregistered op only allows numeric indexing");
455 unsigned resultIndex;
456 name.getAsInteger(10, resultIndex);
457 IntegerAttr index = builder.getI32IntegerAttr(resultIndex);
458 return builder.create<pdl::ResultOp>(loc, genType(expr->getType()),
459 parentExprs[0], index);
464 unsigned resultIndex = results.size();
465 if (llvm::isDigit(name[0])) {
466 name.getAsInteger(10, resultIndex);
469 return result.getName() == name;
471 resultIndex = llvm::find_if(results, findFn) - results.begin();
473 assert(resultIndex < results.size() &&
"invalid result index");
476 IntegerAttr index = builder.getI32IntegerAttr(resultIndex);
477 return builder.create<pdl::ResultsOp>(loc, genType(expr->getType()),
478 parentExprs[0], index);
482 if (
auto tupleType = dyn_cast<ast::TupleType>(parentType)) {
483 auto elementNames = tupleType.getElementNames();
487 if (llvm::isDigit(name[0]))
488 name.getAsInteger(10, index);
490 index = llvm::find(elementNames, name) - elementNames.begin();
492 assert(index < parentExprs.size() &&
"invalid result index");
493 return parentExprs[index];
496 llvm_unreachable(
"unhandled member access expression");
500 Location loc = genLoc(expr->getLoc());
501 std::optional<StringRef> opName = expr->
getName();
506 operands.push_back(genSingleExpr(operand));
512 attrNames.push_back(attr->getName().getName());
513 attrValues.push_back(genSingleExpr(attr->getValue()));
519 results.push_back(genSingleExpr(result));
521 return builder.create<pdl::OperationOp>(loc, opName, operands, attrNames,
522 attrValues, results);
528 llvm::append_range(elements, genExpr(element));
530 return builder.create<pdl::RangeOp>(genLoc(expr->getLoc()),
531 genType(expr->
getType()), elements);
537 elements.push_back(genSingleExpr(element));
543 assert(type &&
"invalid MLIR type data");
544 return builder.create<pdl::TypeOp>(genLoc(expr->getLoc()),
545 builder.getType<pdl::TypeType>(),
553 for (
auto it : llvm::zip(decl->
getInputs(), inputs))
554 applyVarConstraints(std::get<0>(it), std::get<1>(it));
558 genConstraintOrRewriteCall<pdl::ApplyNativeConstraintOp>(
559 decl, loc, inputs, isNegated);
562 for (
auto it : llvm::zip(decl->
getResults(), results))
563 applyVarConstraints(std::get<0>(it), std::get<1>(it));
569 return genConstraintOrRewriteCall<pdl::ApplyNativeRewriteOp>(decl, loc,
573 template <
typename PDLOpT,
typename T>
575 CodeGen::genConstraintOrRewriteCall(
const T *decl,
Location loc,
581 ast::Type declResultType = decl->getResultType();
583 if (
ast::TupleType tupleType = dyn_cast<ast::TupleType>(declResultType)) {
584 for (
ast::Type type : tupleType.getElementTypes())
585 resultTypes.push_back(genType(type));
587 resultTypes.push_back(genType(declResultType));
589 PDLOpT pdlOp = builder.create<PDLOpT>(loc, resultTypes,
590 decl->getName().getName(), inputs);
591 if (isNegated && std::is_same_v<PDLOpT, pdl::ApplyNativeConstraintOp>)
592 cast<pdl::ApplyNativeConstraintOp>(pdlOp).setIsNegated(
true);
593 return pdlOp->getResults();
597 VariableMapTy::ScopeTy varScope(variables);
602 for (
auto it : llvm::zip(inputs, decl->getInputs()))
603 variables.insert(std::get<1>(it), {std::get<0>(it)});
611 auto *returnStmt = dyn_cast<ast::ReturnStmt>(cstBody->
getChildren().back());
616 return genExpr(returnStmt->getResultExpr());
625 const llvm::SourceMgr &sourceMgr,
const ast::Module &module) {
626 CodeGen codegen(mlirContext, context, sourceMgr);
628 if (failed(
verify(*mlirModule)))
static void checkAndNestUnderRewriteOp(OpBuilder &builder, Value rootExpr, Location loc)
If the given builder is nested under a PDL PatternOp, build a rewrite operation and update the builde...
static void rewrite(DataFlowSolver &solver, MLIRContext *context, MutableArrayRef< Region > initialRegions)
Rewrite the given regions using the computing analysis.
Attributes are known-constant values of operations.
Operation * getParentOp()
Returns the closest surrounding operation that contains this block.
static FileLineColLoc get(StringAttr filename, unsigned line, unsigned column)
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
void loadDialect()
Load a dialect in the context.
RAII guard to reset the insertion point of the builder when destroyed.
This class helps build Operations.
Block * createBlock(Region *parent, Region::iterator insertPt={}, TypeRange argTypes={}, ArrayRef< Location > locs={})
Add new block with 'argTypes' arguments and set the insertion point to the end of it.
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Block * getInsertionBlock() const
Return the block the current insertion point belongs to.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
The class represents an Attribute constraint, and constrains a variable to be an Attribute.
This expression represents a literal MLIR Attribute, and contains the textual assembly format of that...
StringRef getValue() const
Get the raw value of this expression.
This class represents a PDLL type that corresponds to an mlir::Attribute.
This expression represents a call to a decl, such as a UserConstraintDecl/UserRewriteDecl.
Expr * getCallableExpr() const
Return the callable of this call.
MutableArrayRef< Expr * > getArguments()
Return the arguments of this call.
bool getIsNegated() const
Returns whether the result of this call is to be negated.
This statement represents a compound statement, which contains a collection of other statements.
MutableArrayRef< Stmt * > getChildren()
Return the children of this compound statement.
This class represents the main context of the PDLL AST.
This expression represents a reference to a Decl node.
Decl * getDecl() const
Get the decl referenced by this expression.
This class represents the base Decl node.
This statement represents the erase statement in PDLL.
This class represents a base AST Expression node.
Type getType() const
Return the type of this expression.
This statement represents a let statement in PDLL.
VariableDecl * getVarDecl() const
Return the variable defined by this statement.
This expression represents a named member or field access of a given parent expression.
const Expr * getParentExpr() const
Get the parent expression of this access.
StringRef getMemberName() const
Return the name of the member being accessed.
This class represents a top-level AST module.
MutableArrayRef< Decl * > getChildren()
Return the children of this module.
This Decl represents a NamedAttribute, and contains a string name and attribute value.
This class represents a base AST node.
This expression represents the structural form of an MLIR Operation.
MutableArrayRef< Expr * > getResultTypes()
Return the result types of this operation.
MutableArrayRef< NamedAttributeDecl * > getAttributes()
Return the attributes of this operation.
MutableArrayRef< Expr * > getOperands()
Return the operands of this operation.
std::optional< StringRef > getName() const
Return the name of the operation, or std::nullopt if there isn't one.
This class represents a PDLL type that corresponds to an mlir::Operation.
This Decl represents a single Pattern.
const CompoundStmt * getBody() const
Return the body of this pattern.
static PatternDecl * create(Context &ctx, SMRange location, const Name *name, std::optional< uint16_t > benefit, bool hasBoundedRecursion, const CompoundStmt *body)
std::optional< uint16_t > getBenefit() const
Return the benefit of this pattern if specified, or std::nullopt.
This expression builds a range from a set of element values (which may be ranges themselves).
MutableArrayRef< Expr * > getElements()
Return the element expressions of this range.
RangeType getType() const
Return the range result type of this expression.
This class represents a PDLL type that corresponds to a range of elements with a given element type.
Type getElementType() const
Return the element type of this range.
This statement represents the replace statement in PDLL.
MutableArrayRef< Expr * > getReplExprs()
Return the replacement values of this statement.
This statement represents a return from a "callable" like decl, e.g.
This statement represents an operation rewrite that contains a block of nested rewrite commands.
CompoundStmt * getRewriteBody() const
Return the compound rewrite body.
This class represents a base AST Statement node.
This expression builds a tuple from a set of element values.
MutableArrayRef< Expr * > getElements()
Return the element expressions of this tuple.
This class represents a PDLL tuple type, i.e.
This expression represents a literal MLIR Type, and contains the textual assembly format of that type...
StringRef getValue() const
Get the raw value of this expression.
This class represents a PDLL type that corresponds to an mlir::Type.
This decl represents a user defined constraint.
MutableArrayRef< VariableDecl * > getInputs()
Return the input arguments of this constraint.
MutableArrayRef< VariableDecl * > getResults()
Return the explicit results of the constraint declaration.
This decl represents a user defined rewrite.
The class represents a Value constraint, and constrains a variable to be a Value.
The class represents a ValueRange constraint, and constrains a variable to be a ValueRange.
This class represents a PDLL type that corresponds to an mlir::Value.
This Decl represents the definition of a PDLL variable.
Expr * getInitExpr() const
Return the initializer expression of this statement, or nullptr if there was no initializer.
MutableArrayRef< ConstraintRef > getConstraints()
Return the constraints of this variable.
Type getType() const
Return the type of the decl.
This class contains all of the registered ODS operation classes.
This class provides an ODS representation of a specific operation operand or result.
This class provides an ODS representation of a specific operation.
ArrayRef< OperandOrResult > getResults() const
Returns the results of this operation.
OwningOpRef< ModuleOp > codegenPDLLToMLIR(MLIRContext *mlirContext, const ast::Context &context, const llvm::SourceMgr &sourceMgr, const ast::Module &module)
Given a PDLL module, generate an MLIR PDL pattern module within the given MLIR context.
Include the generated interface declarations.
Attribute parseAttribute(llvm::StringRef attrStr, MLIRContext *context, Type type={}, size_t *numRead=nullptr, bool isKnownNullTerminated=false)
This parses a single MLIR attribute to an MLIR context if it was valid.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
Type parseType(llvm::StringRef typeStr, MLIRContext *context, size_t *numRead=nullptr, bool isKnownNullTerminated=false)
This parses a single MLIR type to an MLIR context if it was valid.
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs,...
This class represents a reference to a constraint, and contains a constraint and the location of the ...
This class provides a convenient API for interacting with source names.
StringRef getName() const
Return the raw string name.