22#include "llvm/ADT/ScopedHashTable.h"
23#include "llvm/ADT/StringExtras.h"
24#include "llvm/ADT/TypeSwitch.h"
37 CodeGen(MLIRContext *mlirContext,
const ast::Context &context,
38 const llvm::SourceMgr &sourceMgr)
39 : builder(mlirContext), odsContext(context.getODSContext()),
40 sourceMgr(sourceMgr) {
45 OwningOpRef<ModuleOp> generate(
const ast::Module &module);
49 Location genLoc(llvm::SMLoc loc);
50 Location genLoc(llvm::SMRange loc) {
return genLoc(loc.Start); }
53 Type genType(ast::Type type);
56 void gen(
const ast::Node *node);
62 void genImpl(
const ast::CompoundStmt *stmt);
63 void genImpl(
const ast::EraseStmt *stmt);
64 void genImpl(
const ast::LetStmt *stmt);
65 void genImpl(
const ast::ReplaceStmt *stmt);
66 void genImpl(
const ast::RewriteStmt *stmt);
67 void genImpl(
const ast::ReturnStmt *stmt);
73 void genImpl(
const ast::UserConstraintDecl *decl);
74 void genImpl(
const ast::UserRewriteDecl *decl);
75 void genImpl(
const ast::PatternDecl *decl);
79 SmallVector<Value> genVar(
const ast::VariableDecl *varDecl);
84 Value genNonInitializerVar(
const ast::VariableDecl *varDecl, Location loc);
88 void applyVarConstraints(
const ast::VariableDecl *varDecl,
ValueRange values);
94 Value genSingleExpr(
const ast::Expr *expr);
95 SmallVector<Value> genExpr(
const ast::Expr *expr);
96 Value genExprImpl(
const ast::AttributeExpr *expr);
97 SmallVector<Value> genExprImpl(
const ast::CallExpr *expr);
98 SmallVector<Value> genExprImpl(
const ast::DeclRefExpr *expr);
99 Value genExprImpl(
const ast::MemberAccessExpr *expr);
100 Value genExprImpl(
const ast::OperationExpr *expr);
101 Value genExprImpl(
const ast::RangeExpr *expr);
102 SmallVector<Value> genExprImpl(
const ast::TupleExpr *expr);
103 Value genExprImpl(
const ast::TypeExpr *expr);
105 SmallVector<Value> genConstraintCall(
const ast::UserConstraintDecl *decl,
107 bool isNegated =
false);
108 SmallVector<Value> genRewriteCall(
const ast::UserRewriteDecl *decl,
110 template <
typename PDLOpT,
typename T>
111 SmallVector<Value> genConstraintOrRewriteCall(
const T *decl, Location loc,
113 bool isNegated =
false);
123 using VariableMapTy =
124 llvm::ScopedHashTable<const ast::VariableDecl *, SmallVector<Value>>;
125 VariableMapTy variables;
128 const ods::Context &odsContext;
131 const llvm::SourceMgr &sourceMgr;
136 OwningOpRef<ModuleOp> mlirModule =
137 ModuleOp::create(builder, genLoc(module.
getLoc()));
138 builder.setInsertionPointToStart(mlirModule->getBody());
147Location CodeGen::genLoc(llvm::SMLoc loc) {
148 unsigned fileID = sourceMgr.FindBufferContainingLoc(loc);
152 auto &bufferInfo = sourceMgr.getBufferInfo(fileID);
153 unsigned lineNo = bufferInfo.getLineNumber(loc.getPointer());
155 (loc.getPointer() - bufferInfo.getPointerForLineNumber(lineNo)) + 1;
156 auto *buffer = sourceMgr.getMemoryBuffer(fileID);
159 buffer->getBufferIdentifier(), lineNo, column);
162Type CodeGen::genType(ast::Type type) {
164 .Case([&](ast::AttributeType astType) -> Type {
165 return builder.getType<pdl::AttributeType>();
167 .Case([&](ast::OperationType astType) -> Type {
168 return builder.getType<pdl::OperationType>();
170 .Case([&](ast::TypeType astType) -> Type {
171 return builder.getType<pdl::TypeType>();
173 .Case([&](ast::ValueType astType) -> Type {
174 return builder.getType<pdl::ValueType>();
176 .Case([&](ast::RangeType astType) -> Type {
181void CodeGen::gen(
const ast::Node *node) {
183 .Case<
const ast::CompoundStmt,
const ast::EraseStmt,
const ast::LetStmt,
184 const ast::ReplaceStmt,
const ast::RewriteStmt,
185 const ast::ReturnStmt,
const ast::UserConstraintDecl,
186 const ast::UserRewriteDecl,
const ast::PatternDecl>(
187 [&](
auto derivedNode) { this->genImpl(derivedNode); })
188 .Case([&](
const ast::Expr *expr) { genExpr(expr); });
195void CodeGen::genImpl(
const ast::CompoundStmt *stmt) {
196 VariableMapTy::ScopeTy varScope(variables);
197 for (
const ast::Stmt *childStmt : stmt->
getChildren())
208 pdl::RewriteOp::create(builder, loc, rootExpr, StringAttr(),
214void CodeGen::genImpl(
const ast::EraseStmt *stmt) {
215 OpBuilder::InsertionGuard insertGuard(builder);
217 Location loc = genLoc(stmt->
getLoc());
220 OpBuilder::InsertionGuard guard(builder);
222 pdl::EraseOp::create(builder, loc, rootExpr);
225void CodeGen::genImpl(
const ast::LetStmt *stmt) { genVar(stmt->
getVarDecl()); }
227void CodeGen::genImpl(
const ast::ReplaceStmt *stmt) {
228 OpBuilder::InsertionGuard insertGuard(builder);
230 Location loc = genLoc(stmt->
getLoc());
233 OpBuilder::InsertionGuard guard(builder);
236 SmallVector<Value> replValues;
238 replValues.push_back(genSingleExpr(replExpr));
242 bool usesReplOperation =
243 replValues.size() == 1 &&
244 isa<pdl::OperationType>(replValues.front().getType());
245 pdl::ReplaceOp::create(
246 builder, loc, rootExpr, usesReplOperation ? replValues[0] : Value(),
250void CodeGen::genImpl(
const ast::RewriteStmt *stmt) {
251 OpBuilder::InsertionGuard insertGuard(builder);
255 OpBuilder::InsertionGuard guard(builder);
260void CodeGen::genImpl(
const ast::ReturnStmt *stmt) {
269void CodeGen::genImpl(
const ast::UserConstraintDecl *decl) {
275void CodeGen::genImpl(
const ast::UserRewriteDecl *decl) {
281void CodeGen::genImpl(
const ast::PatternDecl *decl) {
282 const ast::Name *name = decl->
getName();
286 pdl::PatternOp pattern = pdl::PatternOp::create(
288 name ? std::optional<StringRef>(name->
getName())
289 : std::optional<StringRef>());
291 OpBuilder::InsertionGuard savedInsertPoint(builder);
292 builder.setInsertionPointToStart(pattern.getBody());
296SmallVector<Value> CodeGen::genVar(
const ast::VariableDecl *varDecl) {
297 auto it = variables.begin(varDecl);
298 if (it != variables.end())
303 SmallVector<Value> values;
304 if (
const ast::Expr *initExpr = varDecl->
getInitExpr())
305 values = genExpr(initExpr);
307 values.push_back(genNonInitializerVar(varDecl, genLoc(varDecl->
getLoc())));
310 applyVarConstraints(varDecl, values);
312 variables.insert(varDecl, values);
316Value CodeGen::genNonInitializerVar(
const ast::VariableDecl *varDecl,
319 auto getTypeConstraint = [&]() -> Value {
320 for (
const ast::ConstraintRef &constraint : varDecl->
getConstraints()) {
323 .Case<ast::AttrConstraintDecl, ast::ValueConstraintDecl,
324 ast::ValueRangeConstraintDecl>(
325 [&,
this](
auto *cst) -> Value {
326 if (
auto *typeConstraintExpr = cst->getTypeExpr())
327 return this->genSingleExpr(typeConstraintExpr);
338 ast::Type type = varDecl->
getType();
339 Type mlirType = genType(type);
340 if (isa<ast::ValueType>(type))
341 return pdl::OperandOp::create(builder, loc, mlirType, getTypeConstraint());
342 if (isa<ast::TypeType>(type))
343 return pdl::TypeOp::create(builder, loc, mlirType, TypeAttr());
344 if (isa<ast::AttributeType>(type))
345 return pdl::AttributeOp::create(builder, loc, getTypeConstraint());
346 if (ast::OperationType opType = dyn_cast<ast::OperationType>(type)) {
347 Value operands = pdl::OperandsOp::create(
348 builder, loc, pdl::RangeType::get(builder.getType<pdl::ValueType>()),
350 Value results = pdl::TypesOp::create(
351 builder, loc, pdl::RangeType::get(builder.getType<pdl::TypeType>()),
353 return pdl::OperationOp::create(builder, loc, opType.getName(), operands,
358 if (ast::RangeType rangeTy = dyn_cast<ast::RangeType>(type)) {
359 ast::Type eleTy = rangeTy.getElementType();
360 if (isa<ast::ValueType>(eleTy))
361 return pdl::OperandsOp::create(builder, loc, mlirType,
362 getTypeConstraint());
363 if (isa<ast::TypeType>(eleTy))
364 return pdl::TypesOp::create(builder, loc, mlirType,
368 llvm_unreachable(
"invalid non-initialized variable type");
371void CodeGen::applyVarConstraints(
const ast::VariableDecl *varDecl,
376 if (
const auto *userCst = dyn_cast<ast::UserConstraintDecl>(ref.constraint))
377 genConstraintCall(userCst, genLoc(ref.referenceLoc), values);
384Value CodeGen::genSingleExpr(
const ast::Expr *expr) {
386 .Case<
const ast::AttributeExpr,
const ast::MemberAccessExpr,
387 const ast::OperationExpr,
const ast::RangeExpr,
388 const ast::TypeExpr>(
389 [&](
auto derivedNode) {
return this->genExprImpl(derivedNode); })
390 .Case<const ast::CallExpr, const ast::DeclRefExpr, const ast::TupleExpr>(
391 [&](
auto derivedNode) {
392 return llvm::getSingleElement(this->genExprImpl(derivedNode));
396SmallVector<Value> CodeGen::genExpr(
const ast::Expr *expr) {
398 .Case<const ast::CallExpr, const ast::DeclRefExpr, const ast::TupleExpr>(
399 [&](
auto derivedNode) {
return this->genExprImpl(derivedNode); })
400 .Default([&](
const ast::Expr *expr) -> SmallVector<Value> {
401 return {genSingleExpr(expr)};
405Value CodeGen::genExprImpl(
const ast::AttributeExpr *expr) {
407 assert(attr &&
"invalid MLIR attribute data");
408 return pdl::AttributeOp::create(builder, genLoc(expr->
getLoc()), attr);
411SmallVector<Value> CodeGen::genExprImpl(
const ast::CallExpr *expr) {
412 Location loc = genLoc(expr->
getLoc());
413 SmallVector<Value> arguments;
415 arguments.push_back(genSingleExpr(arg));
418 auto *callableExpr = dyn_cast<ast::DeclRefExpr>(expr->
getCallableExpr());
419 assert(callableExpr &&
"unhandled CallExpr callable");
422 const ast::Decl *callable = callableExpr->getDecl();
423 if (
const auto *decl = dyn_cast<ast::UserConstraintDecl>(callable))
424 return genConstraintCall(decl, loc, arguments, expr->
getIsNegated());
425 if (
const auto *decl = dyn_cast<ast::UserRewriteDecl>(callable))
426 return genRewriteCall(decl, loc, arguments);
427 llvm_unreachable(
"unhandled CallExpr callable");
430SmallVector<Value> CodeGen::genExprImpl(
const ast::DeclRefExpr *expr) {
431 if (
const auto *varDecl = dyn_cast<ast::VariableDecl>(expr->
getDecl()))
432 return genVar(varDecl);
433 llvm_unreachable(
"unknown decl reference expression");
436Value CodeGen::genExprImpl(
const ast::MemberAccessExpr *expr) {
437 Location loc = genLoc(expr->
getLoc());
439 SmallVector<Value> parentExprs = genExpr(expr->
getParentExpr());
443 if (ast::OperationType opType = dyn_cast<ast::OperationType>(parentType)) {
444 if (isa<ast::AllResultsMemberAccessExpr>(expr)) {
445 Type mlirType = genType(expr->
getType());
446 if (isa<pdl::ValueType>(mlirType))
447 return pdl::ResultOp::create(builder, loc, mlirType, parentExprs[0],
448 builder.getI32IntegerAttr(0));
449 return pdl::ResultsOp::create(builder, loc, mlirType, parentExprs[0]);
452 const ods::Operation *odsOp = opType.getODSOperation();
454 assert(llvm::isDigit(name[0]) &&
455 "unregistered op only allows numeric indexing");
456 unsigned resultIndex;
457 name.getAsInteger(10, resultIndex);
458 IntegerAttr index = builder.getI32IntegerAttr(resultIndex);
459 return pdl::ResultOp::create(builder, loc, genType(expr->
getType()),
460 parentExprs[0], index);
464 ArrayRef<ods::OperandOrResult> results = odsOp->
getResults();
465 unsigned resultIndex = results.size();
466 if (llvm::isDigit(name[0])) {
467 name.getAsInteger(10, resultIndex);
469 auto findFn = [&](
const ods::OperandOrResult &
result) {
470 return result.getName() == name;
472 resultIndex = llvm::find_if(results, findFn) - results.begin();
474 assert(resultIndex < results.size() &&
"invalid result index");
477 IntegerAttr index = builder.getI32IntegerAttr(resultIndex);
478 return pdl::ResultsOp::create(builder, loc, genType(expr->
getType()),
479 parentExprs[0], index);
483 if (
auto tupleType = dyn_cast<ast::TupleType>(parentType)) {
484 auto elementNames = tupleType.getElementNames();
488 if (llvm::isDigit(name[0]))
489 name.getAsInteger(10, index);
491 index = llvm::find(elementNames, name) - elementNames.begin();
493 assert(index < parentExprs.size() &&
"invalid result index");
494 return parentExprs[index];
497 llvm_unreachable(
"unhandled member access expression");
500Value CodeGen::genExprImpl(
const ast::OperationExpr *expr) {
501 Location loc = genLoc(expr->
getLoc());
502 std::optional<StringRef> opName = expr->
getName();
505 SmallVector<Value> operands;
506 for (
const ast::Expr *operand : expr->
getOperands())
507 operands.push_back(genSingleExpr(operand));
510 SmallVector<StringRef> attrNames;
511 SmallVector<Value> attrValues;
512 for (
const ast::NamedAttributeDecl *attr : expr->
getAttributes()) {
513 attrNames.push_back(attr->getName().getName());
514 attrValues.push_back(genSingleExpr(attr->getValue()));
518 SmallVector<Value> results;
520 results.push_back(genSingleExpr(
result));
522 return pdl::OperationOp::create(builder, loc, opName, operands, attrNames,
523 attrValues, results);
526Value CodeGen::genExprImpl(
const ast::RangeExpr *expr) {
527 SmallVector<Value> elements;
528 for (
const ast::Expr *element : expr->
getElements())
529 llvm::append_range(elements, genExpr(element));
531 return pdl::RangeOp::create(builder, genLoc(expr->
getLoc()),
532 genType(expr->
getType()), elements);
535SmallVector<Value> CodeGen::genExprImpl(
const ast::TupleExpr *expr) {
536 SmallVector<Value> elements;
537 for (
const ast::Expr *element : expr->
getElements())
538 elements.push_back(genSingleExpr(element));
542Value CodeGen::genExprImpl(
const ast::TypeExpr *expr) {
544 assert(type &&
"invalid MLIR type data");
545 return pdl::TypeOp::create(builder, genLoc(expr->
getLoc()),
546 builder.getType<pdl::TypeType>(),
547 TypeAttr::get(type));
551CodeGen::genConstraintCall(
const ast::UserConstraintDecl *decl, Location loc,
554 for (
auto it : llvm::zip(decl->
getInputs(), inputs))
555 applyVarConstraints(std::get<0>(it), std::get<1>(it));
558 SmallVector<Value> results =
559 genConstraintOrRewriteCall<pdl::ApplyNativeConstraintOp>(
560 decl, loc, inputs, isNegated);
563 for (
auto it : llvm::zip(decl->
getResults(), results))
564 applyVarConstraints(std::get<0>(it), std::get<1>(it));
568SmallVector<Value> CodeGen::genRewriteCall(
const ast::UserRewriteDecl *decl,
570 return genConstraintOrRewriteCall<pdl::ApplyNativeRewriteOp>(decl, loc,
574template <
typename PDLOpT,
typename T>
576CodeGen::genConstraintOrRewriteCall(
const T *decl, Location loc,
578 const ast::CompoundStmt *cstBody = decl->getBody();
582 ast::Type declResultType = decl->getResultType();
583 SmallVector<Type> resultTypes;
584 if (ast::TupleType tupleType = dyn_cast<ast::TupleType>(declResultType)) {
585 for (ast::Type type : tupleType.getElementTypes())
586 resultTypes.push_back(genType(type));
588 resultTypes.push_back(genType(declResultType));
590 PDLOpT pdlOp = PDLOpT::create(builder, loc, resultTypes,
591 decl->getName().getName(), inputs);
592 if (isNegated && std::is_same_v<PDLOpT, pdl::ApplyNativeConstraintOp>)
593 cast<pdl::ApplyNativeConstraintOp>(pdlOp).setIsNegated(
true);
594 return pdlOp->getResults();
598 VariableMapTy::ScopeTy varScope(variables);
603 for (
auto it : llvm::zip(inputs, decl->getInputs()))
604 variables.insert(std::get<1>(it), {std::get<0>(it)});
611 return SmallVector<Value>();
612 auto *returnStmt = dyn_cast<ast::ReturnStmt>(cstBody->
getChildren().back());
614 return SmallVector<Value>();
617 return genExpr(returnStmt->getResultExpr());
626 const llvm::SourceMgr &sourceMgr,
const ast::Module &module) {
627 CodeGen codegen(mlirContext, context, sourceMgr);
629 if (failed(
verify(*mlirModule)))
static void checkAndNestUnderRewriteOp(OpBuilder &builder, Value rootExpr, Location loc)
If the given builder is nested under a PDL PatternOp, build a rewrite operation and update the builde...
static void rewrite(DataFlowSolver &solver, MLIRContext *context, MutableArrayRef< Region > initialRegions)
Rewrite the given regions using the computing analysis.
Operation * getParentOp()
Returns the closest surrounding operation that contains this block.
static FileLineColLoc get(StringAttr filename, unsigned line, unsigned column)
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
void loadDialect()
Load a dialect in the context.
This class helps build Operations.
Block * createBlock(Region *parent, Region::iterator insertPt={}, TypeRange argTypes={}, ArrayRef< Location > locs={})
Add new block with 'argTypes' arguments and set the insertion point to the end of it.
Block * getInsertionBlock() const
Return the block the current insertion point belongs to.
This class acts as an owning reference to an op, and will automatically destroy the held op on destru...
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
StringRef getValue() const
Get the raw value of this expression.
Expr * getCallableExpr() const
Return the callable of this call.
MutableArrayRef< Expr * > getArguments()
Return the arguments of this call.
bool getIsNegated() const
Returns whether the result of this call is to be negated.
MutableArrayRef< Stmt * > getChildren()
Return the children of this compound statement.
This class represents the main context of the PDLL AST.
Decl * getDecl() const
Get the decl referenced by this expression.
const Name * getName() const
Return the name of the decl, or nullptr if it doesn't have one.
Type getType() const
Return the type of this expression.
VariableDecl * getVarDecl() const
Return the variable defined by this statement.
const Expr * getParentExpr() const
Get the parent expression of this access.
StringRef getMemberName() const
Return the name of the member being accessed.
This class represents a top-level AST module.
MutableArrayRef< Decl * > getChildren()
Return the children of this module.
SMRange getLoc() const
Return the location of this node.
Expr * getRootOpExpr() const
Return the root operation of this rewrite.
MutableArrayRef< Expr * > getOperands()
Return the operands of this operation.
MutableArrayRef< NamedAttributeDecl * > getAttributes()
Return the attributes of this operation.
MutableArrayRef< Expr * > getResultTypes()
Return the result types of this operation.
std::optional< StringRef > getName() const
Return the name of the operation, or std::nullopt if there isn't one.
const CompoundStmt * getBody() const
Return the body of this pattern.
std::optional< uint16_t > getBenefit() const
Return the benefit of this pattern if specified, or std::nullopt.
RangeType getType() const
Return the range result type of this expression.
MutableArrayRef< Expr * > getElements()
Return the element expressions of this range.
Type getElementType() const
Return the element type of this range.
MutableArrayRef< Expr * > getReplExprs()
Return the replacement values of this statement.
CompoundStmt * getRewriteBody() const
Return the compound rewrite body.
MutableArrayRef< Expr * > getElements()
Return the element expressions of this tuple.
StringRef getValue() const
Get the raw value of this expression.
MutableArrayRef< VariableDecl * > getResults()
Return the explicit results of the constraint declaration.
MutableArrayRef< VariableDecl * > getInputs()
Return the input arguments of this constraint.
MutableArrayRef< ConstraintRef > getConstraints()
Return the constraints of this variable.
Expr * getInitExpr() const
Return the initializer expression of this statement, or nullptr if there was no initializer.
Type getType() const
Return the type of the decl.
ArrayRef< OperandOrResult > getResults() const
Returns the results of this operation.
OwningOpRef< ModuleOp > codegenPDLLToMLIR(MLIRContext *mlirContext, const ast::Context &context, const llvm::SourceMgr &sourceMgr, const ast::Module &module)
Given a PDLL module, generate an MLIR PDL pattern module within the given MLIR context.
Include the generated interface declarations.
Attribute parseAttribute(llvm::StringRef attrStr, MLIRContext *context, Type type={}, size_t *numRead=nullptr, bool isKnownNullTerminated=false)
This parses a single MLIR attribute to an MLIR context if it was valid.
llvm::TypeSwitch< T, ResultT > TypeSwitch
Type parseType(llvm::StringRef typeStr, MLIRContext *context, size_t *numRead=nullptr, bool isKnownNullTerminated=false)
This parses a single MLIR type to an MLIR context if it was valid.
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs,...
StringRef getName() const
Return the raw string name.