MLIR 23.0.0git
MLIRGen.cpp
Go to the documentation of this file.
1//===- MLIRGen.cpp --------------------------------------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
14#include "mlir/IR/Builders.h"
15#include "mlir/IR/BuiltinOps.h"
16#include "mlir/IR/Verifier.h"
22#include "llvm/ADT/ScopedHashTable.h"
23#include "llvm/ADT/StringExtras.h"
24#include "llvm/ADT/TypeSwitch.h"
25#include <optional>
26
27using namespace mlir;
28using namespace mlir::pdll;
29
30//===----------------------------------------------------------------------===//
31// CodeGen
32//===----------------------------------------------------------------------===//
33
34namespace {
35class CodeGen {
36public:
37 CodeGen(MLIRContext *mlirContext, const ast::Context &context,
38 const llvm::SourceMgr &sourceMgr)
39 : builder(mlirContext), odsContext(context.getODSContext()),
40 sourceMgr(sourceMgr) {
41 // Make sure that the PDL dialect is loaded.
42 mlirContext->loadDialect<pdl::PDLDialect>();
43 }
44
45 OwningOpRef<ModuleOp> generate(const ast::Module &module);
46
47private:
48 /// Generate an MLIR location from the given source location.
49 Location genLoc(llvm::SMLoc loc);
50 Location genLoc(llvm::SMRange loc) { return genLoc(loc.Start); }
51
52 /// Generate an MLIR type from the given source type.
53 Type genType(ast::Type type);
54
55 /// Generate MLIR for the given AST node.
56 void gen(const ast::Node *node);
57
58 //===--------------------------------------------------------------------===//
59 // Statements
60 //===--------------------------------------------------------------------===//
61
62 void genImpl(const ast::CompoundStmt *stmt);
63 void genImpl(const ast::EraseStmt *stmt);
64 void genImpl(const ast::LetStmt *stmt);
65 void genImpl(const ast::ReplaceStmt *stmt);
66 void genImpl(const ast::RewriteStmt *stmt);
67 void genImpl(const ast::ReturnStmt *stmt);
68
69 //===--------------------------------------------------------------------===//
70 // Decls
71 //===--------------------------------------------------------------------===//
72
73 void genImpl(const ast::UserConstraintDecl *decl);
74 void genImpl(const ast::UserRewriteDecl *decl);
75 void genImpl(const ast::PatternDecl *decl);
76
77 /// Generate the set of MLIR values defined for the given variable decl, and
78 /// apply any attached constraints.
79 SmallVector<Value> genVar(const ast::VariableDecl *varDecl);
80
81 /// Generate the value for a variable that does not have an initializer
82 /// expression, i.e. create the PDL value based on the type/constraints of the
83 /// variable.
84 Value genNonInitializerVar(const ast::VariableDecl *varDecl, Location loc);
85
86 /// Apply the constraints of the given variable to `values`, which correspond
87 /// to the MLIR values of the variable.
88 void applyVarConstraints(const ast::VariableDecl *varDecl, ValueRange values);
89
90 //===--------------------------------------------------------------------===//
91 // Expressions
92 //===--------------------------------------------------------------------===//
93
94 Value genSingleExpr(const ast::Expr *expr);
95 SmallVector<Value> genExpr(const ast::Expr *expr);
96 Value genExprImpl(const ast::AttributeExpr *expr);
97 SmallVector<Value> genExprImpl(const ast::CallExpr *expr);
98 SmallVector<Value> genExprImpl(const ast::DeclRefExpr *expr);
99 Value genExprImpl(const ast::MemberAccessExpr *expr);
100 Value genExprImpl(const ast::OperationExpr *expr);
101 Value genExprImpl(const ast::RangeExpr *expr);
102 SmallVector<Value> genExprImpl(const ast::TupleExpr *expr);
103 Value genExprImpl(const ast::TypeExpr *expr);
104
105 SmallVector<Value> genConstraintCall(const ast::UserConstraintDecl *decl,
106 Location loc, ValueRange inputs,
107 bool isNegated = false);
108 SmallVector<Value> genRewriteCall(const ast::UserRewriteDecl *decl,
109 Location loc, ValueRange inputs);
110 template <typename PDLOpT, typename T>
111 SmallVector<Value> genConstraintOrRewriteCall(const T *decl, Location loc,
112 ValueRange inputs,
113 bool isNegated = false);
114
115 //===--------------------------------------------------------------------===//
116 // Fields
117 //===--------------------------------------------------------------------===//
118
119 /// The MLIR builder used for building the resultant IR.
120 OpBuilder builder;
121
122 /// A map from variable declarations to the MLIR equivalent.
123 using VariableMapTy =
124 llvm::ScopedHashTable<const ast::VariableDecl *, SmallVector<Value>>;
125 VariableMapTy variables;
126
127 /// A reference to the ODS context.
128 const ods::Context &odsContext;
129
130 /// The source manager of the PDLL ast.
131 const llvm::SourceMgr &sourceMgr;
132};
133} // namespace
134
135OwningOpRef<ModuleOp> CodeGen::generate(const ast::Module &module) {
136 OwningOpRef<ModuleOp> mlirModule =
137 ModuleOp::create(builder, genLoc(module.getLoc()));
138 builder.setInsertionPointToStart(mlirModule->getBody());
139
140 // Generate code for each of the decls within the module.
141 for (const ast::Decl *decl : module.getChildren())
142 gen(decl);
143
144 return mlirModule;
145}
146
147Location CodeGen::genLoc(llvm::SMLoc loc) {
148 unsigned fileID = sourceMgr.FindBufferContainingLoc(loc);
149
150 auto [lineNo, column] = sourceMgr.getLineAndColumn(loc);
151 auto *buffer = sourceMgr.getMemoryBuffer(fileID);
152
153 return FileLineColLoc::get(builder.getContext(),
154 buffer->getBufferIdentifier(), lineNo, column);
155}
156
157Type CodeGen::genType(ast::Type type) {
158 return TypeSwitch<ast::Type, Type>(type)
159 .Case([&](ast::AttributeType astType) -> Type {
160 return builder.getType<pdl::AttributeType>();
161 })
162 .Case([&](ast::OperationType astType) -> Type {
163 return builder.getType<pdl::OperationType>();
164 })
165 .Case([&](ast::TypeType astType) -> Type {
166 return builder.getType<pdl::TypeType>();
167 })
168 .Case([&](ast::ValueType astType) -> Type {
169 return builder.getType<pdl::ValueType>();
170 })
171 .Case([&](ast::RangeType astType) -> Type {
172 return pdl::RangeType::get(genType(astType.getElementType()));
173 });
174}
175
176void CodeGen::gen(const ast::Node *node) {
178 .Case<const ast::CompoundStmt, const ast::EraseStmt, const ast::LetStmt,
179 const ast::ReplaceStmt, const ast::RewriteStmt,
180 const ast::ReturnStmt, const ast::UserConstraintDecl,
181 const ast::UserRewriteDecl, const ast::PatternDecl>(
182 [&](auto derivedNode) { this->genImpl(derivedNode); })
183 .Case([&](const ast::Expr *expr) { genExpr(expr); });
184}
185
186//===----------------------------------------------------------------------===//
187// CodeGen: Statements
188//===----------------------------------------------------------------------===//
189
190void CodeGen::genImpl(const ast::CompoundStmt *stmt) {
191 VariableMapTy::ScopeTy varScope(variables);
192 for (const ast::Stmt *childStmt : stmt->getChildren())
193 gen(childStmt);
194}
195
196/// If the given builder is nested under a PDL PatternOp, build a rewrite
197/// operation and update the builder to nest under it. This is necessary for
198/// PDLL operation rewrite statements that are directly nested within a Pattern.
199static void checkAndNestUnderRewriteOp(OpBuilder &builder, Value rootExpr,
200 Location loc) {
201 if (isa<pdl::PatternOp>(builder.getInsertionBlock()->getParentOp())) {
202 pdl::RewriteOp rewrite =
203 pdl::RewriteOp::create(builder, loc, rootExpr, /*name=*/StringAttr(),
204 /*externalArgs=*/ValueRange());
205 builder.createBlock(&rewrite.getBodyRegion());
206 }
207}
208
209void CodeGen::genImpl(const ast::EraseStmt *stmt) {
210 OpBuilder::InsertionGuard insertGuard(builder);
211 Value rootExpr = genSingleExpr(stmt->getRootOpExpr());
212 Location loc = genLoc(stmt->getLoc());
213
214 // Make sure we are nested in a RewriteOp.
215 OpBuilder::InsertionGuard guard(builder);
216 checkAndNestUnderRewriteOp(builder, rootExpr, loc);
217 pdl::EraseOp::create(builder, loc, rootExpr);
218}
219
220void CodeGen::genImpl(const ast::LetStmt *stmt) { genVar(stmt->getVarDecl()); }
221
222void CodeGen::genImpl(const ast::ReplaceStmt *stmt) {
223 OpBuilder::InsertionGuard insertGuard(builder);
224 Value rootExpr = genSingleExpr(stmt->getRootOpExpr());
225 Location loc = genLoc(stmt->getLoc());
226
227 // Make sure we are nested in a RewriteOp.
228 OpBuilder::InsertionGuard guard(builder);
229 checkAndNestUnderRewriteOp(builder, rootExpr, loc);
230
231 SmallVector<Value> replValues;
232 for (ast::Expr *replExpr : stmt->getReplExprs())
233 replValues.push_back(genSingleExpr(replExpr));
234
235 // Check to see if the statement has a replacement operation, or a range of
236 // replacement values.
237 bool usesReplOperation =
238 replValues.size() == 1 &&
239 isa<pdl::OperationType>(replValues.front().getType());
240 pdl::ReplaceOp::create(
241 builder, loc, rootExpr, usesReplOperation ? replValues[0] : Value(),
242 usesReplOperation ? ValueRange() : ValueRange(replValues));
243}
244
245void CodeGen::genImpl(const ast::RewriteStmt *stmt) {
246 OpBuilder::InsertionGuard insertGuard(builder);
247 Value rootExpr = genSingleExpr(stmt->getRootOpExpr());
248
249 // Make sure we are nested in a RewriteOp.
250 OpBuilder::InsertionGuard guard(builder);
251 checkAndNestUnderRewriteOp(builder, rootExpr, genLoc(stmt->getLoc()));
252 gen(stmt->getRewriteBody());
253}
254
255void CodeGen::genImpl(const ast::ReturnStmt *stmt) {
256 // ReturnStmt generation is handled by the respective constraint or rewrite
257 // parent node.
258}
259
260//===----------------------------------------------------------------------===//
261// CodeGen: Decls
262//===----------------------------------------------------------------------===//
263
264void CodeGen::genImpl(const ast::UserConstraintDecl *decl) {
265 // All PDLL constraints get inlined when called, and the main native
266 // constraint declarations doesn't require any MLIR to be generated, only uses
267 // of it do.
268}
269
270void CodeGen::genImpl(const ast::UserRewriteDecl *decl) {
271 // All PDLL rewrites get inlined when called, and the main native
272 // rewrite declarations doesn't require any MLIR to be generated, only uses
273 // of it do.
274}
275
276void CodeGen::genImpl(const ast::PatternDecl *decl) {
277 const ast::Name *name = decl->getName();
278
279 // FIXME: Properly model HasBoundedRecursion in PDL so that we don't drop it
280 // here.
281 pdl::PatternOp pattern = pdl::PatternOp::create(
282 builder, genLoc(decl->getLoc()), decl->getBenefit(),
283 name ? std::optional<StringRef>(name->getName())
284 : std::optional<StringRef>());
285
286 OpBuilder::InsertionGuard savedInsertPoint(builder);
287 builder.setInsertionPointToStart(pattern.getBody());
288 gen(decl->getBody());
289}
290
291SmallVector<Value> CodeGen::genVar(const ast::VariableDecl *varDecl) {
292 auto it = variables.begin(varDecl);
293 if (it != variables.end())
294 return *it;
295
296 // If the variable has an initial value, use that as the base value.
297 // Otherwise, generate a value using the constraint list.
298 SmallVector<Value> values;
299 if (const ast::Expr *initExpr = varDecl->getInitExpr())
300 values = genExpr(initExpr);
301 else
302 values.push_back(genNonInitializerVar(varDecl, genLoc(varDecl->getLoc())));
303
304 // Apply the constraints of the values of the variable.
305 applyVarConstraints(varDecl, values);
306
307 variables.insert(varDecl, values);
308 return values;
309}
310
311Value CodeGen::genNonInitializerVar(const ast::VariableDecl *varDecl,
312 Location loc) {
313 // A functor used to generate expressions nested
314 auto getTypeConstraint = [&]() -> Value {
315 for (const ast::ConstraintRef &constraint : varDecl->getConstraints()) {
316 Value typeValue =
317 TypeSwitch<const ast::Node *, Value>(constraint.constraint)
318 .Case<ast::AttrConstraintDecl, ast::ValueConstraintDecl,
319 ast::ValueRangeConstraintDecl>(
320 [&, this](auto *cst) -> Value {
321 if (auto *typeConstraintExpr = cst->getTypeExpr())
322 return this->genSingleExpr(typeConstraintExpr);
323 return Value();
324 })
325 .Default(Value());
326 if (typeValue)
327 return typeValue;
328 }
329 return Value();
330 };
331
332 // Generate a value based on the type of the variable.
333 ast::Type type = varDecl->getType();
334 Type mlirType = genType(type);
335 if (isa<ast::ValueType>(type))
336 return pdl::OperandOp::create(builder, loc, mlirType, getTypeConstraint());
337 if (isa<ast::TypeType>(type))
338 return pdl::TypeOp::create(builder, loc, mlirType, /*type=*/TypeAttr());
339 if (isa<ast::AttributeType>(type))
340 return pdl::AttributeOp::create(builder, loc, getTypeConstraint());
341 if (ast::OperationType opType = dyn_cast<ast::OperationType>(type)) {
342 Value operands = pdl::OperandsOp::create(
343 builder, loc, pdl::RangeType::get(builder.getType<pdl::ValueType>()),
344 /*type=*/Value());
345 Value results = pdl::TypesOp::create(
346 builder, loc, pdl::RangeType::get(builder.getType<pdl::TypeType>()),
347 /*types=*/ArrayAttr());
348 return pdl::OperationOp::create(builder, loc, opType.getName(), operands,
349 ArrayRef<StringRef>(), ValueRange(),
350 results);
351 }
352
353 if (ast::RangeType rangeTy = dyn_cast<ast::RangeType>(type)) {
354 ast::Type eleTy = rangeTy.getElementType();
355 if (isa<ast::ValueType>(eleTy))
356 return pdl::OperandsOp::create(builder, loc, mlirType,
357 getTypeConstraint());
358 if (isa<ast::TypeType>(eleTy))
359 return pdl::TypesOp::create(builder, loc, mlirType,
360 /*types=*/ArrayAttr());
361 }
362
363 llvm_unreachable("invalid non-initialized variable type");
364}
365
366void CodeGen::applyVarConstraints(const ast::VariableDecl *varDecl,
367 ValueRange values) {
368 // Generate calls to any user constraints that were attached via the
369 // constraint list.
370 for (const ast::ConstraintRef &ref : varDecl->getConstraints())
371 if (const auto *userCst = dyn_cast<ast::UserConstraintDecl>(ref.constraint))
372 genConstraintCall(userCst, genLoc(ref.referenceLoc), values);
373}
374
375//===----------------------------------------------------------------------===//
376// CodeGen: Expressions
377//===----------------------------------------------------------------------===//
378
379Value CodeGen::genSingleExpr(const ast::Expr *expr) {
381 .Case<const ast::AttributeExpr, const ast::MemberAccessExpr,
382 const ast::OperationExpr, const ast::RangeExpr,
383 const ast::TypeExpr>(
384 [&](auto derivedNode) { return this->genExprImpl(derivedNode); })
385 .Case<const ast::CallExpr, const ast::DeclRefExpr, const ast::TupleExpr>(
386 [&](auto derivedNode) {
387 return llvm::getSingleElement(this->genExprImpl(derivedNode));
388 });
389}
390
391SmallVector<Value> CodeGen::genExpr(const ast::Expr *expr) {
393 .Case<const ast::CallExpr, const ast::DeclRefExpr, const ast::TupleExpr>(
394 [&](auto derivedNode) { return this->genExprImpl(derivedNode); })
395 .Default([&](const ast::Expr *expr) -> SmallVector<Value> {
396 return {genSingleExpr(expr)};
397 });
398}
399
400Value CodeGen::genExprImpl(const ast::AttributeExpr *expr) {
401 Attribute attr = parseAttribute(expr->getValue(), builder.getContext());
402 assert(attr && "invalid MLIR attribute data");
403 return pdl::AttributeOp::create(builder, genLoc(expr->getLoc()), attr);
404}
405
406SmallVector<Value> CodeGen::genExprImpl(const ast::CallExpr *expr) {
407 Location loc = genLoc(expr->getLoc());
408 SmallVector<Value> arguments;
409 for (const ast::Expr *arg : expr->getArguments())
410 arguments.push_back(genSingleExpr(arg));
411
412 // Resolve the callable expression of this call.
413 auto *callableExpr = dyn_cast<ast::DeclRefExpr>(expr->getCallableExpr());
414 assert(callableExpr && "unhandled CallExpr callable");
415
416 // Generate the PDL based on the type of callable.
417 const ast::Decl *callable = callableExpr->getDecl();
418 if (const auto *decl = dyn_cast<ast::UserConstraintDecl>(callable))
419 return genConstraintCall(decl, loc, arguments, expr->getIsNegated());
420 if (const auto *decl = dyn_cast<ast::UserRewriteDecl>(callable))
421 return genRewriteCall(decl, loc, arguments);
422 llvm_unreachable("unhandled CallExpr callable");
423}
424
425SmallVector<Value> CodeGen::genExprImpl(const ast::DeclRefExpr *expr) {
426 if (const auto *varDecl = dyn_cast<ast::VariableDecl>(expr->getDecl()))
427 return genVar(varDecl);
428 llvm_unreachable("unknown decl reference expression");
429}
430
431Value CodeGen::genExprImpl(const ast::MemberAccessExpr *expr) {
432 Location loc = genLoc(expr->getLoc());
433 StringRef name = expr->getMemberName();
434 SmallVector<Value> parentExprs = genExpr(expr->getParentExpr());
435 ast::Type parentType = expr->getParentExpr()->getType();
436
437 // Handle operation based member access.
438 if (ast::OperationType opType = dyn_cast<ast::OperationType>(parentType)) {
439 if (isa<ast::AllResultsMemberAccessExpr>(expr)) {
440 Type mlirType = genType(expr->getType());
441 if (isa<pdl::ValueType>(mlirType))
442 return pdl::ResultOp::create(builder, loc, mlirType, parentExprs[0],
443 builder.getI32IntegerAttr(0));
444 return pdl::ResultsOp::create(builder, loc, mlirType, parentExprs[0]);
445 }
446
447 const ods::Operation *odsOp = opType.getODSOperation();
448 if (!odsOp) {
449 assert(llvm::isDigit(name[0]) &&
450 "unregistered op only allows numeric indexing");
451 unsigned resultIndex;
452 name.getAsInteger(/*Radix=*/10, resultIndex);
453 IntegerAttr index = builder.getI32IntegerAttr(resultIndex);
454 return pdl::ResultOp::create(builder, loc, genType(expr->getType()),
455 parentExprs[0], index);
456 }
457
458 // Find the result with the member name or by index.
459 ArrayRef<ods::OperandOrResult> results = odsOp->getResults();
460 unsigned resultIndex = results.size();
461 if (llvm::isDigit(name[0])) {
462 name.getAsInteger(/*Radix=*/10, resultIndex);
463 } else {
464 auto findFn = [&](const ods::OperandOrResult &result) {
465 return result.getName() == name;
466 };
467 resultIndex = llvm::find_if(results, findFn) - results.begin();
468 }
469 assert(resultIndex < results.size() && "invalid result index");
470
471 // Generate the result access.
472 IntegerAttr index = builder.getI32IntegerAttr(resultIndex);
473 return pdl::ResultsOp::create(builder, loc, genType(expr->getType()),
474 parentExprs[0], index);
475 }
476
477 // Handle tuple based member access.
478 if (auto tupleType = dyn_cast<ast::TupleType>(parentType)) {
479 auto elementNames = tupleType.getElementNames();
480
481 // The index is either a numeric index, or a name.
482 unsigned index = 0;
483 if (llvm::isDigit(name[0]))
484 name.getAsInteger(/*Radix=*/10, index);
485 else
486 index = llvm::find(elementNames, name) - elementNames.begin();
487
488 assert(index < parentExprs.size() && "invalid result index");
489 return parentExprs[index];
490 }
491
492 llvm_unreachable("unhandled member access expression");
493}
494
495Value CodeGen::genExprImpl(const ast::OperationExpr *expr) {
496 Location loc = genLoc(expr->getLoc());
497 std::optional<StringRef> opName = expr->getName();
498
499 // Operands.
500 SmallVector<Value> operands;
501 for (const ast::Expr *operand : expr->getOperands())
502 operands.push_back(genSingleExpr(operand));
503
504 // Attributes.
505 SmallVector<StringRef> attrNames;
506 SmallVector<Value> attrValues;
507 for (const ast::NamedAttributeDecl *attr : expr->getAttributes()) {
508 attrNames.push_back(attr->getName().getName());
509 attrValues.push_back(genSingleExpr(attr->getValue()));
510 }
511
512 // Results.
513 SmallVector<Value> results;
514 for (const ast::Expr *result : expr->getResultTypes())
515 results.push_back(genSingleExpr(result));
516
517 return pdl::OperationOp::create(builder, loc, opName, operands, attrNames,
518 attrValues, results);
519}
520
521Value CodeGen::genExprImpl(const ast::RangeExpr *expr) {
522 SmallVector<Value> elements;
523 for (const ast::Expr *element : expr->getElements())
524 llvm::append_range(elements, genExpr(element));
525
526 return pdl::RangeOp::create(builder, genLoc(expr->getLoc()),
527 genType(expr->getType()), elements);
528}
529
530SmallVector<Value> CodeGen::genExprImpl(const ast::TupleExpr *expr) {
531 SmallVector<Value> elements;
532 for (const ast::Expr *element : expr->getElements())
533 elements.push_back(genSingleExpr(element));
534 return elements;
535}
536
537Value CodeGen::genExprImpl(const ast::TypeExpr *expr) {
538 Type type = parseType(expr->getValue(), builder.getContext());
539 assert(type && "invalid MLIR type data");
540 return pdl::TypeOp::create(builder, genLoc(expr->getLoc()),
541 builder.getType<pdl::TypeType>(),
542 TypeAttr::get(type));
543}
544
545SmallVector<Value>
546CodeGen::genConstraintCall(const ast::UserConstraintDecl *decl, Location loc,
547 ValueRange inputs, bool isNegated) {
548 // Apply any constraints defined on the arguments to the input values.
549 for (auto it : llvm::zip(decl->getInputs(), inputs))
550 applyVarConstraints(std::get<0>(it), std::get<1>(it));
551
552 // Generate the constraint call.
553 SmallVector<Value> results =
554 genConstraintOrRewriteCall<pdl::ApplyNativeConstraintOp>(
555 decl, loc, inputs, isNegated);
556
557 // Apply any constraints defined on the results of the constraint.
558 for (auto it : llvm::zip(decl->getResults(), results))
559 applyVarConstraints(std::get<0>(it), std::get<1>(it));
560 return results;
561}
562
563SmallVector<Value> CodeGen::genRewriteCall(const ast::UserRewriteDecl *decl,
564 Location loc, ValueRange inputs) {
565 return genConstraintOrRewriteCall<pdl::ApplyNativeRewriteOp>(decl, loc,
566 inputs);
567}
568
569template <typename PDLOpT, typename T>
570SmallVector<Value>
571CodeGen::genConstraintOrRewriteCall(const T *decl, Location loc,
572 ValueRange inputs, bool isNegated) {
573 const ast::CompoundStmt *cstBody = decl->getBody();
574
575 // If the decl doesn't have a statement body, it is a native decl.
576 if (!cstBody) {
577 ast::Type declResultType = decl->getResultType();
578 SmallVector<Type> resultTypes;
579 if (ast::TupleType tupleType = dyn_cast<ast::TupleType>(declResultType)) {
580 for (ast::Type type : tupleType.getElementTypes())
581 resultTypes.push_back(genType(type));
582 } else {
583 resultTypes.push_back(genType(declResultType));
584 }
585 PDLOpT pdlOp = PDLOpT::create(builder, loc, resultTypes,
586 decl->getName().getName(), inputs);
587 if (isNegated && std::is_same_v<PDLOpT, pdl::ApplyNativeConstraintOp>)
588 cast<pdl::ApplyNativeConstraintOp>(pdlOp).setIsNegated(true);
589 return pdlOp->getResults();
590 }
591
592 // Otherwise, this is a PDLL decl.
593 VariableMapTy::ScopeTy varScope(variables);
594
595 // Map the inputs of the call to the decl arguments.
596 // Note: This is only valid because we do not support recursion, meaning
597 // we don't need to worry about conflicting mappings here.
598 for (auto it : llvm::zip(inputs, decl->getInputs()))
599 variables.insert(std::get<1>(it), {std::get<0>(it)});
600
601 // Visit the body of the call as normal.
602 gen(cstBody);
603
604 // If the decl has no results, there is nothing to do.
605 if (cstBody->getChildren().empty())
606 return SmallVector<Value>();
607 auto *returnStmt = dyn_cast<ast::ReturnStmt>(cstBody->getChildren().back());
608 if (!returnStmt)
609 return SmallVector<Value>();
610
611 // Otherwise, grab the results from the return statement.
612 return genExpr(returnStmt->getResultExpr());
613}
614
615//===----------------------------------------------------------------------===//
616// MLIRGen
617//===----------------------------------------------------------------------===//
618
620 MLIRContext *mlirContext, const ast::Context &context,
621 const llvm::SourceMgr &sourceMgr, const ast::Module &module) {
622 CodeGen codegen(mlirContext, context, sourceMgr);
623 OwningOpRef<ModuleOp> mlirModule = codegen.generate(module);
624 if (failed(verify(*mlirModule)))
625 return nullptr;
626 return mlirModule;
627}
ArrayAttr()
static void checkAndNestUnderRewriteOp(OpBuilder &builder, Value rootExpr, Location loc)
If the given builder is nested under a PDL PatternOp, build a rewrite operation and update the builde...
Definition MLIRGen.cpp:199
static void rewrite(DataFlowSolver &solver, MLIRContext *context, MutableArrayRef< Region > initialRegions)
Rewrite the given regions using the computing analysis.
Definition SCCP.cpp:67
Operation * getParentOp()
Returns the closest surrounding operation that contains this block.
Definition Block.cpp:31
static FileLineColLoc get(StringAttr filename, unsigned line, unsigned column)
Definition Location.cpp:157
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition Location.h:76
MLIRContext is the top-level object for a collection of MLIR operations.
Definition MLIRContext.h:63
void loadDialect()
Load a dialect in the context.
This class helps build Operations.
Definition Builders.h:209
Block * createBlock(Region *parent, Region::iterator insertPt={}, TypeRange argTypes={}, ArrayRef< Location > locs={})
Add new block with 'argTypes' arguments and set the insertion point to the end of it.
Definition Builders.cpp:434
Block * getInsertionBlock() const
Return the block the current insertion point belongs to.
Definition Builders.h:444
This class acts as an owning reference to an op, and will automatically destroy the held op on destru...
Definition OwningOpRef.h:29
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
StringRef getValue() const
Get the raw value of this expression.
Definition Nodes.h:376
Expr * getCallableExpr() const
Return the callable of this call.
Definition Nodes.h:400
MutableArrayRef< Expr * > getArguments()
Return the arguments of this call.
Definition Nodes.h:403
bool getIsNegated() const
Returns whether the result of this call is to be negated.
Definition Nodes.h:407
MutableArrayRef< Stmt * > getChildren()
Return the children of this compound statement.
Definition Nodes.h:185
This class represents the main context of the PDLL AST.
Definition Context.h:25
Decl * getDecl() const
Get the decl referenced by this expression.
Definition Nodes.h:438
const Name * getName() const
Return the name of the decl, or nullptr if it doesn't have one.
Definition Nodes.h:672
Type getType() const
Return the type of this expression.
Definition Nodes.h:351
VariableDecl * getVarDecl() const
Return the variable defined by this statement.
Definition Nodes.h:216
const Expr * getParentExpr() const
Get the parent expression of this access.
Definition Nodes.h:461
StringRef getMemberName() const
Return the name of the member being accessed.
Definition Nodes.h:464
This class represents a top-level AST module.
Definition Nodes.h:1297
MutableArrayRef< Decl * > getChildren()
Return the children of this module.
Definition Nodes.h:1302
SMRange getLoc() const
Return the location of this node.
Definition Nodes.h:131
Expr * getRootOpExpr() const
Return the root operation of this rewrite.
Definition Nodes.h:237
MutableArrayRef< Expr * > getOperands()
Return the operands of this operation.
Definition Nodes.h:532
MutableArrayRef< NamedAttributeDecl * > getAttributes()
Return the attributes of this operation.
Definition Nodes.h:548
MutableArrayRef< Expr * > getResultTypes()
Return the result types of this operation.
Definition Nodes.h:540
std::optional< StringRef > getName() const
Return the name of the operation, or std::nullopt if there isn't one.
Definition Nodes.cpp:327
const CompoundStmt * getBody() const
Return the body of this pattern.
Definition Nodes.h:1057
std::optional< uint16_t > getBenefit() const
Return the benefit of this pattern if specified, or std::nullopt.
Definition Nodes.h:1051
RangeType getType() const
Return the range result type of this expression.
Definition Nodes.h:600
MutableArrayRef< Expr * > getElements()
Return the element expressions of this range.
Definition Nodes.h:592
Type getElementType() const
Return the element type of this range.
Definition Types.cpp:99
MutableArrayRef< Expr * > getReplExprs()
Return the replacement values of this statement.
Definition Nodes.h:277
CompoundStmt * getRewriteBody() const
Return the compound rewrite body.
Definition Nodes.h:308
MutableArrayRef< Expr * > getElements()
Return the element expressions of this tuple.
Definition Nodes.h:625
StringRef getValue() const
Get the raw value of this expression.
Definition Nodes.h:654
MutableArrayRef< VariableDecl * > getResults()
Return the explicit results of the constraint declaration.
Definition Nodes.h:927
MutableArrayRef< VariableDecl * > getInputs()
Return the input arguments of this constraint.
Definition Nodes.h:914
MutableArrayRef< ConstraintRef > getConstraints()
Return the constraints of this variable.
Definition Nodes.h:1255
Expr * getInitExpr() const
Return the initializer expression of this statement, or nullptr if there was no initializer.
Definition Nodes.h:1264
Type getType() const
Return the type of the decl.
Definition Nodes.h:1270
ArrayRef< OperandOrResult > getResults() const
Returns the results of this operation.
Definition Operation.h:168
OwningOpRef< ModuleOp > codegenPDLLToMLIR(MLIRContext *mlirContext, const ast::Context &context, const llvm::SourceMgr &sourceMgr, const ast::Module &module)
Given a PDLL module, generate an MLIR PDL pattern module within the given MLIR context.
Definition MLIRGen.cpp:619
Include the generated interface declarations.
Attribute parseAttribute(llvm::StringRef attrStr, MLIRContext *context, Type type={}, size_t *numRead=nullptr, bool isKnownNullTerminated=false)
This parses a single MLIR attribute to an MLIR context if it was valid.
llvm::TypeSwitch< T, ResultT > TypeSwitch
Definition LLVM.h:139
Type parseType(llvm::StringRef typeStr, MLIRContext *context, size_t *numRead=nullptr, bool isKnownNullTerminated=false)
This parses a single MLIR type to an MLIR context if it was valid.
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs,...
Definition Verifier.cpp:480
StringRef getName() const
Return the raw string name.
Definition Nodes.h:41