MLIR  16.0.0git
MLIRGen.cpp
Go to the documentation of this file.
1 //===- MLIRGen.cpp --------------------------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
14 #include "mlir/IR/Builders.h"
15 #include "mlir/IR/BuiltinOps.h"
16 #include "mlir/IR/Verifier.h"
22 #include "llvm/ADT/ScopedHashTable.h"
23 #include "llvm/ADT/StringExtras.h"
24 #include "llvm/ADT/TypeSwitch.h"
25 
26 using namespace mlir;
27 using namespace mlir::pdll;
28 
29 //===----------------------------------------------------------------------===//
30 // CodeGen
31 //===----------------------------------------------------------------------===//
32 
33 namespace {
34 class CodeGen {
35 public:
36  CodeGen(MLIRContext *mlirContext, const ast::Context &context,
37  const llvm::SourceMgr &sourceMgr)
38  : builder(mlirContext), odsContext(context.getODSContext()),
39  sourceMgr(sourceMgr) {
40  // Make sure that the PDL dialect is loaded.
41  mlirContext->loadDialect<pdl::PDLDialect>();
42  }
43 
44  OwningOpRef<ModuleOp> generate(const ast::Module &module);
45 
46 private:
47  /// Generate an MLIR location from the given source location.
48  Location genLoc(llvm::SMLoc loc);
49  Location genLoc(llvm::SMRange loc) { return genLoc(loc.Start); }
50 
51  /// Generate an MLIR type from the given source type.
52  Type genType(ast::Type type);
53 
54  /// Generate MLIR for the given AST node.
55  void gen(const ast::Node *node);
56 
57  //===--------------------------------------------------------------------===//
58  // Statements
59  //===--------------------------------------------------------------------===//
60 
61  void genImpl(const ast::CompoundStmt *stmt);
62  void genImpl(const ast::EraseStmt *stmt);
63  void genImpl(const ast::LetStmt *stmt);
64  void genImpl(const ast::ReplaceStmt *stmt);
65  void genImpl(const ast::RewriteStmt *stmt);
66  void genImpl(const ast::ReturnStmt *stmt);
67 
68  //===--------------------------------------------------------------------===//
69  // Decls
70  //===--------------------------------------------------------------------===//
71 
72  void genImpl(const ast::UserConstraintDecl *decl);
73  void genImpl(const ast::UserRewriteDecl *decl);
74  void genImpl(const ast::PatternDecl *decl);
75 
76  /// Generate the set of MLIR values defined for the given variable decl, and
77  /// apply any attached constraints.
78  SmallVector<Value> genVar(const ast::VariableDecl *varDecl);
79 
80  /// Generate the value for a variable that does not have an initializer
81  /// expression, i.e. create the PDL value based on the type/constraints of the
82  /// variable.
83  Value genNonInitializerVar(const ast::VariableDecl *varDecl, Location loc);
84 
85  /// Apply the constraints of the given variable to `values`, which correspond
86  /// to the MLIR values of the variable.
87  void applyVarConstraints(const ast::VariableDecl *varDecl, ValueRange values);
88 
89  //===--------------------------------------------------------------------===//
90  // Expressions
91  //===--------------------------------------------------------------------===//
92 
93  Value genSingleExpr(const ast::Expr *expr);
94  SmallVector<Value> genExpr(const ast::Expr *expr);
95  Value genExprImpl(const ast::AttributeExpr *expr);
96  SmallVector<Value> genExprImpl(const ast::CallExpr *expr);
97  SmallVector<Value> genExprImpl(const ast::DeclRefExpr *expr);
98  Value genExprImpl(const ast::MemberAccessExpr *expr);
99  Value genExprImpl(const ast::OperationExpr *expr);
100  SmallVector<Value> genExprImpl(const ast::TupleExpr *expr);
101  Value genExprImpl(const ast::TypeExpr *expr);
102 
103  SmallVector<Value> genConstraintCall(const ast::UserConstraintDecl *decl,
104  Location loc, ValueRange inputs);
105  SmallVector<Value> genRewriteCall(const ast::UserRewriteDecl *decl,
106  Location loc, ValueRange inputs);
107  template <typename PDLOpT, typename T>
108  SmallVector<Value> genConstraintOrRewriteCall(const T *decl, Location loc,
109  ValueRange inputs);
110 
111  //===--------------------------------------------------------------------===//
112  // Fields
113  //===--------------------------------------------------------------------===//
114 
115  /// The MLIR builder used for building the resultant IR.
116  OpBuilder builder;
117 
118  /// A map from variable declarations to the MLIR equivalent.
119  using VariableMapTy =
120  llvm::ScopedHashTable<const ast::VariableDecl *, SmallVector<Value>>;
121  VariableMapTy variables;
122 
123  /// A reference to the ODS context.
124  const ods::Context &odsContext;
125 
126  /// The source manager of the PDLL ast.
127  const llvm::SourceMgr &sourceMgr;
128 };
129 } // namespace
130 
131 OwningOpRef<ModuleOp> CodeGen::generate(const ast::Module &module) {
132  OwningOpRef<ModuleOp> mlirModule =
133  builder.create<ModuleOp>(genLoc(module.getLoc()));
134  builder.setInsertionPointToStart(mlirModule->getBody());
135 
136  // Generate code for each of the decls within the module.
137  for (const ast::Decl *decl : module.getChildren())
138  gen(decl);
139 
140  return mlirModule;
141 }
142 
143 Location CodeGen::genLoc(llvm::SMLoc loc) {
144  unsigned fileID = sourceMgr.FindBufferContainingLoc(loc);
145 
146  // TODO: Fix performance issues in SourceMgr::getLineAndColumn so that we can
147  // use it here.
148  auto &bufferInfo = sourceMgr.getBufferInfo(fileID);
149  unsigned lineNo = bufferInfo.getLineNumber(loc.getPointer());
150  unsigned column =
151  (loc.getPointer() - bufferInfo.getPointerForLineNumber(lineNo)) + 1;
152  auto *buffer = sourceMgr.getMemoryBuffer(fileID);
153 
154  return FileLineColLoc::get(builder.getContext(),
155  buffer->getBufferIdentifier(), lineNo, column);
156 }
157 
158 Type CodeGen::genType(ast::Type type) {
159  return TypeSwitch<ast::Type, Type>(type)
160  .Case([&](ast::AttributeType astType) -> Type {
161  return builder.getType<pdl::AttributeType>();
162  })
163  .Case([&](ast::OperationType astType) -> Type {
164  return builder.getType<pdl::OperationType>();
165  })
166  .Case([&](ast::TypeType astType) -> Type {
167  return builder.getType<pdl::TypeType>();
168  })
169  .Case([&](ast::ValueType astType) -> Type {
170  return builder.getType<pdl::ValueType>();
171  })
172  .Case([&](ast::RangeType astType) -> Type {
173  return pdl::RangeType::get(genType(astType.getElementType()));
174  });
175 }
176 
177 void CodeGen::gen(const ast::Node *node) {
179  .Case<const ast::CompoundStmt, const ast::EraseStmt, const ast::LetStmt,
180  const ast::ReplaceStmt, const ast::RewriteStmt,
183  [&](auto derivedNode) { this->genImpl(derivedNode); })
184  .Case([&](const ast::Expr *expr) { genExpr(expr); });
185 }
186 
187 //===----------------------------------------------------------------------===//
188 // CodeGen: Statements
189 //===----------------------------------------------------------------------===//
190 
191 void CodeGen::genImpl(const ast::CompoundStmt *stmt) {
192  VariableMapTy::ScopeTy varScope(variables);
193  for (const ast::Stmt *childStmt : stmt->getChildren())
194  gen(childStmt);
195 }
196 
197 /// If the given builder is nested under a PDL PatternOp, build a rewrite
198 /// operation and update the builder to nest under it. This is necessary for
199 /// PDLL operation rewrite statements that are directly nested within a Pattern.
200 static void checkAndNestUnderRewriteOp(OpBuilder &builder, Value rootExpr,
201  Location loc) {
202  if (isa<pdl::PatternOp>(builder.getInsertionBlock()->getParentOp())) {
203  pdl::RewriteOp rewrite =
204  builder.create<pdl::RewriteOp>(loc, rootExpr, /*name=*/StringAttr(),
205  /*externalArgs=*/ValueRange());
206  builder.createBlock(&rewrite.body());
207  }
208 }
209 
210 void CodeGen::genImpl(const ast::EraseStmt *stmt) {
211  OpBuilder::InsertionGuard insertGuard(builder);
212  Value rootExpr = genSingleExpr(stmt->getRootOpExpr());
213  Location loc = genLoc(stmt->getLoc());
214 
215  // Make sure we are nested in a RewriteOp.
216  OpBuilder::InsertionGuard guard(builder);
217  checkAndNestUnderRewriteOp(builder, rootExpr, loc);
218  builder.create<pdl::EraseOp>(loc, rootExpr);
219 }
220 
221 void CodeGen::genImpl(const ast::LetStmt *stmt) { genVar(stmt->getVarDecl()); }
222 
223 void CodeGen::genImpl(const ast::ReplaceStmt *stmt) {
224  OpBuilder::InsertionGuard insertGuard(builder);
225  Value rootExpr = genSingleExpr(stmt->getRootOpExpr());
226  Location loc = genLoc(stmt->getLoc());
227 
228  // Make sure we are nested in a RewriteOp.
229  OpBuilder::InsertionGuard guard(builder);
230  checkAndNestUnderRewriteOp(builder, rootExpr, loc);
231 
232  SmallVector<Value> replValues;
233  for (ast::Expr *replExpr : stmt->getReplExprs())
234  replValues.push_back(genSingleExpr(replExpr));
235 
236  // Check to see if the statement has a replacement operation, or a range of
237  // replacement values.
238  bool usesReplOperation =
239  replValues.size() == 1 &&
240  replValues.front().getType().isa<pdl::OperationType>();
241  builder.create<pdl::ReplaceOp>(
242  loc, rootExpr, usesReplOperation ? replValues[0] : Value(),
243  usesReplOperation ? ValueRange() : ValueRange(replValues));
244 }
245 
246 void CodeGen::genImpl(const ast::RewriteStmt *stmt) {
247  OpBuilder::InsertionGuard insertGuard(builder);
248  Value rootExpr = genSingleExpr(stmt->getRootOpExpr());
249 
250  // Make sure we are nested in a RewriteOp.
251  OpBuilder::InsertionGuard guard(builder);
252  checkAndNestUnderRewriteOp(builder, rootExpr, genLoc(stmt->getLoc()));
253  gen(stmt->getRewriteBody());
254 }
255 
256 void CodeGen::genImpl(const ast::ReturnStmt *stmt) {
257  // ReturnStmt generation is handled by the respective constraint or rewrite
258  // parent node.
259 }
260 
261 //===----------------------------------------------------------------------===//
262 // CodeGen: Decls
263 //===----------------------------------------------------------------------===//
264 
265 void CodeGen::genImpl(const ast::UserConstraintDecl *decl) {
266  // All PDLL constraints get inlined when called, and the main native
267  // constraint declarations doesn't require any MLIR to be generated, only uses
268  // of it do.
269 }
270 
271 void CodeGen::genImpl(const ast::UserRewriteDecl *decl) {
272  // All PDLL rewrites get inlined when called, and the main native
273  // rewrite declarations doesn't require any MLIR to be generated, only uses
274  // of it do.
275 }
276 
277 void CodeGen::genImpl(const ast::PatternDecl *decl) {
278  const ast::Name *name = decl->getName();
279 
280  // FIXME: Properly model HasBoundedRecursion in PDL so that we don't drop it
281  // here.
282  pdl::PatternOp pattern = builder.create<pdl::PatternOp>(
283  genLoc(decl->getLoc()), decl->getBenefit(),
284  name ? Optional<StringRef>(name->getName()) : Optional<StringRef>());
285 
286  OpBuilder::InsertionGuard savedInsertPoint(builder);
287  builder.setInsertionPointToStart(pattern.getBody());
288  gen(decl->getBody());
289 }
290 
291 SmallVector<Value> CodeGen::genVar(const ast::VariableDecl *varDecl) {
292  auto it = variables.begin(varDecl);
293  if (it != variables.end())
294  return *it;
295 
296  // If the variable has an initial value, use that as the base value.
297  // Otherwise, generate a value using the constraint list.
298  SmallVector<Value> values;
299  if (const ast::Expr *initExpr = varDecl->getInitExpr())
300  values = genExpr(initExpr);
301  else
302  values.push_back(genNonInitializerVar(varDecl, genLoc(varDecl->getLoc())));
303 
304  // Apply the constraints of the values of the variable.
305  applyVarConstraints(varDecl, values);
306 
307  variables.insert(varDecl, values);
308  return values;
309 }
310 
311 Value CodeGen::genNonInitializerVar(const ast::VariableDecl *varDecl,
312  Location loc) {
313  // A functor used to generate expressions nested
314  auto getTypeConstraint = [&]() -> Value {
315  for (const ast::ConstraintRef &constraint : varDecl->getConstraints()) {
316  Value typeValue =
317  TypeSwitch<const ast::Node *, Value>(constraint.constraint)
320  [&, this](auto *cst) -> Value {
321  if (auto *typeConstraintExpr = cst->getTypeExpr())
322  return this->genSingleExpr(typeConstraintExpr);
323  return Value();
324  })
325  .Default(Value());
326  if (typeValue)
327  return typeValue;
328  }
329  return Value();
330  };
331 
332  // Generate a value based on the type of the variable.
333  ast::Type type = varDecl->getType();
334  Type mlirType = genType(type);
335  if (type.isa<ast::ValueType>())
336  return builder.create<pdl::OperandOp>(loc, mlirType, getTypeConstraint());
337  if (type.isa<ast::TypeType>())
338  return builder.create<pdl::TypeOp>(loc, mlirType, /*type=*/TypeAttr());
339  if (type.isa<ast::AttributeType>())
340  return builder.create<pdl::AttributeOp>(loc, getTypeConstraint());
341  if (ast::OperationType opType = type.dyn_cast<ast::OperationType>()) {
342  Value operands = builder.create<pdl::OperandsOp>(
343  loc, pdl::RangeType::get(builder.getType<pdl::ValueType>()),
344  /*type=*/Value());
345  Value results = builder.create<pdl::TypesOp>(
346  loc, pdl::RangeType::get(builder.getType<pdl::TypeType>()),
347  /*types=*/ArrayAttr());
348  return builder.create<pdl::OperationOp>(loc, opType.getName(), operands,
349  llvm::None, ValueRange(), results);
350  }
351 
352  if (ast::RangeType rangeTy = type.dyn_cast<ast::RangeType>()) {
353  ast::Type eleTy = rangeTy.getElementType();
354  if (eleTy.isa<ast::ValueType>())
355  return builder.create<pdl::OperandsOp>(loc, mlirType,
356  getTypeConstraint());
357  if (eleTy.isa<ast::TypeType>())
358  return builder.create<pdl::TypesOp>(loc, mlirType, /*types=*/ArrayAttr());
359  }
360 
361  llvm_unreachable("invalid non-initialized variable type");
362 }
363 
364 void CodeGen::applyVarConstraints(const ast::VariableDecl *varDecl,
365  ValueRange values) {
366  // Generate calls to any user constraints that were attached via the
367  // constraint list.
368  for (const ast::ConstraintRef &ref : varDecl->getConstraints())
369  if (const auto *userCst = dyn_cast<ast::UserConstraintDecl>(ref.constraint))
370  genConstraintCall(userCst, genLoc(ref.referenceLoc), values);
371 }
372 
373 //===----------------------------------------------------------------------===//
374 // CodeGen: Expressions
375 //===----------------------------------------------------------------------===//
376 
377 Value CodeGen::genSingleExpr(const ast::Expr *expr) {
379  .Case<const ast::AttributeExpr, const ast::MemberAccessExpr,
380  const ast::OperationExpr, const ast::TypeExpr>(
381  [&](auto derivedNode) { return this->genExprImpl(derivedNode); })
382  .Case<const ast::CallExpr, const ast::DeclRefExpr, const ast::TupleExpr>(
383  [&](auto derivedNode) {
384  SmallVector<Value> results = this->genExprImpl(derivedNode);
385  assert(results.size() == 1 && "expected single expression result");
386  return results[0];
387  });
388 }
389 
390 SmallVector<Value> CodeGen::genExpr(const ast::Expr *expr) {
392  .Case<const ast::CallExpr, const ast::DeclRefExpr, const ast::TupleExpr>(
393  [&](auto derivedNode) { return this->genExprImpl(derivedNode); })
394  .Default([&](const ast::Expr *expr) -> SmallVector<Value> {
395  return {genSingleExpr(expr)};
396  });
397 }
398 
399 Value CodeGen::genExprImpl(const ast::AttributeExpr *expr) {
400  Attribute attr = parseAttribute(expr->getValue(), builder.getContext());
401  assert(attr && "invalid MLIR attribute data");
402  return builder.create<pdl::AttributeOp>(genLoc(expr->getLoc()), attr);
403 }
404 
405 SmallVector<Value> CodeGen::genExprImpl(const ast::CallExpr *expr) {
406  Location loc = genLoc(expr->getLoc());
407  SmallVector<Value> arguments;
408  for (const ast::Expr *arg : expr->getArguments())
409  arguments.push_back(genSingleExpr(arg));
410 
411  // Resolve the callable expression of this call.
412  auto *callableExpr = dyn_cast<ast::DeclRefExpr>(expr->getCallableExpr());
413  assert(callableExpr && "unhandled CallExpr callable");
414 
415  // Generate the PDL based on the type of callable.
416  const ast::Decl *callable = callableExpr->getDecl();
417  if (const auto *decl = dyn_cast<ast::UserConstraintDecl>(callable))
418  return genConstraintCall(decl, loc, arguments);
419  if (const auto *decl = dyn_cast<ast::UserRewriteDecl>(callable))
420  return genRewriteCall(decl, loc, arguments);
421  llvm_unreachable("unhandled CallExpr callable");
422 }
423 
424 SmallVector<Value> CodeGen::genExprImpl(const ast::DeclRefExpr *expr) {
425  if (const auto *varDecl = dyn_cast<ast::VariableDecl>(expr->getDecl()))
426  return genVar(varDecl);
427  llvm_unreachable("unknown decl reference expression");
428 }
429 
430 Value CodeGen::genExprImpl(const ast::MemberAccessExpr *expr) {
431  Location loc = genLoc(expr->getLoc());
432  StringRef name = expr->getMemberName();
433  SmallVector<Value> parentExprs = genExpr(expr->getParentExpr());
434  ast::Type parentType = expr->getParentExpr()->getType();
435 
436  // Handle operation based member access.
437  if (ast::OperationType opType = parentType.dyn_cast<ast::OperationType>()) {
438  if (isa<ast::AllResultsMemberAccessExpr>(expr)) {
439  Type mlirType = genType(expr->getType());
440  if (mlirType.isa<pdl::ValueType>())
441  return builder.create<pdl::ResultOp>(loc, mlirType, parentExprs[0],
442  builder.getI32IntegerAttr(0));
443  return builder.create<pdl::ResultsOp>(loc, mlirType, parentExprs[0]);
444  }
445 
446  const ods::Operation *odsOp = opType.getODSOperation();
447  if (!odsOp) {
448  assert(llvm::isDigit(name[0]) &&
449  "unregistered op only allows numeric indexing");
450  unsigned resultIndex;
451  name.getAsInteger(/*Radix=*/10, resultIndex);
452  IntegerAttr index = builder.getI32IntegerAttr(resultIndex);
453  return builder.create<pdl::ResultOp>(loc, genType(expr->getType()),
454  parentExprs[0], index);
455  }
456 
457  // Find the result with the member name or by index.
458  ArrayRef<ods::OperandOrResult> results = odsOp->getResults();
459  unsigned resultIndex = results.size();
460  if (llvm::isDigit(name[0])) {
461  name.getAsInteger(/*Radix=*/10, resultIndex);
462  } else {
463  auto findFn = [&](const ods::OperandOrResult &result) {
464  return result.getName() == name;
465  };
466  resultIndex = llvm::find_if(results, findFn) - results.begin();
467  }
468  assert(resultIndex < results.size() && "invalid result index");
469 
470  // Generate the result access.
471  IntegerAttr index = builder.getI32IntegerAttr(resultIndex);
472  return builder.create<pdl::ResultsOp>(loc, genType(expr->getType()),
473  parentExprs[0], index);
474  }
475 
476  // Handle tuple based member access.
477  if (auto tupleType = parentType.dyn_cast<ast::TupleType>()) {
478  auto elementNames = tupleType.getElementNames();
479 
480  // The index is either a numeric index, or a name.
481  unsigned index = 0;
482  if (llvm::isDigit(name[0]))
483  name.getAsInteger(/*Radix=*/10, index);
484  else
485  index = llvm::find(elementNames, name) - elementNames.begin();
486 
487  assert(index < parentExprs.size() && "invalid result index");
488  return parentExprs[index];
489  }
490 
491  llvm_unreachable("unhandled member access expression");
492 }
493 
494 Value CodeGen::genExprImpl(const ast::OperationExpr *expr) {
495  Location loc = genLoc(expr->getLoc());
496  Optional<StringRef> opName = expr->getName();
497 
498  // Operands.
499  SmallVector<Value> operands;
500  for (const ast::Expr *operand : expr->getOperands())
501  operands.push_back(genSingleExpr(operand));
502 
503  // Attributes.
504  SmallVector<StringRef> attrNames;
505  SmallVector<Value> attrValues;
506  for (const ast::NamedAttributeDecl *attr : expr->getAttributes()) {
507  attrNames.push_back(attr->getName().getName());
508  attrValues.push_back(genSingleExpr(attr->getValue()));
509  }
510 
511  // Results.
512  SmallVector<Value> results;
513  for (const ast::Expr *result : expr->getResultTypes())
514  results.push_back(genSingleExpr(result));
515 
516  return builder.create<pdl::OperationOp>(loc, opName, operands, attrNames,
517  attrValues, results);
518 }
519 
520 SmallVector<Value> CodeGen::genExprImpl(const ast::TupleExpr *expr) {
521  SmallVector<Value> elements;
522  for (const ast::Expr *element : expr->getElements())
523  elements.push_back(genSingleExpr(element));
524  return elements;
525 }
526 
527 Value CodeGen::genExprImpl(const ast::TypeExpr *expr) {
528  Type type = parseType(expr->getValue(), builder.getContext());
529  assert(type && "invalid MLIR type data");
530  return builder.create<pdl::TypeOp>(genLoc(expr->getLoc()),
531  builder.getType<pdl::TypeType>(),
532  TypeAttr::get(type));
533 }
534 
536 CodeGen::genConstraintCall(const ast::UserConstraintDecl *decl, Location loc,
537  ValueRange inputs) {
538  // Apply any constraints defined on the arguments to the input values.
539  for (auto it : llvm::zip(decl->getInputs(), inputs))
540  applyVarConstraints(std::get<0>(it), std::get<1>(it));
541 
542  // Generate the constraint call.
543  SmallVector<Value> results =
544  genConstraintOrRewriteCall<pdl::ApplyNativeConstraintOp>(decl, loc,
545  inputs);
546 
547  // Apply any constraints defined on the results of the constraint.
548  for (auto it : llvm::zip(decl->getResults(), results))
549  applyVarConstraints(std::get<0>(it), std::get<1>(it));
550  return results;
551 }
552 
553 SmallVector<Value> CodeGen::genRewriteCall(const ast::UserRewriteDecl *decl,
554  Location loc, ValueRange inputs) {
555  return genConstraintOrRewriteCall<pdl::ApplyNativeRewriteOp>(decl, loc,
556  inputs);
557 }
558 
559 template <typename PDLOpT, typename T>
560 SmallVector<Value> CodeGen::genConstraintOrRewriteCall(const T *decl,
561  Location loc,
562  ValueRange inputs) {
563  const ast::CompoundStmt *cstBody = decl->getBody();
564 
565  // If the decl doesn't have a statement body, it is a native decl.
566  if (!cstBody) {
567  ast::Type declResultType = decl->getResultType();
568  SmallVector<Type> resultTypes;
569  if (ast::TupleType tupleType = declResultType.dyn_cast<ast::TupleType>()) {
570  for (ast::Type type : tupleType.getElementTypes())
571  resultTypes.push_back(genType(type));
572  } else {
573  resultTypes.push_back(genType(declResultType));
574  }
575  Operation *pdlOp = builder.create<PDLOpT>(
576  loc, resultTypes, decl->getName().getName(), inputs);
577  return pdlOp->getResults();
578  }
579 
580  // Otherwise, this is a PDLL decl.
581  VariableMapTy::ScopeTy varScope(variables);
582 
583  // Map the inputs of the call to the decl arguments.
584  // Note: This is only valid because we do not support recursion, meaning
585  // we don't need to worry about conflicting mappings here.
586  for (auto it : llvm::zip(inputs, decl->getInputs()))
587  variables.insert(std::get<1>(it), {std::get<0>(it)});
588 
589  // Visit the body of the call as normal.
590  gen(cstBody);
591 
592  // If the decl has no results, there is nothing to do.
593  if (cstBody->getChildren().empty())
594  return SmallVector<Value>();
595  auto *returnStmt = dyn_cast<ast::ReturnStmt>(cstBody->getChildren().back());
596  if (!returnStmt)
597  return SmallVector<Value>();
598 
599  // Otherwise, grab the results from the return statement.
600  return genExpr(returnStmt->getResultExpr());
601 }
602 
603 //===----------------------------------------------------------------------===//
604 // MLIRGen
605 //===----------------------------------------------------------------------===//
606 
608  MLIRContext *mlirContext, const ast::Context &context,
609  const llvm::SourceMgr &sourceMgr, const ast::Module &module) {
610  CodeGen codegen(mlirContext, context, sourceMgr);
611  OwningOpRef<ModuleOp> mlirModule = codegen.generate(module);
612  if (failed(verify(*mlirModule)))
613  return nullptr;
614  return mlirModule;
615 }
Include the generated interface declarations.
Attribute parseAttribute(llvm::StringRef attrStr, MLIRContext *context)
This parses a single MLIR attribute to an MLIR context if it was valid.
CompoundStmt * getRewriteBody() const
Return the compound rewrite body.
Definition: Nodes.h:304
MutableArrayRef< Expr * > getArguments()
Return the arguments of this call.
Definition: Nodes.h:398
Block * getInsertionBlock() const
Return the block the current insertion point belongs to.
Definition: Builders.h:388
This class represents a PDLL type that corresponds to an mlir::Value.
Definition: Types.h:283
Operation is a basic unit of execution within MLIR.
Definition: Operation.h:28
This class represents the main context of the PDLL AST.
Definition: Context.h:25
Operation * getParentOp()
Returns the closest surrounding operation that contains this block.
Definition: Block.cpp:30
Type getType() const
Return the type of the decl.
Definition: Nodes.h:1216
MutableArrayRef< Expr * > getResultTypes()
Return the result types of this operation.
Definition: Nodes.h:530
U dyn_cast() const
Definition: Types.h:75
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value...
Definition: LogicalResult.h:72
This decl represents a user defined constraint.
Definition: Nodes.h:836
This class represents a PDLL tuple type, i.e.
Definition: Types.h:242
bool isa() const
Provide type casting support.
Definition: Types.h:66
void loadDialect()
Load a dialect in the context.
Definition: MLIRContext.h:102
Expr * getInitExpr() const
Return the initializer expression of this statement, or nullptr if there was no initializer.
Definition: Nodes.h:1210
This expression represents a literal MLIR Attribute, and contains the textual assembly format of that...
Definition: Nodes.h:366
This class contains all of the registered ODS operation classes.
Definition: Context.h:32
SMRange getLoc() const
Return the location of this node.
Definition: Nodes.h:130
Expr * getRootOpExpr() const
Return the root operation of this rewrite.
Definition: Nodes.h:236
MutableArrayRef< Expr * > getOperands()
Return the operands of this operation.
Definition: Nodes.h:522
static const Name & create(Context &ctx, StringRef name, SMRange location)
Definition: Nodes.cpp:32
Optional< uint16_t > getBenefit() const
Return the benefit of this pattern if specified, or None.
Definition: Nodes.h:998
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:48
This class represents the base Decl node.
Definition: Nodes.h:625
The class represents an Attribute constraint, and constrains a variable to be an Attribute.
Definition: Nodes.h:704
MutableArrayRef< VariableDecl * > getInputs()
Return the input arguments of this constraint.
Definition: Nodes.h:863
This statement represents the replace statement in PDLL.
Definition: Nodes.h:267
Decl * getDecl() const
Get the decl referenced by this expression.
Definition: Nodes.h:429
StringRef getName() const
Return the raw string name.
Definition: Nodes.h:40
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:404
StringRef getValue() const
Get the raw value of this expression.
Definition: Nodes.h:372
Type getType() const
Return the type of this expression.
Definition: Nodes.h:347
static Operation * create(Location location, OperationName name, TypeRange resultTypes, ValueRange operands, NamedAttrList &&attributes, BlockRange successors, unsigned numRegions)
Create a new Operation with the specific fields.
Definition: Operation.cpp:49
MutableArrayRef< NamedAttributeDecl * > getAttributes()
Return the attributes of this operation.
Definition: Nodes.h:538
Optional< StringRef > getName() const
Return the name of the operation, or None if there isn&#39;t one.
Definition: Nodes.cpp:324
This statement represents the erase statement in PDLL.
Definition: Nodes.h:253
This statement represents an operation rewrite that contains a block of nested rewrite commands...
Definition: Nodes.h:298
This expression represents a reference to a Decl node.
Definition: Nodes.h:424
The class represents a ValueRange constraint, and constrains a variable to be a ValueRange.
Definition: Nodes.h:802
MutableArrayRef< Expr * > getElements()
Return the element expressions of this tuple.
Definition: Nodes.h:581
static void checkAndNestUnderRewriteOp(OpBuilder &builder, Value rootExpr, Location loc)
If the given builder is nested under a PDL PatternOp, build a rewrite operation and update the builde...
Definition: MLIRGen.cpp:200
Attributes are known-constant values of operations.
Definition: Attributes.h:24
const Name * getName() const
Return the name of the decl, or nullptr if it doesn&#39;t have one.
Definition: Nodes.h:628
This class represents a base AST Statement node.
Definition: Nodes.h:163
This class represents a PDLL type that corresponds to an mlir::Attribute.
Definition: Types.h:130
This class represents a reference to a constraint, and contains a constraint and the location of the ...
Definition: Nodes.h:672
const CompoundStmt * getBody() const
Return the body of this pattern.
Definition: Nodes.h:1004
This expression represents a named member or field access of a given parent expression.
Definition: Nodes.h:445
Expr * getCallableExpr() const
Return the callable of this call.
Definition: Nodes.h:395
This Decl represents the definition of a PDLL variable.
Definition: Nodes.h:1192
MutableArrayRef< Decl * > getChildren()
Return the children of this module.
Definition: Nodes.h:1248
MutableArrayRef< ConstraintRef > getConstraints()
Return the constraints of this variable.
Definition: Nodes.h:1201
This expression represents the structural form of an MLIR Operation.
Definition: Nodes.h:499
This class represents a base AST node.
Definition: Nodes.h:107
Type parseType(llvm::StringRef typeStr, MLIRContext *context)
This parses a single MLIR type to an MLIR context if it was valid.
This decl represents a user defined rewrite.
Definition: Nodes.h:1042
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:72
The class represents a Value constraint, and constrains a variable to be a Value. ...
Definition: Nodes.h:780
static void rewrite(DataFlowSolver &solver, MLIRContext *context, MutableArrayRef< Region > initialRegions)
Rewrite the given regions using the computing analysis.
Definition: SCCP.cpp:62
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:85
MutableArrayRef< Expr * > getReplExprs()
Return the replacement values of this statement.
Definition: Nodes.h:274
static int resultIndex(int i)
Definition: Operator.cpp:308
RAII guard to reset the insertion point of the builder when destroyed.
Definition: Builders.h:294
This Decl represents a NamedAttribute, and contains a string name and attribute value.
Definition: Nodes.h:945
This class provides an ODS representation of a specific operation.
Definition: Operation.h:125
This class provides a convenient API for interacting with source names.
Definition: Nodes.h:36
This statement represents a let statement in PDLL.
Definition: Nodes.h:210
VariableDecl * getVarDecl() const
Return the variable defined by this statement.
Definition: Nodes.h:215
This expression represents a literal MLIR Type, and contains the textual assembly format of that type...
Definition: Nodes.h:604
This class represents a PDLL type that corresponds to a range of elements with a given element type...
Definition: Types.h:181
This Decl represents a single Pattern.
Definition: Nodes.h:990
StringRef getMemberName() const
Return the name of the member being accessed.
Definition: Nodes.h:455
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:55
This expression builds a tuple from a set of element values.
Definition: Nodes.h:574
This class provides an ODS representation of a specific operation operand or result.
Definition: Operation.h:74
MutableArrayRef< Stmt * > getChildren()
Return the children of this compound statement.
Definition: Nodes.h:184
OwningOpRef< ModuleOp > codegenPDLLToMLIR(MLIRContext *mlirContext, const ast::Context &context, const llvm::SourceMgr &sourceMgr, const ast::Module &module)
Given a PDLL module, generate an MLIR PDL pattern module within the given MLIR context.
Definition: MLIRGen.cpp:607
This statement represents a return from a "callable" like decl, e.g.
Definition: Nodes.h:320
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs, on this operation and any nested operations.
Definition: Verifier.cpp:372
const Expr * getParentExpr() const
Get the parent expression of this access.
Definition: Nodes.h:452
Block * createBlock(Region *parent, Region::iterator insertPt={}, TypeRange argTypes=llvm::None, ArrayRef< Location > locs=llvm::None)
Add new block with &#39;argTypes&#39; arguments and set the insertion point to the end of it...
Definition: Builders.cpp:377
This class represents a PDLL type that corresponds to an mlir::Type.
Definition: Types.h:270
bool isa() const
Definition: Types.h:254
StringRef getValue() const
Get the raw value of this expression.
Definition: Nodes.h:610
This statement represents a compound statement, which contains a collection of other statements...
Definition: Nodes.h:177
This class represents a top-level AST module.
Definition: Nodes.h:1242
result_range getResults()
Definition: Operation.h:332
This class helps build Operations.
Definition: Builders.h:192
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:345
This class represents a base AST Expression node.
Definition: Nodes.h:344
This class represents a PDLL type that corresponds to an mlir::Operation.
Definition: Types.h:157
This expression represents a call to a decl, such as a UserConstraintDecl/UserRewriteDecl.
Definition: Nodes.h:388