MLIR  19.0.0git
MLIRGen.cpp
Go to the documentation of this file.
1 //===- MLIRGen.cpp --------------------------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
14 #include "mlir/IR/Builders.h"
15 #include "mlir/IR/BuiltinOps.h"
16 #include "mlir/IR/Verifier.h"
22 #include "llvm/ADT/ScopedHashTable.h"
23 #include "llvm/ADT/StringExtras.h"
24 #include "llvm/ADT/TypeSwitch.h"
25 #include <optional>
26 
27 using namespace mlir;
28 using namespace mlir::pdll;
29 
30 //===----------------------------------------------------------------------===//
31 // CodeGen
32 //===----------------------------------------------------------------------===//
33 
34 namespace {
35 class CodeGen {
36 public:
37  CodeGen(MLIRContext *mlirContext, const ast::Context &context,
38  const llvm::SourceMgr &sourceMgr)
39  : builder(mlirContext), odsContext(context.getODSContext()),
40  sourceMgr(sourceMgr) {
41  // Make sure that the PDL dialect is loaded.
42  mlirContext->loadDialect<pdl::PDLDialect>();
43  }
44 
45  OwningOpRef<ModuleOp> generate(const ast::Module &module);
46 
47 private:
48  /// Generate an MLIR location from the given source location.
49  Location genLoc(llvm::SMLoc loc);
50  Location genLoc(llvm::SMRange loc) { return genLoc(loc.Start); }
51 
52  /// Generate an MLIR type from the given source type.
53  Type genType(ast::Type type);
54 
55  /// Generate MLIR for the given AST node.
56  void gen(const ast::Node *node);
57 
58  //===--------------------------------------------------------------------===//
59  // Statements
60  //===--------------------------------------------------------------------===//
61 
62  void genImpl(const ast::CompoundStmt *stmt);
63  void genImpl(const ast::EraseStmt *stmt);
64  void genImpl(const ast::LetStmt *stmt);
65  void genImpl(const ast::ReplaceStmt *stmt);
66  void genImpl(const ast::RewriteStmt *stmt);
67  void genImpl(const ast::ReturnStmt *stmt);
68 
69  //===--------------------------------------------------------------------===//
70  // Decls
71  //===--------------------------------------------------------------------===//
72 
73  void genImpl(const ast::UserConstraintDecl *decl);
74  void genImpl(const ast::UserRewriteDecl *decl);
75  void genImpl(const ast::PatternDecl *decl);
76 
77  /// Generate the set of MLIR values defined for the given variable decl, and
78  /// apply any attached constraints.
79  SmallVector<Value> genVar(const ast::VariableDecl *varDecl);
80 
81  /// Generate the value for a variable that does not have an initializer
82  /// expression, i.e. create the PDL value based on the type/constraints of the
83  /// variable.
84  Value genNonInitializerVar(const ast::VariableDecl *varDecl, Location loc);
85 
86  /// Apply the constraints of the given variable to `values`, which correspond
87  /// to the MLIR values of the variable.
88  void applyVarConstraints(const ast::VariableDecl *varDecl, ValueRange values);
89 
90  //===--------------------------------------------------------------------===//
91  // Expressions
92  //===--------------------------------------------------------------------===//
93 
94  Value genSingleExpr(const ast::Expr *expr);
95  SmallVector<Value> genExpr(const ast::Expr *expr);
96  Value genExprImpl(const ast::AttributeExpr *expr);
97  SmallVector<Value> genExprImpl(const ast::CallExpr *expr);
98  SmallVector<Value> genExprImpl(const ast::DeclRefExpr *expr);
99  Value genExprImpl(const ast::MemberAccessExpr *expr);
100  Value genExprImpl(const ast::OperationExpr *expr);
101  Value genExprImpl(const ast::RangeExpr *expr);
102  SmallVector<Value> genExprImpl(const ast::TupleExpr *expr);
103  Value genExprImpl(const ast::TypeExpr *expr);
104 
105  SmallVector<Value> genConstraintCall(const ast::UserConstraintDecl *decl,
106  Location loc, ValueRange inputs,
107  bool isNegated = false);
108  SmallVector<Value> genRewriteCall(const ast::UserRewriteDecl *decl,
109  Location loc, ValueRange inputs);
110  template <typename PDLOpT, typename T>
111  SmallVector<Value> genConstraintOrRewriteCall(const T *decl, Location loc,
112  ValueRange inputs,
113  bool isNegated = false);
114 
115  //===--------------------------------------------------------------------===//
116  // Fields
117  //===--------------------------------------------------------------------===//
118 
119  /// The MLIR builder used for building the resultant IR.
120  OpBuilder builder;
121 
122  /// A map from variable declarations to the MLIR equivalent.
123  using VariableMapTy =
124  llvm::ScopedHashTable<const ast::VariableDecl *, SmallVector<Value>>;
125  VariableMapTy variables;
126 
127  /// A reference to the ODS context.
128  const ods::Context &odsContext;
129 
130  /// The source manager of the PDLL ast.
131  const llvm::SourceMgr &sourceMgr;
132 };
133 } // namespace
134 
135 OwningOpRef<ModuleOp> CodeGen::generate(const ast::Module &module) {
136  OwningOpRef<ModuleOp> mlirModule =
137  builder.create<ModuleOp>(genLoc(module.getLoc()));
138  builder.setInsertionPointToStart(mlirModule->getBody());
139 
140  // Generate code for each of the decls within the module.
141  for (const ast::Decl *decl : module.getChildren())
142  gen(decl);
143 
144  return mlirModule;
145 }
146 
147 Location CodeGen::genLoc(llvm::SMLoc loc) {
148  unsigned fileID = sourceMgr.FindBufferContainingLoc(loc);
149 
150  // TODO: Fix performance issues in SourceMgr::getLineAndColumn so that we can
151  // use it here.
152  auto &bufferInfo = sourceMgr.getBufferInfo(fileID);
153  unsigned lineNo = bufferInfo.getLineNumber(loc.getPointer());
154  unsigned column =
155  (loc.getPointer() - bufferInfo.getPointerForLineNumber(lineNo)) + 1;
156  auto *buffer = sourceMgr.getMemoryBuffer(fileID);
157 
158  return FileLineColLoc::get(builder.getContext(),
159  buffer->getBufferIdentifier(), lineNo, column);
160 }
161 
162 Type CodeGen::genType(ast::Type type) {
163  return TypeSwitch<ast::Type, Type>(type)
164  .Case([&](ast::AttributeType astType) -> Type {
165  return builder.getType<pdl::AttributeType>();
166  })
167  .Case([&](ast::OperationType astType) -> Type {
168  return builder.getType<pdl::OperationType>();
169  })
170  .Case([&](ast::TypeType astType) -> Type {
171  return builder.getType<pdl::TypeType>();
172  })
173  .Case([&](ast::ValueType astType) -> Type {
174  return builder.getType<pdl::ValueType>();
175  })
176  .Case([&](ast::RangeType astType) -> Type {
177  return pdl::RangeType::get(genType(astType.getElementType()));
178  });
179 }
180 
181 void CodeGen::gen(const ast::Node *node) {
183  .Case<const ast::CompoundStmt, const ast::EraseStmt, const ast::LetStmt,
184  const ast::ReplaceStmt, const ast::RewriteStmt,
187  [&](auto derivedNode) { this->genImpl(derivedNode); })
188  .Case([&](const ast::Expr *expr) { genExpr(expr); });
189 }
190 
191 //===----------------------------------------------------------------------===//
192 // CodeGen: Statements
193 //===----------------------------------------------------------------------===//
194 
195 void CodeGen::genImpl(const ast::CompoundStmt *stmt) {
196  VariableMapTy::ScopeTy varScope(variables);
197  for (const ast::Stmt *childStmt : stmt->getChildren())
198  gen(childStmt);
199 }
200 
201 /// If the given builder is nested under a PDL PatternOp, build a rewrite
202 /// operation and update the builder to nest under it. This is necessary for
203 /// PDLL operation rewrite statements that are directly nested within a Pattern.
204 static void checkAndNestUnderRewriteOp(OpBuilder &builder, Value rootExpr,
205  Location loc) {
206  if (isa<pdl::PatternOp>(builder.getInsertionBlock()->getParentOp())) {
207  pdl::RewriteOp rewrite =
208  builder.create<pdl::RewriteOp>(loc, rootExpr, /*name=*/StringAttr(),
209  /*externalArgs=*/ValueRange());
210  builder.createBlock(&rewrite.getBodyRegion());
211  }
212 }
213 
214 void CodeGen::genImpl(const ast::EraseStmt *stmt) {
215  OpBuilder::InsertionGuard insertGuard(builder);
216  Value rootExpr = genSingleExpr(stmt->getRootOpExpr());
217  Location loc = genLoc(stmt->getLoc());
218 
219  // Make sure we are nested in a RewriteOp.
220  OpBuilder::InsertionGuard guard(builder);
221  checkAndNestUnderRewriteOp(builder, rootExpr, loc);
222  builder.create<pdl::EraseOp>(loc, rootExpr);
223 }
224 
225 void CodeGen::genImpl(const ast::LetStmt *stmt) { genVar(stmt->getVarDecl()); }
226 
227 void CodeGen::genImpl(const ast::ReplaceStmt *stmt) {
228  OpBuilder::InsertionGuard insertGuard(builder);
229  Value rootExpr = genSingleExpr(stmt->getRootOpExpr());
230  Location loc = genLoc(stmt->getLoc());
231 
232  // Make sure we are nested in a RewriteOp.
233  OpBuilder::InsertionGuard guard(builder);
234  checkAndNestUnderRewriteOp(builder, rootExpr, loc);
235 
236  SmallVector<Value> replValues;
237  for (ast::Expr *replExpr : stmt->getReplExprs())
238  replValues.push_back(genSingleExpr(replExpr));
239 
240  // Check to see if the statement has a replacement operation, or a range of
241  // replacement values.
242  bool usesReplOperation =
243  replValues.size() == 1 &&
244  isa<pdl::OperationType>(replValues.front().getType());
245  builder.create<pdl::ReplaceOp>(
246  loc, rootExpr, usesReplOperation ? replValues[0] : Value(),
247  usesReplOperation ? ValueRange() : ValueRange(replValues));
248 }
249 
250 void CodeGen::genImpl(const ast::RewriteStmt *stmt) {
251  OpBuilder::InsertionGuard insertGuard(builder);
252  Value rootExpr = genSingleExpr(stmt->getRootOpExpr());
253 
254  // Make sure we are nested in a RewriteOp.
255  OpBuilder::InsertionGuard guard(builder);
256  checkAndNestUnderRewriteOp(builder, rootExpr, genLoc(stmt->getLoc()));
257  gen(stmt->getRewriteBody());
258 }
259 
260 void CodeGen::genImpl(const ast::ReturnStmt *stmt) {
261  // ReturnStmt generation is handled by the respective constraint or rewrite
262  // parent node.
263 }
264 
265 //===----------------------------------------------------------------------===//
266 // CodeGen: Decls
267 //===----------------------------------------------------------------------===//
268 
269 void CodeGen::genImpl(const ast::UserConstraintDecl *decl) {
270  // All PDLL constraints get inlined when called, and the main native
271  // constraint declarations doesn't require any MLIR to be generated, only uses
272  // of it do.
273 }
274 
275 void CodeGen::genImpl(const ast::UserRewriteDecl *decl) {
276  // All PDLL rewrites get inlined when called, and the main native
277  // rewrite declarations doesn't require any MLIR to be generated, only uses
278  // of it do.
279 }
280 
281 void CodeGen::genImpl(const ast::PatternDecl *decl) {
282  const ast::Name *name = decl->getName();
283 
284  // FIXME: Properly model HasBoundedRecursion in PDL so that we don't drop it
285  // here.
286  pdl::PatternOp pattern = builder.create<pdl::PatternOp>(
287  genLoc(decl->getLoc()), decl->getBenefit(),
288  name ? std::optional<StringRef>(name->getName())
289  : std::optional<StringRef>());
290 
291  OpBuilder::InsertionGuard savedInsertPoint(builder);
292  builder.setInsertionPointToStart(pattern.getBody());
293  gen(decl->getBody());
294 }
295 
296 SmallVector<Value> CodeGen::genVar(const ast::VariableDecl *varDecl) {
297  auto it = variables.begin(varDecl);
298  if (it != variables.end())
299  return *it;
300 
301  // If the variable has an initial value, use that as the base value.
302  // Otherwise, generate a value using the constraint list.
303  SmallVector<Value> values;
304  if (const ast::Expr *initExpr = varDecl->getInitExpr())
305  values = genExpr(initExpr);
306  else
307  values.push_back(genNonInitializerVar(varDecl, genLoc(varDecl->getLoc())));
308 
309  // Apply the constraints of the values of the variable.
310  applyVarConstraints(varDecl, values);
311 
312  variables.insert(varDecl, values);
313  return values;
314 }
315 
316 Value CodeGen::genNonInitializerVar(const ast::VariableDecl *varDecl,
317  Location loc) {
318  // A functor used to generate expressions nested
319  auto getTypeConstraint = [&]() -> Value {
320  for (const ast::ConstraintRef &constraint : varDecl->getConstraints()) {
321  Value typeValue =
322  TypeSwitch<const ast::Node *, Value>(constraint.constraint)
325  [&, this](auto *cst) -> Value {
326  if (auto *typeConstraintExpr = cst->getTypeExpr())
327  return this->genSingleExpr(typeConstraintExpr);
328  return Value();
329  })
330  .Default(Value());
331  if (typeValue)
332  return typeValue;
333  }
334  return Value();
335  };
336 
337  // Generate a value based on the type of the variable.
338  ast::Type type = varDecl->getType();
339  Type mlirType = genType(type);
340  if (type.isa<ast::ValueType>())
341  return builder.create<pdl::OperandOp>(loc, mlirType, getTypeConstraint());
342  if (type.isa<ast::TypeType>())
343  return builder.create<pdl::TypeOp>(loc, mlirType, /*type=*/TypeAttr());
344  if (type.isa<ast::AttributeType>())
345  return builder.create<pdl::AttributeOp>(loc, getTypeConstraint());
346  if (ast::OperationType opType = type.dyn_cast<ast::OperationType>()) {
347  Value operands = builder.create<pdl::OperandsOp>(
348  loc, pdl::RangeType::get(builder.getType<pdl::ValueType>()),
349  /*type=*/Value());
350  Value results = builder.create<pdl::TypesOp>(
351  loc, pdl::RangeType::get(builder.getType<pdl::TypeType>()),
352  /*types=*/ArrayAttr());
353  return builder.create<pdl::OperationOp>(
354  loc, opType.getName(), operands, std::nullopt, ValueRange(), results);
355  }
356 
357  if (ast::RangeType rangeTy = type.dyn_cast<ast::RangeType>()) {
358  ast::Type eleTy = rangeTy.getElementType();
359  if (eleTy.isa<ast::ValueType>())
360  return builder.create<pdl::OperandsOp>(loc, mlirType,
361  getTypeConstraint());
362  if (eleTy.isa<ast::TypeType>())
363  return builder.create<pdl::TypesOp>(loc, mlirType, /*types=*/ArrayAttr());
364  }
365 
366  llvm_unreachable("invalid non-initialized variable type");
367 }
368 
369 void CodeGen::applyVarConstraints(const ast::VariableDecl *varDecl,
370  ValueRange values) {
371  // Generate calls to any user constraints that were attached via the
372  // constraint list.
373  for (const ast::ConstraintRef &ref : varDecl->getConstraints())
374  if (const auto *userCst = dyn_cast<ast::UserConstraintDecl>(ref.constraint))
375  genConstraintCall(userCst, genLoc(ref.referenceLoc), values);
376 }
377 
378 //===----------------------------------------------------------------------===//
379 // CodeGen: Expressions
380 //===----------------------------------------------------------------------===//
381 
382 Value CodeGen::genSingleExpr(const ast::Expr *expr) {
384  .Case<const ast::AttributeExpr, const ast::MemberAccessExpr,
385  const ast::OperationExpr, const ast::RangeExpr,
386  const ast::TypeExpr>(
387  [&](auto derivedNode) { return this->genExprImpl(derivedNode); })
388  .Case<const ast::CallExpr, const ast::DeclRefExpr, const ast::TupleExpr>(
389  [&](auto derivedNode) {
390  SmallVector<Value> results = this->genExprImpl(derivedNode);
391  assert(results.size() == 1 && "expected single expression result");
392  return results[0];
393  });
394 }
395 
396 SmallVector<Value> CodeGen::genExpr(const ast::Expr *expr) {
398  .Case<const ast::CallExpr, const ast::DeclRefExpr, const ast::TupleExpr>(
399  [&](auto derivedNode) { return this->genExprImpl(derivedNode); })
400  .Default([&](const ast::Expr *expr) -> SmallVector<Value> {
401  return {genSingleExpr(expr)};
402  });
403 }
404 
405 Value CodeGen::genExprImpl(const ast::AttributeExpr *expr) {
406  Attribute attr = parseAttribute(expr->getValue(), builder.getContext());
407  assert(attr && "invalid MLIR attribute data");
408  return builder.create<pdl::AttributeOp>(genLoc(expr->getLoc()), attr);
409 }
410 
411 SmallVector<Value> CodeGen::genExprImpl(const ast::CallExpr *expr) {
412  Location loc = genLoc(expr->getLoc());
413  SmallVector<Value> arguments;
414  for (const ast::Expr *arg : expr->getArguments())
415  arguments.push_back(genSingleExpr(arg));
416 
417  // Resolve the callable expression of this call.
418  auto *callableExpr = dyn_cast<ast::DeclRefExpr>(expr->getCallableExpr());
419  assert(callableExpr && "unhandled CallExpr callable");
420 
421  // Generate the PDL based on the type of callable.
422  const ast::Decl *callable = callableExpr->getDecl();
423  if (const auto *decl = dyn_cast<ast::UserConstraintDecl>(callable))
424  return genConstraintCall(decl, loc, arguments, expr->getIsNegated());
425  if (const auto *decl = dyn_cast<ast::UserRewriteDecl>(callable))
426  return genRewriteCall(decl, loc, arguments);
427  llvm_unreachable("unhandled CallExpr callable");
428 }
429 
430 SmallVector<Value> CodeGen::genExprImpl(const ast::DeclRefExpr *expr) {
431  if (const auto *varDecl = dyn_cast<ast::VariableDecl>(expr->getDecl()))
432  return genVar(varDecl);
433  llvm_unreachable("unknown decl reference expression");
434 }
435 
436 Value CodeGen::genExprImpl(const ast::MemberAccessExpr *expr) {
437  Location loc = genLoc(expr->getLoc());
438  StringRef name = expr->getMemberName();
439  SmallVector<Value> parentExprs = genExpr(expr->getParentExpr());
440  ast::Type parentType = expr->getParentExpr()->getType();
441 
442  // Handle operation based member access.
443  if (ast::OperationType opType = parentType.dyn_cast<ast::OperationType>()) {
444  if (isa<ast::AllResultsMemberAccessExpr>(expr)) {
445  Type mlirType = genType(expr->getType());
446  if (isa<pdl::ValueType>(mlirType))
447  return builder.create<pdl::ResultOp>(loc, mlirType, parentExprs[0],
448  builder.getI32IntegerAttr(0));
449  return builder.create<pdl::ResultsOp>(loc, mlirType, parentExprs[0]);
450  }
451 
452  const ods::Operation *odsOp = opType.getODSOperation();
453  if (!odsOp) {
454  assert(llvm::isDigit(name[0]) &&
455  "unregistered op only allows numeric indexing");
456  unsigned resultIndex;
457  name.getAsInteger(/*Radix=*/10, resultIndex);
458  IntegerAttr index = builder.getI32IntegerAttr(resultIndex);
459  return builder.create<pdl::ResultOp>(loc, genType(expr->getType()),
460  parentExprs[0], index);
461  }
462 
463  // Find the result with the member name or by index.
464  ArrayRef<ods::OperandOrResult> results = odsOp->getResults();
465  unsigned resultIndex = results.size();
466  if (llvm::isDigit(name[0])) {
467  name.getAsInteger(/*Radix=*/10, resultIndex);
468  } else {
469  auto findFn = [&](const ods::OperandOrResult &result) {
470  return result.getName() == name;
471  };
472  resultIndex = llvm::find_if(results, findFn) - results.begin();
473  }
474  assert(resultIndex < results.size() && "invalid result index");
475 
476  // Generate the result access.
477  IntegerAttr index = builder.getI32IntegerAttr(resultIndex);
478  return builder.create<pdl::ResultsOp>(loc, genType(expr->getType()),
479  parentExprs[0], index);
480  }
481 
482  // Handle tuple based member access.
483  if (auto tupleType = parentType.dyn_cast<ast::TupleType>()) {
484  auto elementNames = tupleType.getElementNames();
485 
486  // The index is either a numeric index, or a name.
487  unsigned index = 0;
488  if (llvm::isDigit(name[0]))
489  name.getAsInteger(/*Radix=*/10, index);
490  else
491  index = llvm::find(elementNames, name) - elementNames.begin();
492 
493  assert(index < parentExprs.size() && "invalid result index");
494  return parentExprs[index];
495  }
496 
497  llvm_unreachable("unhandled member access expression");
498 }
499 
500 Value CodeGen::genExprImpl(const ast::OperationExpr *expr) {
501  Location loc = genLoc(expr->getLoc());
502  std::optional<StringRef> opName = expr->getName();
503 
504  // Operands.
505  SmallVector<Value> operands;
506  for (const ast::Expr *operand : expr->getOperands())
507  operands.push_back(genSingleExpr(operand));
508 
509  // Attributes.
510  SmallVector<StringRef> attrNames;
511  SmallVector<Value> attrValues;
512  for (const ast::NamedAttributeDecl *attr : expr->getAttributes()) {
513  attrNames.push_back(attr->getName().getName());
514  attrValues.push_back(genSingleExpr(attr->getValue()));
515  }
516 
517  // Results.
518  SmallVector<Value> results;
519  for (const ast::Expr *result : expr->getResultTypes())
520  results.push_back(genSingleExpr(result));
521 
522  return builder.create<pdl::OperationOp>(loc, opName, operands, attrNames,
523  attrValues, results);
524 }
525 
526 Value CodeGen::genExprImpl(const ast::RangeExpr *expr) {
527  SmallVector<Value> elements;
528  for (const ast::Expr *element : expr->getElements())
529  llvm::append_range(elements, genExpr(element));
530 
531  return builder.create<pdl::RangeOp>(genLoc(expr->getLoc()),
532  genType(expr->getType()), elements);
533 }
534 
535 SmallVector<Value> CodeGen::genExprImpl(const ast::TupleExpr *expr) {
536  SmallVector<Value> elements;
537  for (const ast::Expr *element : expr->getElements())
538  elements.push_back(genSingleExpr(element));
539  return elements;
540 }
541 
542 Value CodeGen::genExprImpl(const ast::TypeExpr *expr) {
543  Type type = parseType(expr->getValue(), builder.getContext());
544  assert(type && "invalid MLIR type data");
545  return builder.create<pdl::TypeOp>(genLoc(expr->getLoc()),
546  builder.getType<pdl::TypeType>(),
547  TypeAttr::get(type));
548 }
549 
551 CodeGen::genConstraintCall(const ast::UserConstraintDecl *decl, Location loc,
552  ValueRange inputs, bool isNegated) {
553  // Apply any constraints defined on the arguments to the input values.
554  for (auto it : llvm::zip(decl->getInputs(), inputs))
555  applyVarConstraints(std::get<0>(it), std::get<1>(it));
556 
557  // Generate the constraint call.
558  SmallVector<Value> results =
559  genConstraintOrRewriteCall<pdl::ApplyNativeConstraintOp>(
560  decl, loc, inputs, isNegated);
561 
562  // Apply any constraints defined on the results of the constraint.
563  for (auto it : llvm::zip(decl->getResults(), results))
564  applyVarConstraints(std::get<0>(it), std::get<1>(it));
565  return results;
566 }
567 
568 SmallVector<Value> CodeGen::genRewriteCall(const ast::UserRewriteDecl *decl,
569  Location loc, ValueRange inputs) {
570  return genConstraintOrRewriteCall<pdl::ApplyNativeRewriteOp>(decl, loc,
571  inputs);
572 }
573 
574 template <typename PDLOpT, typename T>
576 CodeGen::genConstraintOrRewriteCall(const T *decl, Location loc,
577  ValueRange inputs, bool isNegated) {
578  const ast::CompoundStmt *cstBody = decl->getBody();
579 
580  // If the decl doesn't have a statement body, it is a native decl.
581  if (!cstBody) {
582  ast::Type declResultType = decl->getResultType();
583  SmallVector<Type> resultTypes;
584  if (ast::TupleType tupleType = declResultType.dyn_cast<ast::TupleType>()) {
585  for (ast::Type type : tupleType.getElementTypes())
586  resultTypes.push_back(genType(type));
587  } else {
588  resultTypes.push_back(genType(declResultType));
589  }
590  PDLOpT pdlOp = builder.create<PDLOpT>(
591  loc, resultTypes, decl->getName().getName(), inputs);
592  if (isNegated && std::is_same_v<PDLOpT, pdl::ApplyNativeConstraintOp>)
593  cast<pdl::ApplyNativeConstraintOp>(pdlOp).setIsNegated(true);
594  return pdlOp->getResults();
595  }
596 
597  // Otherwise, this is a PDLL decl.
598  VariableMapTy::ScopeTy varScope(variables);
599 
600  // Map the inputs of the call to the decl arguments.
601  // Note: This is only valid because we do not support recursion, meaning
602  // we don't need to worry about conflicting mappings here.
603  for (auto it : llvm::zip(inputs, decl->getInputs()))
604  variables.insert(std::get<1>(it), {std::get<0>(it)});
605 
606  // Visit the body of the call as normal.
607  gen(cstBody);
608 
609  // If the decl has no results, there is nothing to do.
610  if (cstBody->getChildren().empty())
611  return SmallVector<Value>();
612  auto *returnStmt = dyn_cast<ast::ReturnStmt>(cstBody->getChildren().back());
613  if (!returnStmt)
614  return SmallVector<Value>();
615 
616  // Otherwise, grab the results from the return statement.
617  return genExpr(returnStmt->getResultExpr());
618 }
619 
620 //===----------------------------------------------------------------------===//
621 // MLIRGen
622 //===----------------------------------------------------------------------===//
623 
625  MLIRContext *mlirContext, const ast::Context &context,
626  const llvm::SourceMgr &sourceMgr, const ast::Module &module) {
627  CodeGen codegen(mlirContext, context, sourceMgr);
628  OwningOpRef<ModuleOp> mlirModule = codegen.generate(module);
629  if (failed(verify(*mlirModule)))
630  return nullptr;
631  return mlirModule;
632 }
static void checkAndNestUnderRewriteOp(OpBuilder &builder, Value rootExpr, Location loc)
If the given builder is nested under a PDL PatternOp, build a rewrite operation and update the builde...
Definition: MLIRGen.cpp:204
static void rewrite(DataFlowSolver &solver, MLIRContext *context, MutableArrayRef< Region > initialRegions)
Rewrite the given regions using the computing analysis.
Definition: SCCP.cpp:67
Attributes are known-constant values of operations.
Definition: Attributes.h:25
Operation * getParentOp()
Returns the closest surrounding operation that contains this block.
Definition: Block.cpp:30
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:63
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
void loadDialect()
Load a dialect in the context.
Definition: MLIRContext.h:107
RAII guard to reset the insertion point of the builder when destroyed.
Definition: Builders.h:350
This class helps build Operations.
Definition: Builders.h:209
Block * createBlock(Region *parent, Region::iterator insertPt={}, TypeRange argTypes=std::nullopt, ArrayRef< Location > locs=std::nullopt)
Add new block with 'argTypes' arguments and set the insertion point to the end of it.
Definition: Builders.cpp:437
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:464
Block * getInsertionBlock() const
Return the block the current insertion point belongs to.
Definition: Builders.h:444
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:381
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
The class represents an Attribute constraint, and constrains a variable to be an Attribute.
Definition: Nodes.h:749
This expression represents a literal MLIR Attribute, and contains the textual assembly format of that...
Definition: Nodes.h:367
StringRef getValue() const
Get the raw value of this expression.
Definition: Nodes.h:373
This class represents a PDLL type that corresponds to an mlir::Attribute.
Definition: Types.h:131
This expression represents a call to a decl, such as a UserConstraintDecl/UserRewriteDecl.
Definition: Nodes.h:390
Expr * getCallableExpr() const
Return the callable of this call.
Definition: Nodes.h:397
MutableArrayRef< Expr * > getArguments()
Return the arguments of this call.
Definition: Nodes.h:400
bool getIsNegated() const
Returns whether the result of this call is to be negated.
Definition: Nodes.h:408
This statement represents a compound statement, which contains a collection of other statements.
Definition: Nodes.h:179
MutableArrayRef< Stmt * > getChildren()
Return the children of this compound statement.
Definition: Nodes.h:185
This class represents the main context of the PDLL AST.
Definition: Context.h:25
This expression represents a reference to a Decl node.
Definition: Nodes.h:434
Decl * getDecl() const
Get the decl referenced by this expression.
Definition: Nodes.h:439
This class represents the base Decl node.
Definition: Nodes.h:669
This statement represents the erase statement in PDLL.
Definition: Nodes.h:254
This class represents a base AST Expression node.
Definition: Nodes.h:345
Type getType() const
Return the type of this expression.
Definition: Nodes.h:348
This statement represents a let statement in PDLL.
Definition: Nodes.h:211
VariableDecl * getVarDecl() const
Return the variable defined by this statement.
Definition: Nodes.h:216
This expression represents a named member or field access of a given parent expression.
Definition: Nodes.h:455
const Expr * getParentExpr() const
Get the parent expression of this access.
Definition: Nodes.h:462
StringRef getMemberName() const
Return the name of the member being accessed.
Definition: Nodes.h:465
This class represents a top-level AST module.
Definition: Nodes.h:1291
MutableArrayRef< Decl * > getChildren()
Return the children of this module.
Definition: Nodes.h:1296
This Decl represents a NamedAttribute, and contains a string name and attribute value.
Definition: Nodes.h:992
This class represents a base AST node.
Definition: Nodes.h:108
This expression represents the structural form of an MLIR Operation.
Definition: Nodes.h:512
MutableArrayRef< Expr * > getResultTypes()
Return the result types of this operation.
Definition: Nodes.h:540
MutableArrayRef< NamedAttributeDecl * > getAttributes()
Return the attributes of this operation.
Definition: Nodes.h:548
MutableArrayRef< Expr * > getOperands()
Return the operands of this operation.
Definition: Nodes.h:532
std::optional< StringRef > getName() const
Return the name of the operation, or std::nullopt if there isn't one.
Definition: Nodes.cpp:330
This class represents a PDLL type that corresponds to an mlir::Operation.
Definition: Types.h:158
This Decl represents a single Pattern.
Definition: Nodes.h:1037
const CompoundStmt * getBody() const
Return the body of this pattern.
Definition: Nodes.h:1051
static PatternDecl * create(Context &ctx, SMRange location, const Name *name, std::optional< uint16_t > benefit, bool hasBoundedRecursion, const CompoundStmt *body)
Definition: Nodes.cpp:520
std::optional< uint16_t > getBenefit() const
Return the benefit of this pattern if specified, or std::nullopt.
Definition: Nodes.h:1045
This expression builds a range from a set of element values (which may be ranges themselves).
Definition: Nodes.h:586
MutableArrayRef< Expr * > getElements()
Return the element expressions of this range.
Definition: Nodes.h:592
RangeType getType() const
Return the range result type of this expression.
Definition: Nodes.h:600
This class represents a PDLL type that corresponds to a range of elements with a given element type.
Definition: Types.h:183
Type getElementType() const
Return the element type of this range.
Definition: Types.cpp:100
This statement represents the replace statement in PDLL.
Definition: Nodes.h:269
MutableArrayRef< Expr * > getReplExprs()
Return the replacement values of this statement.
Definition: Nodes.h:275
This statement represents a return from a "callable" like decl, e.g.
Definition: Nodes.h:321
This statement represents an operation rewrite that contains a block of nested rewrite commands.
Definition: Nodes.h:299
CompoundStmt * getRewriteBody() const
Return the compound rewrite body.
Definition: Nodes.h:305
This class represents a base AST Statement node.
Definition: Nodes.h:164
This expression builds a tuple from a set of element values.
Definition: Nodes.h:619
MutableArrayRef< Expr * > getElements()
Return the element expressions of this tuple.
Definition: Nodes.h:625
This class represents a PDLL tuple type, i.e.
Definition: Types.h:244
This expression represents a literal MLIR Type, and contains the textual assembly format of that type...
Definition: Nodes.h:648
StringRef getValue() const
Get the raw value of this expression.
Definition: Nodes.h:654
This class represents a PDLL type that corresponds to an mlir::Type.
Definition: Types.h:272
U dyn_cast() const
Definition: Types.h:76
bool isa() const
Provide type casting support.
Definition: Types.h:67
This decl represents a user defined constraint.
Definition: Nodes.h:882
MutableArrayRef< VariableDecl * > getInputs()
Return the input arguments of this constraint.
Definition: Nodes.h:908
MutableArrayRef< VariableDecl * > getResults()
Return the explicit results of the constraint declaration.
Definition: Nodes.h:921
This decl represents a user defined rewrite.
Definition: Nodes.h:1092
The class represents a Value constraint, and constrains a variable to be a Value.
Definition: Nodes.h:825
The class represents a ValueRange constraint, and constrains a variable to be a ValueRange.
Definition: Nodes.h:847
This class represents a PDLL type that corresponds to an mlir::Value.
Definition: Types.h:285
This Decl represents the definition of a PDLL variable.
Definition: Nodes.h:1242
Expr * getInitExpr() const
Return the initializer expression of this statement, or nullptr if there was no initializer.
Definition: Nodes.h:1258
MutableArrayRef< ConstraintRef > getConstraints()
Return the constraints of this variable.
Definition: Nodes.h:1249
Type getType() const
Return the type of the decl.
Definition: Nodes.h:1264
This class contains all of the registered ODS operation classes.
Definition: Context.h:32
This class provides an ODS representation of a specific operation operand or result.
Definition: Operation.h:74
This class provides an ODS representation of a specific operation.
Definition: Operation.h:125
ArrayRef< OperandOrResult > getResults() const
Returns the results of this operation.
Definition: Operation.h:168
OwningOpRef< ModuleOp > codegenPDLLToMLIR(MLIRContext *mlirContext, const ast::Context &context, const llvm::SourceMgr &sourceMgr, const ast::Module &module)
Given a PDLL module, generate an MLIR PDL pattern module within the given MLIR context.
Definition: MLIRGen.cpp:624
Include the generated interface declarations.
Attribute parseAttribute(llvm::StringRef attrStr, MLIRContext *context, Type type={}, size_t *numRead=nullptr, bool isKnownNullTerminated=false)
This parses a single MLIR attribute to an MLIR context if it was valid.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
Type parseType(llvm::StringRef typeStr, MLIRContext *context, size_t *numRead=nullptr, bool isKnownNullTerminated=false)
This parses a single MLIR type to an MLIR context if it was valid.
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs,...
Definition: Verifier.cpp:421
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value.
Definition: LogicalResult.h:72
This class represents a reference to a constraint, and contains a constraint and the location of the ...
Definition: Nodes.h:716
This class provides a convenient API for interacting with source names.
Definition: Nodes.h:37
StringRef getName() const
Return the raw string name.
Definition: Nodes.h:41