22 #include "llvm/ADT/ScopedHashTable.h"
23 #include "llvm/ADT/StringExtras.h"
24 #include "llvm/ADT/TypeSwitch.h"
38 const llvm::SourceMgr &sourceMgr)
39 : builder(mlirContext), odsContext(context.getODSContext()),
40 sourceMgr(sourceMgr) {
50 Location genLoc(llvm::SMRange loc) {
return genLoc(loc.Start); }
107 bool isNegated =
false);
110 template <
typename PDLOpT,
typename T>
113 bool isNegated =
false);
123 using VariableMapTy =
124 llvm::ScopedHashTable<const ast::VariableDecl *, SmallVector<Value>>;
125 VariableMapTy variables;
131 const llvm::SourceMgr &sourceMgr;
137 builder.create<ModuleOp>(genLoc(module.getLoc()));
138 builder.setInsertionPointToStart(mlirModule->getBody());
147 Location CodeGen::genLoc(llvm::SMLoc loc) {
148 unsigned fileID = sourceMgr.FindBufferContainingLoc(loc);
152 auto &bufferInfo = sourceMgr.getBufferInfo(fileID);
153 unsigned lineNo = bufferInfo.getLineNumber(loc.getPointer());
155 (loc.getPointer() - bufferInfo.getPointerForLineNumber(lineNo)) + 1;
156 auto *buffer = sourceMgr.getMemoryBuffer(fileID);
159 buffer->getBufferIdentifier(), lineNo, column);
165 return builder.getType<pdl::AttributeType>();
168 return builder.getType<pdl::OperationType>();
171 return builder.getType<pdl::TypeType>();
174 return builder.getType<pdl::ValueType>();
181 void CodeGen::gen(
const ast::Node *node) {
187 [&](
auto derivedNode) { this->genImpl(derivedNode); })
188 .Case([&](
const ast::Expr *expr) { genExpr(expr); });
196 VariableMapTy::ScopeTy varScope(variables);
208 builder.
create<pdl::RewriteOp>(loc, rootExpr, StringAttr(),
216 Value rootExpr = genSingleExpr(stmt->getRootOpExpr());
217 Location loc = genLoc(stmt->getLoc());
222 builder.create<pdl::EraseOp>(loc, rootExpr);
229 Value rootExpr = genSingleExpr(stmt->getRootOpExpr());
230 Location loc = genLoc(stmt->getLoc());
238 replValues.push_back(genSingleExpr(replExpr));
242 bool usesReplOperation =
243 replValues.size() == 1 &&
244 isa<pdl::OperationType>(replValues.front().getType());
245 builder.create<pdl::ReplaceOp>(
246 loc, rootExpr, usesReplOperation ? replValues[0] :
Value(),
252 Value rootExpr = genSingleExpr(stmt->getRootOpExpr());
286 pdl::PatternOp pattern = builder.
create<pdl::PatternOp>(
288 name ? std::optional<StringRef>(name->
getName())
289 : std::optional<StringRef>());
292 builder.setInsertionPointToStart(pattern.getBody());
297 auto it = variables.begin(varDecl);
298 if (it != variables.end())
305 values = genExpr(initExpr);
307 values.push_back(genNonInitializerVar(varDecl, genLoc(varDecl->getLoc())));
310 applyVarConstraints(varDecl, values);
312 variables.insert(varDecl, values);
319 auto getTypeConstraint = [&]() ->
Value {
325 [&,
this](
auto *cst) ->
Value {
326 if (
auto *typeConstraintExpr = cst->getTypeExpr())
327 return this->genSingleExpr(typeConstraintExpr);
339 Type mlirType = genType(type);
340 if (isa<ast::ValueType>(type))
341 return builder.create<pdl::OperandOp>(loc, mlirType, getTypeConstraint());
342 if (isa<ast::TypeType>(type))
343 return builder.create<pdl::TypeOp>(loc, mlirType, TypeAttr());
344 if (isa<ast::AttributeType>(type))
345 return builder.create<pdl::AttributeOp>(loc, getTypeConstraint());
347 Value operands = builder.create<pdl::OperandsOp>(
350 Value results = builder.create<pdl::TypesOp>(
353 return builder.create<pdl::OperationOp>(
354 loc, opType.getName(), operands, std::nullopt,
ValueRange(), results);
358 ast::Type eleTy = rangeTy.getElementType();
359 if (isa<ast::ValueType>(eleTy))
360 return builder.create<pdl::OperandsOp>(loc, mlirType,
361 getTypeConstraint());
362 if (isa<ast::TypeType>(eleTy))
363 return builder.create<pdl::TypesOp>(loc, mlirType, ArrayAttr());
366 llvm_unreachable(
"invalid non-initialized variable type");
374 if (
const auto *userCst = dyn_cast<ast::UserConstraintDecl>(ref.constraint))
375 genConstraintCall(userCst, genLoc(ref.referenceLoc), values);
387 [&](
auto derivedNode) {
return this->genExprImpl(derivedNode); })
388 .Case<const ast::CallExpr, const ast::DeclRefExpr, const ast::TupleExpr>(
389 [&](
auto derivedNode) {
391 assert(results.size() == 1 &&
"expected single expression result");
398 .Case<const ast::CallExpr, const ast::DeclRefExpr, const ast::TupleExpr>(
399 [&](
auto derivedNode) {
return this->genExprImpl(derivedNode); })
401 return {genSingleExpr(expr)};
407 assert(attr &&
"invalid MLIR attribute data");
408 return builder.create<pdl::AttributeOp>(genLoc(expr->getLoc()), attr);
412 Location loc = genLoc(expr->getLoc());
415 arguments.push_back(genSingleExpr(arg));
418 auto *callableExpr = dyn_cast<ast::DeclRefExpr>(expr->
getCallableExpr());
419 assert(callableExpr &&
"unhandled CallExpr callable");
422 const ast::Decl *callable = callableExpr->getDecl();
423 if (
const auto *decl = dyn_cast<ast::UserConstraintDecl>(callable))
424 return genConstraintCall(decl, loc, arguments, expr->
getIsNegated());
425 if (
const auto *decl = dyn_cast<ast::UserRewriteDecl>(callable))
426 return genRewriteCall(decl, loc, arguments);
427 llvm_unreachable(
"unhandled CallExpr callable");
431 if (
const auto *varDecl = dyn_cast<ast::VariableDecl>(expr->
getDecl()))
432 return genVar(varDecl);
433 llvm_unreachable(
"unknown decl reference expression");
437 Location loc = genLoc(expr->getLoc());
444 if (isa<ast::AllResultsMemberAccessExpr>(expr)) {
445 Type mlirType = genType(expr->getType());
446 if (isa<pdl::ValueType>(mlirType))
447 return builder.create<pdl::ResultOp>(loc, mlirType, parentExprs[0],
448 builder.getI32IntegerAttr(0));
449 return builder.create<pdl::ResultsOp>(loc, mlirType, parentExprs[0]);
454 assert(llvm::isDigit(name[0]) &&
455 "unregistered op only allows numeric indexing");
456 unsigned resultIndex;
457 name.getAsInteger(10, resultIndex);
458 IntegerAttr index = builder.getI32IntegerAttr(resultIndex);
459 return builder.create<pdl::ResultOp>(loc, genType(expr->getType()),
460 parentExprs[0], index);
465 unsigned resultIndex = results.size();
466 if (llvm::isDigit(name[0])) {
467 name.getAsInteger(10, resultIndex);
470 return result.getName() == name;
472 resultIndex = llvm::find_if(results, findFn) - results.begin();
474 assert(resultIndex < results.size() &&
"invalid result index");
477 IntegerAttr index = builder.getI32IntegerAttr(resultIndex);
478 return builder.create<pdl::ResultsOp>(loc, genType(expr->getType()),
479 parentExprs[0], index);
483 if (
auto tupleType = dyn_cast<ast::TupleType>(parentType)) {
484 auto elementNames = tupleType.getElementNames();
488 if (llvm::isDigit(name[0]))
489 name.getAsInteger(10, index);
491 index = llvm::find(elementNames, name) - elementNames.begin();
493 assert(index < parentExprs.size() &&
"invalid result index");
494 return parentExprs[index];
497 llvm_unreachable(
"unhandled member access expression");
501 Location loc = genLoc(expr->getLoc());
502 std::optional<StringRef> opName = expr->
getName();
507 operands.push_back(genSingleExpr(operand));
513 attrNames.push_back(attr->getName().getName());
514 attrValues.push_back(genSingleExpr(attr->getValue()));
520 results.push_back(genSingleExpr(result));
522 return builder.create<pdl::OperationOp>(loc, opName, operands, attrNames,
523 attrValues, results);
529 llvm::append_range(elements, genExpr(element));
531 return builder.create<pdl::RangeOp>(genLoc(expr->getLoc()),
532 genType(expr->
getType()), elements);
538 elements.push_back(genSingleExpr(element));
544 assert(type &&
"invalid MLIR type data");
545 return builder.create<pdl::TypeOp>(genLoc(expr->getLoc()),
546 builder.getType<pdl::TypeType>(),
554 for (
auto it : llvm::zip(decl->
getInputs(), inputs))
555 applyVarConstraints(std::get<0>(it), std::get<1>(it));
559 genConstraintOrRewriteCall<pdl::ApplyNativeConstraintOp>(
560 decl, loc, inputs, isNegated);
563 for (
auto it : llvm::zip(decl->
getResults(), results))
564 applyVarConstraints(std::get<0>(it), std::get<1>(it));
570 return genConstraintOrRewriteCall<pdl::ApplyNativeRewriteOp>(decl, loc,
574 template <
typename PDLOpT,
typename T>
576 CodeGen::genConstraintOrRewriteCall(
const T *decl,
Location loc,
582 ast::Type declResultType = decl->getResultType();
584 if (
ast::TupleType tupleType = dyn_cast<ast::TupleType>(declResultType)) {
585 for (
ast::Type type : tupleType.getElementTypes())
586 resultTypes.push_back(genType(type));
588 resultTypes.push_back(genType(declResultType));
590 PDLOpT pdlOp = builder.create<PDLOpT>(loc, resultTypes,
591 decl->getName().getName(), inputs);
592 if (isNegated && std::is_same_v<PDLOpT, pdl::ApplyNativeConstraintOp>)
593 cast<pdl::ApplyNativeConstraintOp>(pdlOp).setIsNegated(
true);
594 return pdlOp->getResults();
598 VariableMapTy::ScopeTy varScope(variables);
603 for (
auto it : llvm::zip(inputs, decl->getInputs()))
604 variables.insert(std::get<1>(it), {std::get<0>(it)});
612 auto *returnStmt = dyn_cast<ast::ReturnStmt>(cstBody->
getChildren().back());
617 return genExpr(returnStmt->getResultExpr());
626 const llvm::SourceMgr &sourceMgr,
const ast::Module &module) {
627 CodeGen codegen(mlirContext, context, sourceMgr);
629 if (failed(
verify(*mlirModule)))
static void checkAndNestUnderRewriteOp(OpBuilder &builder, Value rootExpr, Location loc)
If the given builder is nested under a PDL PatternOp, build a rewrite operation and update the builde...
static void rewrite(DataFlowSolver &solver, MLIRContext *context, MutableArrayRef< Region > initialRegions)
Rewrite the given regions using the computing analysis.
Attributes are known-constant values of operations.
Operation * getParentOp()
Returns the closest surrounding operation that contains this block.
static FileLineColLoc get(StringAttr filename, unsigned line, unsigned column)
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
void loadDialect()
Load a dialect in the context.
RAII guard to reset the insertion point of the builder when destroyed.
This class helps build Operations.
Block * createBlock(Region *parent, Region::iterator insertPt={}, TypeRange argTypes=std::nullopt, ArrayRef< Location > locs=std::nullopt)
Add new block with 'argTypes' arguments and set the insertion point to the end of it.
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Block * getInsertionBlock() const
Return the block the current insertion point belongs to.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
The class represents an Attribute constraint, and constrains a variable to be an Attribute.
This expression represents a literal MLIR Attribute, and contains the textual assembly format of that...
StringRef getValue() const
Get the raw value of this expression.
This class represents a PDLL type that corresponds to an mlir::Attribute.
This expression represents a call to a decl, such as a UserConstraintDecl/UserRewriteDecl.
Expr * getCallableExpr() const
Return the callable of this call.
MutableArrayRef< Expr * > getArguments()
Return the arguments of this call.
bool getIsNegated() const
Returns whether the result of this call is to be negated.
This statement represents a compound statement, which contains a collection of other statements.
MutableArrayRef< Stmt * > getChildren()
Return the children of this compound statement.
This class represents the main context of the PDLL AST.
This expression represents a reference to a Decl node.
Decl * getDecl() const
Get the decl referenced by this expression.
This class represents the base Decl node.
This statement represents the erase statement in PDLL.
This class represents a base AST Expression node.
Type getType() const
Return the type of this expression.
This statement represents a let statement in PDLL.
VariableDecl * getVarDecl() const
Return the variable defined by this statement.
This expression represents a named member or field access of a given parent expression.
const Expr * getParentExpr() const
Get the parent expression of this access.
StringRef getMemberName() const
Return the name of the member being accessed.
This class represents a top-level AST module.
MutableArrayRef< Decl * > getChildren()
Return the children of this module.
This Decl represents a NamedAttribute, and contains a string name and attribute value.
This class represents a base AST node.
This expression represents the structural form of an MLIR Operation.
MutableArrayRef< Expr * > getResultTypes()
Return the result types of this operation.
MutableArrayRef< NamedAttributeDecl * > getAttributes()
Return the attributes of this operation.
MutableArrayRef< Expr * > getOperands()
Return the operands of this operation.
std::optional< StringRef > getName() const
Return the name of the operation, or std::nullopt if there isn't one.
This class represents a PDLL type that corresponds to an mlir::Operation.
This Decl represents a single Pattern.
const CompoundStmt * getBody() const
Return the body of this pattern.
static PatternDecl * create(Context &ctx, SMRange location, const Name *name, std::optional< uint16_t > benefit, bool hasBoundedRecursion, const CompoundStmt *body)
std::optional< uint16_t > getBenefit() const
Return the benefit of this pattern if specified, or std::nullopt.
This expression builds a range from a set of element values (which may be ranges themselves).
MutableArrayRef< Expr * > getElements()
Return the element expressions of this range.
RangeType getType() const
Return the range result type of this expression.
This class represents a PDLL type that corresponds to a range of elements with a given element type.
Type getElementType() const
Return the element type of this range.
This statement represents the replace statement in PDLL.
MutableArrayRef< Expr * > getReplExprs()
Return the replacement values of this statement.
This statement represents a return from a "callable" like decl, e.g.
This statement represents an operation rewrite that contains a block of nested rewrite commands.
CompoundStmt * getRewriteBody() const
Return the compound rewrite body.
This class represents a base AST Statement node.
This expression builds a tuple from a set of element values.
MutableArrayRef< Expr * > getElements()
Return the element expressions of this tuple.
This class represents a PDLL tuple type, i.e.
This expression represents a literal MLIR Type, and contains the textual assembly format of that type...
StringRef getValue() const
Get the raw value of this expression.
This class represents a PDLL type that corresponds to an mlir::Type.
This decl represents a user defined constraint.
MutableArrayRef< VariableDecl * > getInputs()
Return the input arguments of this constraint.
MutableArrayRef< VariableDecl * > getResults()
Return the explicit results of the constraint declaration.
This decl represents a user defined rewrite.
The class represents a Value constraint, and constrains a variable to be a Value.
The class represents a ValueRange constraint, and constrains a variable to be a ValueRange.
This class represents a PDLL type that corresponds to an mlir::Value.
This Decl represents the definition of a PDLL variable.
Expr * getInitExpr() const
Return the initializer expression of this statement, or nullptr if there was no initializer.
MutableArrayRef< ConstraintRef > getConstraints()
Return the constraints of this variable.
Type getType() const
Return the type of the decl.
This class contains all of the registered ODS operation classes.
This class provides an ODS representation of a specific operation operand or result.
This class provides an ODS representation of a specific operation.
ArrayRef< OperandOrResult > getResults() const
Returns the results of this operation.
OwningOpRef< ModuleOp > codegenPDLLToMLIR(MLIRContext *mlirContext, const ast::Context &context, const llvm::SourceMgr &sourceMgr, const ast::Module &module)
Given a PDLL module, generate an MLIR PDL pattern module within the given MLIR context.
Include the generated interface declarations.
Attribute parseAttribute(llvm::StringRef attrStr, MLIRContext *context, Type type={}, size_t *numRead=nullptr, bool isKnownNullTerminated=false)
This parses a single MLIR attribute to an MLIR context if it was valid.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
Type parseType(llvm::StringRef typeStr, MLIRContext *context, size_t *numRead=nullptr, bool isKnownNullTerminated=false)
This parses a single MLIR type to an MLIR context if it was valid.
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs,...
This class represents a reference to a constraint, and contains a constraint and the location of the ...
This class provides a convenient API for interacting with source names.
StringRef getName() const
Return the raw string name.