22#include "llvm/ADT/ScopedHashTable.h"
23#include "llvm/ADT/StringExtras.h"
24#include "llvm/ADT/TypeSwitch.h"
37 CodeGen(MLIRContext *mlirContext,
const ast::Context &context,
38 const llvm::SourceMgr &sourceMgr)
39 : builder(mlirContext), odsContext(context.getODSContext()),
40 sourceMgr(sourceMgr) {
45 OwningOpRef<ModuleOp> generate(
const ast::Module &module);
49 Location genLoc(llvm::SMLoc loc);
50 Location genLoc(llvm::SMRange loc) {
return genLoc(loc.Start); }
53 Type genType(ast::Type type);
56 void gen(
const ast::Node *node);
62 void genImpl(
const ast::CompoundStmt *stmt);
63 void genImpl(
const ast::EraseStmt *stmt);
64 void genImpl(
const ast::LetStmt *stmt);
65 void genImpl(
const ast::ReplaceStmt *stmt);
66 void genImpl(
const ast::RewriteStmt *stmt);
67 void genImpl(
const ast::ReturnStmt *stmt);
73 void genImpl(
const ast::UserConstraintDecl *decl);
74 void genImpl(
const ast::UserRewriteDecl *decl);
75 void genImpl(
const ast::PatternDecl *decl);
79 SmallVector<Value> genVar(
const ast::VariableDecl *varDecl);
84 Value genNonInitializerVar(
const ast::VariableDecl *varDecl, Location loc);
88 void applyVarConstraints(
const ast::VariableDecl *varDecl,
ValueRange values);
94 Value genSingleExpr(
const ast::Expr *expr);
95 SmallVector<Value> genExpr(
const ast::Expr *expr);
96 Value genExprImpl(
const ast::AttributeExpr *expr);
97 SmallVector<Value> genExprImpl(
const ast::CallExpr *expr);
98 SmallVector<Value> genExprImpl(
const ast::DeclRefExpr *expr);
99 Value genExprImpl(
const ast::MemberAccessExpr *expr);
100 Value genExprImpl(
const ast::OperationExpr *expr);
101 Value genExprImpl(
const ast::RangeExpr *expr);
102 SmallVector<Value> genExprImpl(
const ast::TupleExpr *expr);
103 Value genExprImpl(
const ast::TypeExpr *expr);
105 SmallVector<Value> genConstraintCall(
const ast::UserConstraintDecl *decl,
107 bool isNegated =
false);
108 SmallVector<Value> genRewriteCall(
const ast::UserRewriteDecl *decl,
110 template <
typename PDLOpT,
typename T>
111 SmallVector<Value> genConstraintOrRewriteCall(
const T *decl, Location loc,
113 bool isNegated =
false);
123 using VariableMapTy =
124 llvm::ScopedHashTable<const ast::VariableDecl *, SmallVector<Value>>;
125 VariableMapTy variables;
128 const ods::Context &odsContext;
131 const llvm::SourceMgr &sourceMgr;
136 OwningOpRef<ModuleOp> mlirModule =
137 ModuleOp::create(builder, genLoc(module.
getLoc()));
138 builder.setInsertionPointToStart(mlirModule->getBody());
147Location CodeGen::genLoc(llvm::SMLoc loc) {
148 unsigned fileID = sourceMgr.FindBufferContainingLoc(loc);
150 auto [lineNo, column] = sourceMgr.getLineAndColumn(loc);
151 auto *buffer = sourceMgr.getMemoryBuffer(fileID);
154 buffer->getBufferIdentifier(), lineNo, column);
157Type CodeGen::genType(ast::Type type) {
159 .Case([&](ast::AttributeType astType) -> Type {
160 return builder.getType<pdl::AttributeType>();
162 .Case([&](ast::OperationType astType) -> Type {
163 return builder.getType<pdl::OperationType>();
165 .Case([&](ast::TypeType astType) -> Type {
166 return builder.getType<pdl::TypeType>();
168 .Case([&](ast::ValueType astType) -> Type {
169 return builder.getType<pdl::ValueType>();
171 .Case([&](ast::RangeType astType) -> Type {
176void CodeGen::gen(
const ast::Node *node) {
178 .Case<
const ast::CompoundStmt,
const ast::EraseStmt,
const ast::LetStmt,
179 const ast::ReplaceStmt,
const ast::RewriteStmt,
180 const ast::ReturnStmt,
const ast::UserConstraintDecl,
181 const ast::UserRewriteDecl,
const ast::PatternDecl>(
182 [&](
auto derivedNode) { this->genImpl(derivedNode); })
183 .Case([&](
const ast::Expr *expr) { genExpr(expr); });
190void CodeGen::genImpl(
const ast::CompoundStmt *stmt) {
191 VariableMapTy::ScopeTy varScope(variables);
192 for (
const ast::Stmt *childStmt : stmt->
getChildren())
203 pdl::RewriteOp::create(builder, loc, rootExpr, StringAttr(),
209void CodeGen::genImpl(
const ast::EraseStmt *stmt) {
210 OpBuilder::InsertionGuard insertGuard(builder);
212 Location loc = genLoc(stmt->
getLoc());
215 OpBuilder::InsertionGuard guard(builder);
217 pdl::EraseOp::create(builder, loc, rootExpr);
220void CodeGen::genImpl(
const ast::LetStmt *stmt) { genVar(stmt->
getVarDecl()); }
222void CodeGen::genImpl(
const ast::ReplaceStmt *stmt) {
223 OpBuilder::InsertionGuard insertGuard(builder);
225 Location loc = genLoc(stmt->
getLoc());
228 OpBuilder::InsertionGuard guard(builder);
231 SmallVector<Value> replValues;
233 replValues.push_back(genSingleExpr(replExpr));
237 bool usesReplOperation =
238 replValues.size() == 1 &&
239 isa<pdl::OperationType>(replValues.front().getType());
240 pdl::ReplaceOp::create(
241 builder, loc, rootExpr, usesReplOperation ? replValues[0] : Value(),
245void CodeGen::genImpl(
const ast::RewriteStmt *stmt) {
246 OpBuilder::InsertionGuard insertGuard(builder);
250 OpBuilder::InsertionGuard guard(builder);
255void CodeGen::genImpl(
const ast::ReturnStmt *stmt) {
264void CodeGen::genImpl(
const ast::UserConstraintDecl *decl) {
270void CodeGen::genImpl(
const ast::UserRewriteDecl *decl) {
276void CodeGen::genImpl(
const ast::PatternDecl *decl) {
277 const ast::Name *name = decl->
getName();
281 pdl::PatternOp pattern = pdl::PatternOp::create(
283 name ? std::optional<StringRef>(name->
getName())
284 : std::optional<StringRef>());
286 OpBuilder::InsertionGuard savedInsertPoint(builder);
287 builder.setInsertionPointToStart(pattern.getBody());
291SmallVector<Value> CodeGen::genVar(
const ast::VariableDecl *varDecl) {
292 auto it = variables.begin(varDecl);
293 if (it != variables.end())
298 SmallVector<Value> values;
299 if (
const ast::Expr *initExpr = varDecl->
getInitExpr())
300 values = genExpr(initExpr);
302 values.push_back(genNonInitializerVar(varDecl, genLoc(varDecl->
getLoc())));
305 applyVarConstraints(varDecl, values);
307 variables.insert(varDecl, values);
311Value CodeGen::genNonInitializerVar(
const ast::VariableDecl *varDecl,
314 auto getTypeConstraint = [&]() -> Value {
315 for (
const ast::ConstraintRef &constraint : varDecl->
getConstraints()) {
318 .Case<ast::AttrConstraintDecl, ast::ValueConstraintDecl,
319 ast::ValueRangeConstraintDecl>(
320 [&,
this](
auto *cst) -> Value {
321 if (
auto *typeConstraintExpr = cst->getTypeExpr())
322 return this->genSingleExpr(typeConstraintExpr);
333 ast::Type type = varDecl->
getType();
334 Type mlirType = genType(type);
335 if (isa<ast::ValueType>(type))
336 return pdl::OperandOp::create(builder, loc, mlirType, getTypeConstraint());
337 if (isa<ast::TypeType>(type))
338 return pdl::TypeOp::create(builder, loc, mlirType, TypeAttr());
339 if (isa<ast::AttributeType>(type))
340 return pdl::AttributeOp::create(builder, loc, getTypeConstraint());
341 if (ast::OperationType opType = dyn_cast<ast::OperationType>(type)) {
342 Value operands = pdl::OperandsOp::create(
343 builder, loc, pdl::RangeType::get(builder.getType<pdl::ValueType>()),
345 Value results = pdl::TypesOp::create(
346 builder, loc, pdl::RangeType::get(builder.getType<pdl::TypeType>()),
348 return pdl::OperationOp::create(builder, loc, opType.getName(), operands,
353 if (ast::RangeType rangeTy = dyn_cast<ast::RangeType>(type)) {
354 ast::Type eleTy = rangeTy.getElementType();
355 if (isa<ast::ValueType>(eleTy))
356 return pdl::OperandsOp::create(builder, loc, mlirType,
357 getTypeConstraint());
358 if (isa<ast::TypeType>(eleTy))
359 return pdl::TypesOp::create(builder, loc, mlirType,
363 llvm_unreachable(
"invalid non-initialized variable type");
366void CodeGen::applyVarConstraints(
const ast::VariableDecl *varDecl,
371 if (
const auto *userCst = dyn_cast<ast::UserConstraintDecl>(ref.constraint))
372 genConstraintCall(userCst, genLoc(ref.referenceLoc), values);
379Value CodeGen::genSingleExpr(
const ast::Expr *expr) {
381 .Case<
const ast::AttributeExpr,
const ast::MemberAccessExpr,
382 const ast::OperationExpr,
const ast::RangeExpr,
383 const ast::TypeExpr>(
384 [&](
auto derivedNode) {
return this->genExprImpl(derivedNode); })
385 .Case<const ast::CallExpr, const ast::DeclRefExpr, const ast::TupleExpr>(
386 [&](
auto derivedNode) {
387 return llvm::getSingleElement(this->genExprImpl(derivedNode));
391SmallVector<Value> CodeGen::genExpr(
const ast::Expr *expr) {
393 .Case<const ast::CallExpr, const ast::DeclRefExpr, const ast::TupleExpr>(
394 [&](
auto derivedNode) {
return this->genExprImpl(derivedNode); })
395 .Default([&](
const ast::Expr *expr) -> SmallVector<Value> {
396 return {genSingleExpr(expr)};
400Value CodeGen::genExprImpl(
const ast::AttributeExpr *expr) {
402 assert(attr &&
"invalid MLIR attribute data");
403 return pdl::AttributeOp::create(builder, genLoc(expr->
getLoc()), attr);
406SmallVector<Value> CodeGen::genExprImpl(
const ast::CallExpr *expr) {
407 Location loc = genLoc(expr->
getLoc());
408 SmallVector<Value> arguments;
410 arguments.push_back(genSingleExpr(arg));
413 auto *callableExpr = dyn_cast<ast::DeclRefExpr>(expr->
getCallableExpr());
414 assert(callableExpr &&
"unhandled CallExpr callable");
417 const ast::Decl *callable = callableExpr->getDecl();
418 if (
const auto *decl = dyn_cast<ast::UserConstraintDecl>(callable))
419 return genConstraintCall(decl, loc, arguments, expr->
getIsNegated());
420 if (
const auto *decl = dyn_cast<ast::UserRewriteDecl>(callable))
421 return genRewriteCall(decl, loc, arguments);
422 llvm_unreachable(
"unhandled CallExpr callable");
425SmallVector<Value> CodeGen::genExprImpl(
const ast::DeclRefExpr *expr) {
426 if (
const auto *varDecl = dyn_cast<ast::VariableDecl>(expr->
getDecl()))
427 return genVar(varDecl);
428 llvm_unreachable(
"unknown decl reference expression");
431Value CodeGen::genExprImpl(
const ast::MemberAccessExpr *expr) {
432 Location loc = genLoc(expr->
getLoc());
434 SmallVector<Value> parentExprs = genExpr(expr->
getParentExpr());
438 if (ast::OperationType opType = dyn_cast<ast::OperationType>(parentType)) {
439 if (isa<ast::AllResultsMemberAccessExpr>(expr)) {
440 Type mlirType = genType(expr->
getType());
441 if (isa<pdl::ValueType>(mlirType))
442 return pdl::ResultOp::create(builder, loc, mlirType, parentExprs[0],
443 builder.getI32IntegerAttr(0));
444 return pdl::ResultsOp::create(builder, loc, mlirType, parentExprs[0]);
447 const ods::Operation *odsOp = opType.getODSOperation();
449 assert(llvm::isDigit(name[0]) &&
450 "unregistered op only allows numeric indexing");
451 unsigned resultIndex;
452 name.getAsInteger(10, resultIndex);
453 IntegerAttr index = builder.getI32IntegerAttr(resultIndex);
454 return pdl::ResultOp::create(builder, loc, genType(expr->
getType()),
455 parentExprs[0], index);
459 ArrayRef<ods::OperandOrResult> results = odsOp->
getResults();
460 unsigned resultIndex = results.size();
461 if (llvm::isDigit(name[0])) {
462 name.getAsInteger(10, resultIndex);
464 auto findFn = [&](
const ods::OperandOrResult &
result) {
465 return result.getName() == name;
467 resultIndex = llvm::find_if(results, findFn) - results.begin();
469 assert(resultIndex < results.size() &&
"invalid result index");
472 IntegerAttr index = builder.getI32IntegerAttr(resultIndex);
473 return pdl::ResultsOp::create(builder, loc, genType(expr->
getType()),
474 parentExprs[0], index);
478 if (
auto tupleType = dyn_cast<ast::TupleType>(parentType)) {
479 auto elementNames = tupleType.getElementNames();
483 if (llvm::isDigit(name[0]))
484 name.getAsInteger(10, index);
486 index = llvm::find(elementNames, name) - elementNames.begin();
488 assert(index < parentExprs.size() &&
"invalid result index");
489 return parentExprs[index];
492 llvm_unreachable(
"unhandled member access expression");
495Value CodeGen::genExprImpl(
const ast::OperationExpr *expr) {
496 Location loc = genLoc(expr->
getLoc());
497 std::optional<StringRef> opName = expr->
getName();
500 SmallVector<Value> operands;
501 for (
const ast::Expr *operand : expr->
getOperands())
502 operands.push_back(genSingleExpr(operand));
505 SmallVector<StringRef> attrNames;
506 SmallVector<Value> attrValues;
507 for (
const ast::NamedAttributeDecl *attr : expr->
getAttributes()) {
508 attrNames.push_back(attr->getName().getName());
509 attrValues.push_back(genSingleExpr(attr->getValue()));
513 SmallVector<Value> results;
515 results.push_back(genSingleExpr(
result));
517 return pdl::OperationOp::create(builder, loc, opName, operands, attrNames,
518 attrValues, results);
521Value CodeGen::genExprImpl(
const ast::RangeExpr *expr) {
522 SmallVector<Value> elements;
523 for (
const ast::Expr *element : expr->
getElements())
524 llvm::append_range(elements, genExpr(element));
526 return pdl::RangeOp::create(builder, genLoc(expr->
getLoc()),
527 genType(expr->
getType()), elements);
530SmallVector<Value> CodeGen::genExprImpl(
const ast::TupleExpr *expr) {
531 SmallVector<Value> elements;
532 for (
const ast::Expr *element : expr->
getElements())
533 elements.push_back(genSingleExpr(element));
537Value CodeGen::genExprImpl(
const ast::TypeExpr *expr) {
539 assert(type &&
"invalid MLIR type data");
540 return pdl::TypeOp::create(builder, genLoc(expr->
getLoc()),
541 builder.getType<pdl::TypeType>(),
542 TypeAttr::get(type));
546CodeGen::genConstraintCall(
const ast::UserConstraintDecl *decl, Location loc,
549 for (
auto it : llvm::zip(decl->
getInputs(), inputs))
550 applyVarConstraints(std::get<0>(it), std::get<1>(it));
553 SmallVector<Value> results =
554 genConstraintOrRewriteCall<pdl::ApplyNativeConstraintOp>(
555 decl, loc, inputs, isNegated);
558 for (
auto it : llvm::zip(decl->
getResults(), results))
559 applyVarConstraints(std::get<0>(it), std::get<1>(it));
563SmallVector<Value> CodeGen::genRewriteCall(
const ast::UserRewriteDecl *decl,
565 return genConstraintOrRewriteCall<pdl::ApplyNativeRewriteOp>(decl, loc,
569template <
typename PDLOpT,
typename T>
571CodeGen::genConstraintOrRewriteCall(
const T *decl, Location loc,
573 const ast::CompoundStmt *cstBody = decl->getBody();
577 ast::Type declResultType = decl->getResultType();
578 SmallVector<Type> resultTypes;
579 if (ast::TupleType tupleType = dyn_cast<ast::TupleType>(declResultType)) {
580 for (ast::Type type : tupleType.getElementTypes())
581 resultTypes.push_back(genType(type));
583 resultTypes.push_back(genType(declResultType));
585 PDLOpT pdlOp = PDLOpT::create(builder, loc, resultTypes,
586 decl->getName().getName(), inputs);
587 if (isNegated && std::is_same_v<PDLOpT, pdl::ApplyNativeConstraintOp>)
588 cast<pdl::ApplyNativeConstraintOp>(pdlOp).setIsNegated(
true);
589 return pdlOp->getResults();
593 VariableMapTy::ScopeTy varScope(variables);
598 for (
auto it : llvm::zip(inputs, decl->getInputs()))
599 variables.insert(std::get<1>(it), {std::get<0>(it)});
606 return SmallVector<Value>();
607 auto *returnStmt = dyn_cast<ast::ReturnStmt>(cstBody->
getChildren().back());
609 return SmallVector<Value>();
612 return genExpr(returnStmt->getResultExpr());
621 const llvm::SourceMgr &sourceMgr,
const ast::Module &module) {
622 CodeGen codegen(mlirContext, context, sourceMgr);
624 if (failed(
verify(*mlirModule)))
static void checkAndNestUnderRewriteOp(OpBuilder &builder, Value rootExpr, Location loc)
If the given builder is nested under a PDL PatternOp, build a rewrite operation and update the builde...
static void rewrite(DataFlowSolver &solver, MLIRContext *context, MutableArrayRef< Region > initialRegions)
Rewrite the given regions using the computing analysis.
Operation * getParentOp()
Returns the closest surrounding operation that contains this block.
static FileLineColLoc get(StringAttr filename, unsigned line, unsigned column)
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
void loadDialect()
Load a dialect in the context.
This class helps build Operations.
Block * createBlock(Region *parent, Region::iterator insertPt={}, TypeRange argTypes={}, ArrayRef< Location > locs={})
Add new block with 'argTypes' arguments and set the insertion point to the end of it.
Block * getInsertionBlock() const
Return the block the current insertion point belongs to.
This class acts as an owning reference to an op, and will automatically destroy the held op on destru...
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
StringRef getValue() const
Get the raw value of this expression.
Expr * getCallableExpr() const
Return the callable of this call.
MutableArrayRef< Expr * > getArguments()
Return the arguments of this call.
bool getIsNegated() const
Returns whether the result of this call is to be negated.
MutableArrayRef< Stmt * > getChildren()
Return the children of this compound statement.
This class represents the main context of the PDLL AST.
Decl * getDecl() const
Get the decl referenced by this expression.
const Name * getName() const
Return the name of the decl, or nullptr if it doesn't have one.
Type getType() const
Return the type of this expression.
VariableDecl * getVarDecl() const
Return the variable defined by this statement.
const Expr * getParentExpr() const
Get the parent expression of this access.
StringRef getMemberName() const
Return the name of the member being accessed.
This class represents a top-level AST module.
MutableArrayRef< Decl * > getChildren()
Return the children of this module.
SMRange getLoc() const
Return the location of this node.
Expr * getRootOpExpr() const
Return the root operation of this rewrite.
MutableArrayRef< Expr * > getOperands()
Return the operands of this operation.
MutableArrayRef< NamedAttributeDecl * > getAttributes()
Return the attributes of this operation.
MutableArrayRef< Expr * > getResultTypes()
Return the result types of this operation.
std::optional< StringRef > getName() const
Return the name of the operation, or std::nullopt if there isn't one.
const CompoundStmt * getBody() const
Return the body of this pattern.
std::optional< uint16_t > getBenefit() const
Return the benefit of this pattern if specified, or std::nullopt.
RangeType getType() const
Return the range result type of this expression.
MutableArrayRef< Expr * > getElements()
Return the element expressions of this range.
Type getElementType() const
Return the element type of this range.
MutableArrayRef< Expr * > getReplExprs()
Return the replacement values of this statement.
CompoundStmt * getRewriteBody() const
Return the compound rewrite body.
MutableArrayRef< Expr * > getElements()
Return the element expressions of this tuple.
StringRef getValue() const
Get the raw value of this expression.
MutableArrayRef< VariableDecl * > getResults()
Return the explicit results of the constraint declaration.
MutableArrayRef< VariableDecl * > getInputs()
Return the input arguments of this constraint.
MutableArrayRef< ConstraintRef > getConstraints()
Return the constraints of this variable.
Expr * getInitExpr() const
Return the initializer expression of this statement, or nullptr if there was no initializer.
Type getType() const
Return the type of the decl.
ArrayRef< OperandOrResult > getResults() const
Returns the results of this operation.
OwningOpRef< ModuleOp > codegenPDLLToMLIR(MLIRContext *mlirContext, const ast::Context &context, const llvm::SourceMgr &sourceMgr, const ast::Module &module)
Given a PDLL module, generate an MLIR PDL pattern module within the given MLIR context.
Include the generated interface declarations.
Attribute parseAttribute(llvm::StringRef attrStr, MLIRContext *context, Type type={}, size_t *numRead=nullptr, bool isKnownNullTerminated=false)
This parses a single MLIR attribute to an MLIR context if it was valid.
llvm::TypeSwitch< T, ResultT > TypeSwitch
Type parseType(llvm::StringRef typeStr, MLIRContext *context, size_t *numRead=nullptr, bool isKnownNullTerminated=false)
This parses a single MLIR type to an MLIR context if it was valid.
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs,...
StringRef getName() const
Return the raw string name.