22 #include "llvm/ADT/ScopedHashTable.h"
23 #include "llvm/ADT/StringExtras.h"
24 #include "llvm/ADT/TypeSwitch.h"
38 const llvm::SourceMgr &sourceMgr)
39 : builder(mlirContext), odsContext(context.getODSContext()),
40 sourceMgr(sourceMgr) {
50 Location genLoc(llvm::SMRange loc) {
return genLoc(loc.Start); }
107 bool isNegated =
false);
110 template <
typename PDLOpT,
typename T>
113 bool isNegated =
false);
123 using VariableMapTy =
124 llvm::ScopedHashTable<const ast::VariableDecl *, SmallVector<Value>>;
125 VariableMapTy variables;
131 const llvm::SourceMgr &sourceMgr;
137 builder.create<ModuleOp>(genLoc(module.getLoc()));
138 builder.setInsertionPointToStart(mlirModule->getBody());
147 Location CodeGen::genLoc(llvm::SMLoc loc) {
148 unsigned fileID = sourceMgr.FindBufferContainingLoc(loc);
152 auto &bufferInfo = sourceMgr.getBufferInfo(fileID);
153 unsigned lineNo = bufferInfo.getLineNumber(loc.getPointer());
155 (loc.getPointer() - bufferInfo.getPointerForLineNumber(lineNo)) + 1;
156 auto *buffer = sourceMgr.getMemoryBuffer(fileID);
159 buffer->getBufferIdentifier(), lineNo, column);
165 return builder.getType<pdl::AttributeType>();
168 return builder.getType<pdl::OperationType>();
171 return builder.getType<pdl::TypeType>();
174 return builder.getType<pdl::ValueType>();
181 void CodeGen::gen(
const ast::Node *node) {
187 [&](
auto derivedNode) { this->genImpl(derivedNode); })
188 .Case([&](
const ast::Expr *expr) { genExpr(expr); });
196 VariableMapTy::ScopeTy varScope(variables);
208 builder.
create<pdl::RewriteOp>(loc, rootExpr, StringAttr(),
216 Value rootExpr = genSingleExpr(stmt->getRootOpExpr());
217 Location loc = genLoc(stmt->getLoc());
222 builder.create<pdl::EraseOp>(loc, rootExpr);
229 Value rootExpr = genSingleExpr(stmt->getRootOpExpr());
230 Location loc = genLoc(stmt->getLoc());
238 replValues.push_back(genSingleExpr(replExpr));
242 bool usesReplOperation =
243 replValues.size() == 1 &&
244 isa<pdl::OperationType>(replValues.front().getType());
245 builder.create<pdl::ReplaceOp>(
246 loc, rootExpr, usesReplOperation ? replValues[0] :
Value(),
252 Value rootExpr = genSingleExpr(stmt->getRootOpExpr());
286 pdl::PatternOp pattern = builder.
create<pdl::PatternOp>(
288 name ? std::optional<StringRef>(name->
getName())
289 : std::optional<StringRef>());
292 builder.setInsertionPointToStart(pattern.getBody());
297 auto it = variables.begin(varDecl);
298 if (it != variables.end())
305 values = genExpr(initExpr);
307 values.push_back(genNonInitializerVar(varDecl, genLoc(varDecl->getLoc())));
310 applyVarConstraints(varDecl, values);
312 variables.insert(varDecl, values);
319 auto getTypeConstraint = [&]() ->
Value {
325 [&,
this](
auto *cst) ->
Value {
326 if (
auto *typeConstraintExpr = cst->getTypeExpr())
327 return this->genSingleExpr(typeConstraintExpr);
339 Type mlirType = genType(type);
340 if (isa<ast::ValueType>(type))
341 return builder.create<pdl::OperandOp>(loc, mlirType, getTypeConstraint());
342 if (isa<ast::TypeType>(type))
343 return builder.create<pdl::TypeOp>(loc, mlirType, TypeAttr());
344 if (isa<ast::AttributeType>(type))
345 return builder.create<pdl::AttributeOp>(loc, getTypeConstraint());
347 Value operands = builder.create<pdl::OperandsOp>(
350 Value results = builder.create<pdl::TypesOp>(
353 return builder.create<pdl::OperationOp>(
354 loc, opType.getName(), operands, std::nullopt,
ValueRange(), results);
358 ast::Type eleTy = rangeTy.getElementType();
359 if (isa<ast::ValueType>(eleTy))
360 return builder.create<pdl::OperandsOp>(loc, mlirType,
361 getTypeConstraint());
362 if (isa<ast::TypeType>(eleTy))
363 return builder.create<pdl::TypesOp>(loc, mlirType, ArrayAttr());
366 llvm_unreachable(
"invalid non-initialized variable type");
374 if (
const auto *userCst = dyn_cast<ast::UserConstraintDecl>(ref.constraint))
375 genConstraintCall(userCst, genLoc(ref.referenceLoc), values);
387 [&](
auto derivedNode) {
return this->genExprImpl(derivedNode); })
388 .Case<const ast::CallExpr, const ast::DeclRefExpr, const ast::TupleExpr>(
389 [&](
auto derivedNode) {
390 return llvm::getSingleElement(this->genExprImpl(derivedNode));
396 .Case<const ast::CallExpr, const ast::DeclRefExpr, const ast::TupleExpr>(
397 [&](
auto derivedNode) {
return this->genExprImpl(derivedNode); })
399 return {genSingleExpr(expr)};
405 assert(attr &&
"invalid MLIR attribute data");
406 return builder.create<pdl::AttributeOp>(genLoc(expr->getLoc()), attr);
410 Location loc = genLoc(expr->getLoc());
413 arguments.push_back(genSingleExpr(arg));
416 auto *callableExpr = dyn_cast<ast::DeclRefExpr>(expr->
getCallableExpr());
417 assert(callableExpr &&
"unhandled CallExpr callable");
420 const ast::Decl *callable = callableExpr->getDecl();
421 if (
const auto *decl = dyn_cast<ast::UserConstraintDecl>(callable))
422 return genConstraintCall(decl, loc, arguments, expr->
getIsNegated());
423 if (
const auto *decl = dyn_cast<ast::UserRewriteDecl>(callable))
424 return genRewriteCall(decl, loc, arguments);
425 llvm_unreachable(
"unhandled CallExpr callable");
429 if (
const auto *varDecl = dyn_cast<ast::VariableDecl>(expr->
getDecl()))
430 return genVar(varDecl);
431 llvm_unreachable(
"unknown decl reference expression");
435 Location loc = genLoc(expr->getLoc());
442 if (isa<ast::AllResultsMemberAccessExpr>(expr)) {
443 Type mlirType = genType(expr->getType());
444 if (isa<pdl::ValueType>(mlirType))
445 return builder.create<pdl::ResultOp>(loc, mlirType, parentExprs[0],
446 builder.getI32IntegerAttr(0));
447 return builder.create<pdl::ResultsOp>(loc, mlirType, parentExprs[0]);
452 assert(llvm::isDigit(name[0]) &&
453 "unregistered op only allows numeric indexing");
454 unsigned resultIndex;
455 name.getAsInteger(10, resultIndex);
456 IntegerAttr index = builder.getI32IntegerAttr(resultIndex);
457 return builder.create<pdl::ResultOp>(loc, genType(expr->getType()),
458 parentExprs[0], index);
463 unsigned resultIndex = results.size();
464 if (llvm::isDigit(name[0])) {
465 name.getAsInteger(10, resultIndex);
468 return result.getName() == name;
470 resultIndex = llvm::find_if(results, findFn) - results.begin();
472 assert(resultIndex < results.size() &&
"invalid result index");
475 IntegerAttr index = builder.getI32IntegerAttr(resultIndex);
476 return builder.create<pdl::ResultsOp>(loc, genType(expr->getType()),
477 parentExprs[0], index);
481 if (
auto tupleType = dyn_cast<ast::TupleType>(parentType)) {
482 auto elementNames = tupleType.getElementNames();
486 if (llvm::isDigit(name[0]))
487 name.getAsInteger(10, index);
489 index = llvm::find(elementNames, name) - elementNames.begin();
491 assert(index < parentExprs.size() &&
"invalid result index");
492 return parentExprs[index];
495 llvm_unreachable(
"unhandled member access expression");
499 Location loc = genLoc(expr->getLoc());
500 std::optional<StringRef> opName = expr->
getName();
505 operands.push_back(genSingleExpr(operand));
511 attrNames.push_back(attr->getName().getName());
512 attrValues.push_back(genSingleExpr(attr->getValue()));
518 results.push_back(genSingleExpr(result));
520 return builder.create<pdl::OperationOp>(loc, opName, operands, attrNames,
521 attrValues, results);
527 llvm::append_range(elements, genExpr(element));
529 return builder.create<pdl::RangeOp>(genLoc(expr->getLoc()),
530 genType(expr->
getType()), elements);
536 elements.push_back(genSingleExpr(element));
542 assert(type &&
"invalid MLIR type data");
543 return builder.create<pdl::TypeOp>(genLoc(expr->getLoc()),
544 builder.getType<pdl::TypeType>(),
552 for (
auto it : llvm::zip(decl->
getInputs(), inputs))
553 applyVarConstraints(std::get<0>(it), std::get<1>(it));
557 genConstraintOrRewriteCall<pdl::ApplyNativeConstraintOp>(
558 decl, loc, inputs, isNegated);
561 for (
auto it : llvm::zip(decl->
getResults(), results))
562 applyVarConstraints(std::get<0>(it), std::get<1>(it));
568 return genConstraintOrRewriteCall<pdl::ApplyNativeRewriteOp>(decl, loc,
572 template <
typename PDLOpT,
typename T>
574 CodeGen::genConstraintOrRewriteCall(
const T *decl,
Location loc,
580 ast::Type declResultType = decl->getResultType();
582 if (
ast::TupleType tupleType = dyn_cast<ast::TupleType>(declResultType)) {
583 for (
ast::Type type : tupleType.getElementTypes())
584 resultTypes.push_back(genType(type));
586 resultTypes.push_back(genType(declResultType));
588 PDLOpT pdlOp = builder.create<PDLOpT>(loc, resultTypes,
589 decl->getName().getName(), inputs);
590 if (isNegated && std::is_same_v<PDLOpT, pdl::ApplyNativeConstraintOp>)
591 cast<pdl::ApplyNativeConstraintOp>(pdlOp).setIsNegated(
true);
592 return pdlOp->getResults();
596 VariableMapTy::ScopeTy varScope(variables);
601 for (
auto it : llvm::zip(inputs, decl->getInputs()))
602 variables.insert(std::get<1>(it), {std::get<0>(it)});
610 auto *returnStmt = dyn_cast<ast::ReturnStmt>(cstBody->
getChildren().back());
615 return genExpr(returnStmt->getResultExpr());
624 const llvm::SourceMgr &sourceMgr,
const ast::Module &module) {
625 CodeGen codegen(mlirContext, context, sourceMgr);
627 if (failed(
verify(*mlirModule)))
static void checkAndNestUnderRewriteOp(OpBuilder &builder, Value rootExpr, Location loc)
If the given builder is nested under a PDL PatternOp, build a rewrite operation and update the builde...
static void rewrite(DataFlowSolver &solver, MLIRContext *context, MutableArrayRef< Region > initialRegions)
Rewrite the given regions using the computing analysis.
Attributes are known-constant values of operations.
Operation * getParentOp()
Returns the closest surrounding operation that contains this block.
static FileLineColLoc get(StringAttr filename, unsigned line, unsigned column)
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
void loadDialect()
Load a dialect in the context.
RAII guard to reset the insertion point of the builder when destroyed.
This class helps build Operations.
Block * createBlock(Region *parent, Region::iterator insertPt={}, TypeRange argTypes=std::nullopt, ArrayRef< Location > locs=std::nullopt)
Add new block with 'argTypes' arguments and set the insertion point to the end of it.
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Block * getInsertionBlock() const
Return the block the current insertion point belongs to.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
The class represents an Attribute constraint, and constrains a variable to be an Attribute.
This expression represents a literal MLIR Attribute, and contains the textual assembly format of that...
StringRef getValue() const
Get the raw value of this expression.
This class represents a PDLL type that corresponds to an mlir::Attribute.
This expression represents a call to a decl, such as a UserConstraintDecl/UserRewriteDecl.
Expr * getCallableExpr() const
Return the callable of this call.
MutableArrayRef< Expr * > getArguments()
Return the arguments of this call.
bool getIsNegated() const
Returns whether the result of this call is to be negated.
This statement represents a compound statement, which contains a collection of other statements.
MutableArrayRef< Stmt * > getChildren()
Return the children of this compound statement.
This class represents the main context of the PDLL AST.
This expression represents a reference to a Decl node.
Decl * getDecl() const
Get the decl referenced by this expression.
This class represents the base Decl node.
This statement represents the erase statement in PDLL.
This class represents a base AST Expression node.
Type getType() const
Return the type of this expression.
This statement represents a let statement in PDLL.
VariableDecl * getVarDecl() const
Return the variable defined by this statement.
This expression represents a named member or field access of a given parent expression.
const Expr * getParentExpr() const
Get the parent expression of this access.
StringRef getMemberName() const
Return the name of the member being accessed.
This class represents a top-level AST module.
MutableArrayRef< Decl * > getChildren()
Return the children of this module.
This Decl represents a NamedAttribute, and contains a string name and attribute value.
This class represents a base AST node.
This expression represents the structural form of an MLIR Operation.
MutableArrayRef< Expr * > getResultTypes()
Return the result types of this operation.
MutableArrayRef< NamedAttributeDecl * > getAttributes()
Return the attributes of this operation.
MutableArrayRef< Expr * > getOperands()
Return the operands of this operation.
std::optional< StringRef > getName() const
Return the name of the operation, or std::nullopt if there isn't one.
This class represents a PDLL type that corresponds to an mlir::Operation.
This Decl represents a single Pattern.
const CompoundStmt * getBody() const
Return the body of this pattern.
static PatternDecl * create(Context &ctx, SMRange location, const Name *name, std::optional< uint16_t > benefit, bool hasBoundedRecursion, const CompoundStmt *body)
std::optional< uint16_t > getBenefit() const
Return the benefit of this pattern if specified, or std::nullopt.
This expression builds a range from a set of element values (which may be ranges themselves).
MutableArrayRef< Expr * > getElements()
Return the element expressions of this range.
RangeType getType() const
Return the range result type of this expression.
This class represents a PDLL type that corresponds to a range of elements with a given element type.
Type getElementType() const
Return the element type of this range.
This statement represents the replace statement in PDLL.
MutableArrayRef< Expr * > getReplExprs()
Return the replacement values of this statement.
This statement represents a return from a "callable" like decl, e.g.
This statement represents an operation rewrite that contains a block of nested rewrite commands.
CompoundStmt * getRewriteBody() const
Return the compound rewrite body.
This class represents a base AST Statement node.
This expression builds a tuple from a set of element values.
MutableArrayRef< Expr * > getElements()
Return the element expressions of this tuple.
This class represents a PDLL tuple type, i.e.
This expression represents a literal MLIR Type, and contains the textual assembly format of that type...
StringRef getValue() const
Get the raw value of this expression.
This class represents a PDLL type that corresponds to an mlir::Type.
This decl represents a user defined constraint.
MutableArrayRef< VariableDecl * > getInputs()
Return the input arguments of this constraint.
MutableArrayRef< VariableDecl * > getResults()
Return the explicit results of the constraint declaration.
This decl represents a user defined rewrite.
The class represents a Value constraint, and constrains a variable to be a Value.
The class represents a ValueRange constraint, and constrains a variable to be a ValueRange.
This class represents a PDLL type that corresponds to an mlir::Value.
This Decl represents the definition of a PDLL variable.
Expr * getInitExpr() const
Return the initializer expression of this statement, or nullptr if there was no initializer.
MutableArrayRef< ConstraintRef > getConstraints()
Return the constraints of this variable.
Type getType() const
Return the type of the decl.
This class contains all of the registered ODS operation classes.
This class provides an ODS representation of a specific operation operand or result.
This class provides an ODS representation of a specific operation.
ArrayRef< OperandOrResult > getResults() const
Returns the results of this operation.
OwningOpRef< ModuleOp > codegenPDLLToMLIR(MLIRContext *mlirContext, const ast::Context &context, const llvm::SourceMgr &sourceMgr, const ast::Module &module)
Given a PDLL module, generate an MLIR PDL pattern module within the given MLIR context.
Include the generated interface declarations.
Attribute parseAttribute(llvm::StringRef attrStr, MLIRContext *context, Type type={}, size_t *numRead=nullptr, bool isKnownNullTerminated=false)
This parses a single MLIR attribute to an MLIR context if it was valid.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
Type parseType(llvm::StringRef typeStr, MLIRContext *context, size_t *numRead=nullptr, bool isKnownNullTerminated=false)
This parses a single MLIR type to an MLIR context if it was valid.
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs,...
This class represents a reference to a constraint, and contains a constraint and the location of the ...
This class provides a convenient API for interacting with source names.
StringRef getName() const
Return the raw string name.