MLIR  14.0.0git
TranslateToCpp.cpp
Go to the documentation of this file.
1 //===- TranslateToCpp.cpp - Translating to C++ calls ----------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include <utility>
10 
12 #include "mlir/Dialect/SCF/SCF.h"
14 #include "mlir/IR/BuiltinOps.h"
15 #include "mlir/IR/BuiltinTypes.h"
16 #include "mlir/IR/Dialect.h"
17 #include "mlir/IR/Operation.h"
20 #include "llvm/ADT/DenseMap.h"
21 #include "llvm/ADT/StringExtras.h"
22 #include "llvm/ADT/StringMap.h"
23 #include "llvm/ADT/TypeSwitch.h"
24 #include "llvm/Support/Debug.h"
25 #include "llvm/Support/FormatVariadic.h"
26 
27 #define DEBUG_TYPE "translate-to-cpp"
28 
29 using namespace mlir;
30 using namespace mlir::emitc;
31 using llvm::formatv;
32 
33 /// Convenience functions to produce interleaved output with functions returning
34 /// a LogicalResult. This is different than those in STLExtras as functions used
35 /// on each element doesn't return a string.
36 template <typename ForwardIterator, typename UnaryFunctor,
37  typename NullaryFunctor>
38 inline LogicalResult
39 interleaveWithError(ForwardIterator begin, ForwardIterator end,
40  UnaryFunctor eachFn, NullaryFunctor betweenFn) {
41  if (begin == end)
42  return success();
43  if (failed(eachFn(*begin)))
44  return failure();
45  ++begin;
46  for (; begin != end; ++begin) {
47  betweenFn();
48  if (failed(eachFn(*begin)))
49  return failure();
50  }
51  return success();
52 }
53 
54 template <typename Container, typename UnaryFunctor, typename NullaryFunctor>
55 inline LogicalResult interleaveWithError(const Container &c,
56  UnaryFunctor eachFn,
57  NullaryFunctor betweenFn) {
58  return interleaveWithError(c.begin(), c.end(), eachFn, betweenFn);
59 }
60 
61 template <typename Container, typename UnaryFunctor>
62 inline LogicalResult interleaveCommaWithError(const Container &c,
63  raw_ostream &os,
64  UnaryFunctor eachFn) {
65  return interleaveWithError(c.begin(), c.end(), eachFn, [&]() { os << ", "; });
66 }
67 
68 namespace {
69 /// Emitter that uses dialect specific emitters to emit C++ code.
70 struct CppEmitter {
71  explicit CppEmitter(raw_ostream &os, bool declareVariablesAtTop);
72 
73  /// Emits attribute or returns failure.
74  LogicalResult emitAttribute(Location loc, Attribute attr);
75 
76  /// Emits operation 'op' with/without training semicolon or returns failure.
77  LogicalResult emitOperation(Operation &op, bool trailingSemicolon);
78 
79  /// Emits type 'type' or returns failure.
80  LogicalResult emitType(Location loc, Type type);
81 
82  /// Emits array of types as a std::tuple of the emitted types.
83  /// - emits void for an empty array;
84  /// - emits the type of the only element for arrays of size one;
85  /// - emits a std::tuple otherwise;
86  LogicalResult emitTypes(Location loc, ArrayRef<Type> types);
87 
88  /// Emits array of types as a std::tuple of the emitted types independently of
89  /// the array size.
90  LogicalResult emitTupleType(Location loc, ArrayRef<Type> types);
91 
92  /// Emits an assignment for a variable which has been declared previously.
93  LogicalResult emitVariableAssignment(OpResult result);
94 
95  /// Emits a variable declaration for a result of an operation.
96  LogicalResult emitVariableDeclaration(OpResult result,
97  bool trailingSemicolon);
98 
99  /// Emits the variable declaration and assignment prefix for 'op'.
100  /// - emits separate variable followed by std::tie for multi-valued operation;
101  /// - emits single type followed by variable for single result;
102  /// - emits nothing if no value produced by op;
103  /// Emits final '=' operator where a type is produced. Returns failure if
104  /// any result type could not be converted.
105  LogicalResult emitAssignPrefix(Operation &op);
106 
107  /// Emits a label for the block.
108  LogicalResult emitLabel(Block &block);
109 
110  /// Emits the operands and atttributes of the operation. All operands are
111  /// emitted first and then all attributes in alphabetical order.
112  LogicalResult emitOperandsAndAttributes(Operation &op,
113  ArrayRef<StringRef> exclude = {});
114 
115  /// Emits the operands of the operation. All operands are emitted in order.
116  LogicalResult emitOperands(Operation &op);
117 
118  /// Return the existing or a new name for a Value.
119  StringRef getOrCreateName(Value val);
120 
121  /// Return the existing or a new label of a Block.
122  StringRef getOrCreateName(Block &block);
123 
124  /// Whether to map an mlir integer to a unsigned integer in C++.
125  bool shouldMapToUnsigned(IntegerType::SignednessSemantics val);
126 
127  /// RAII helper function to manage entering/exiting C++ scopes.
128  struct Scope {
129  Scope(CppEmitter &emitter)
130  : valueMapperScope(emitter.valueMapper),
131  blockMapperScope(emitter.blockMapper), emitter(emitter) {
132  emitter.valueInScopeCount.push(emitter.valueInScopeCount.top());
133  emitter.labelInScopeCount.push(emitter.labelInScopeCount.top());
134  }
135  ~Scope() {
136  emitter.valueInScopeCount.pop();
137  emitter.labelInScopeCount.pop();
138  }
139 
140  private:
141  llvm::ScopedHashTableScope<Value, std::string> valueMapperScope;
142  llvm::ScopedHashTableScope<Block *, std::string> blockMapperScope;
143  CppEmitter &emitter;
144  };
145 
146  /// Returns wether the Value is assigned to a C++ variable in the scope.
147  bool hasValueInScope(Value val);
148 
149  // Returns whether a label is assigned to the block.
150  bool hasBlockLabel(Block &block);
151 
152  /// Returns the output stream.
153  raw_indented_ostream &ostream() { return os; };
154 
155  /// Returns if all variables for op results and basic block arguments need to
156  /// be declared at the beginning of a function.
157  bool shouldDeclareVariablesAtTop() { return declareVariablesAtTop; };
158 
159 private:
160  using ValueMapper = llvm::ScopedHashTable<Value, std::string>;
161  using BlockMapper = llvm::ScopedHashTable<Block *, std::string>;
162 
163  /// Output stream to emit to.
165 
166  /// Boolean to enforce that all variables for op results and block
167  /// arguments are declared at the beginning of the function. This also
168  /// includes results from ops located in nested regions.
169  bool declareVariablesAtTop;
170 
171  /// Map from value to name of C++ variable that contain the name.
172  ValueMapper valueMapper;
173 
174  /// Map from block to name of C++ label.
175  BlockMapper blockMapper;
176 
177  /// The number of values in the current scope. This is used to declare the
178  /// names of values in a scope.
179  std::stack<int64_t> valueInScopeCount;
180  std::stack<int64_t> labelInScopeCount;
181 };
182 } // namespace
183 
184 static LogicalResult printConstantOp(CppEmitter &emitter, Operation *operation,
185  Attribute value) {
186  OpResult result = operation->getResult(0);
187 
188  // Only emit an assignment as the variable was already declared when printing
189  // the FuncOp.
190  if (emitter.shouldDeclareVariablesAtTop()) {
191  // Skip the assignment if the emitc.constant has no value.
192  if (auto oAttr = value.dyn_cast<emitc::OpaqueAttr>()) {
193  if (oAttr.getValue().empty())
194  return success();
195  }
196 
197  if (failed(emitter.emitVariableAssignment(result)))
198  return failure();
199  return emitter.emitAttribute(operation->getLoc(), value);
200  }
201 
202  // Emit a variable declaration for an emitc.constant op without value.
203  if (auto oAttr = value.dyn_cast<emitc::OpaqueAttr>()) {
204  if (oAttr.getValue().empty())
205  // The semicolon gets printed by the emitOperation function.
206  return emitter.emitVariableDeclaration(result,
207  /*trailingSemicolon=*/false);
208  }
209 
210  // Emit a variable declaration.
211  if (failed(emitter.emitAssignPrefix(*operation)))
212  return failure();
213  return emitter.emitAttribute(operation->getLoc(), value);
214 }
215 
216 static LogicalResult printOperation(CppEmitter &emitter,
217  emitc::ConstantOp constantOp) {
218  Operation *operation = constantOp.getOperation();
219  Attribute value = constantOp.value();
220 
221  return printConstantOp(emitter, operation, value);
222 }
223 
224 static LogicalResult printOperation(CppEmitter &emitter,
225  arith::ConstantOp constantOp) {
226  Operation *operation = constantOp.getOperation();
227  Attribute value = constantOp.getValue();
228 
229  return printConstantOp(emitter, operation, value);
230 }
231 
232 static LogicalResult printOperation(CppEmitter &emitter,
233  mlir::ConstantOp constantOp) {
234  Operation *operation = constantOp.getOperation();
235  Attribute value = constantOp.getValue();
236 
237  return printConstantOp(emitter, operation, value);
238 }
239 
240 static LogicalResult printOperation(CppEmitter &emitter, BranchOp branchOp) {
241  raw_ostream &os = emitter.ostream();
242  Block &successor = *branchOp.getSuccessor();
243 
244  for (auto pair :
245  llvm::zip(branchOp.getOperands(), successor.getArguments())) {
246  Value &operand = std::get<0>(pair);
247  BlockArgument &argument = std::get<1>(pair);
248  os << emitter.getOrCreateName(argument) << " = "
249  << emitter.getOrCreateName(operand) << ";\n";
250  }
251 
252  os << "goto ";
253  if (!(emitter.hasBlockLabel(successor)))
254  return branchOp.emitOpError("unable to find label for successor block");
255  os << emitter.getOrCreateName(successor);
256  return success();
257 }
258 
259 static LogicalResult printOperation(CppEmitter &emitter,
260  CondBranchOp condBranchOp) {
261  raw_indented_ostream &os = emitter.ostream();
262  Block &trueSuccessor = *condBranchOp.getTrueDest();
263  Block &falseSuccessor = *condBranchOp.getFalseDest();
264 
265  os << "if (" << emitter.getOrCreateName(condBranchOp.getCondition())
266  << ") {\n";
267 
268  os.indent();
269 
270  // If condition is true.
271  for (auto pair : llvm::zip(condBranchOp.getTrueOperands(),
272  trueSuccessor.getArguments())) {
273  Value &operand = std::get<0>(pair);
274  BlockArgument &argument = std::get<1>(pair);
275  os << emitter.getOrCreateName(argument) << " = "
276  << emitter.getOrCreateName(operand) << ";\n";
277  }
278 
279  os << "goto ";
280  if (!(emitter.hasBlockLabel(trueSuccessor))) {
281  return condBranchOp.emitOpError("unable to find label for successor block");
282  }
283  os << emitter.getOrCreateName(trueSuccessor) << ";\n";
284  os.unindent() << "} else {\n";
285  os.indent();
286  // If condition is false.
287  for (auto pair : llvm::zip(condBranchOp.getFalseOperands(),
288  falseSuccessor.getArguments())) {
289  Value &operand = std::get<0>(pair);
290  BlockArgument &argument = std::get<1>(pair);
291  os << emitter.getOrCreateName(argument) << " = "
292  << emitter.getOrCreateName(operand) << ";\n";
293  }
294 
295  os << "goto ";
296  if (!(emitter.hasBlockLabel(falseSuccessor))) {
297  return condBranchOp.emitOpError()
298  << "unable to find label for successor block";
299  }
300  os << emitter.getOrCreateName(falseSuccessor) << ";\n";
301  os.unindent() << "}";
302  return success();
303 }
304 
305 static LogicalResult printOperation(CppEmitter &emitter, mlir::CallOp callOp) {
306  if (failed(emitter.emitAssignPrefix(*callOp.getOperation())))
307  return failure();
308 
309  raw_ostream &os = emitter.ostream();
310  os << callOp.getCallee() << "(";
311  if (failed(emitter.emitOperands(*callOp.getOperation())))
312  return failure();
313  os << ")";
314  return success();
315 }
316 
317 static LogicalResult printOperation(CppEmitter &emitter, emitc::CallOp callOp) {
318  raw_ostream &os = emitter.ostream();
319  Operation &op = *callOp.getOperation();
320 
321  if (failed(emitter.emitAssignPrefix(op)))
322  return failure();
323  os << callOp.callee();
324 
325  auto emitArgs = [&](Attribute attr) -> LogicalResult {
326  if (auto t = attr.dyn_cast<IntegerAttr>()) {
327  // Index attributes are treated specially as operand index.
328  if (t.getType().isIndex()) {
329  int64_t idx = t.getInt();
330  if ((idx < 0) || (idx >= op.getNumOperands()))
331  return op.emitOpError("invalid operand index");
332  if (!emitter.hasValueInScope(op.getOperand(idx)))
333  return op.emitOpError("operand ")
334  << idx << "'s value not defined in scope";
335  os << emitter.getOrCreateName(op.getOperand(idx));
336  return success();
337  }
338  }
339  if (failed(emitter.emitAttribute(op.getLoc(), attr)))
340  return failure();
341 
342  return success();
343  };
344 
345  if (callOp.template_args()) {
346  os << "<";
347  if (failed(interleaveCommaWithError(*callOp.template_args(), os, emitArgs)))
348  return failure();
349  os << ">";
350  }
351 
352  os << "(";
353 
354  LogicalResult emittedArgs =
355  callOp.args() ? interleaveCommaWithError(*callOp.args(), os, emitArgs)
356  : emitter.emitOperands(op);
357  if (failed(emittedArgs))
358  return failure();
359  os << ")";
360  return success();
361 }
362 
363 static LogicalResult printOperation(CppEmitter &emitter,
364  emitc::ApplyOp applyOp) {
365  raw_ostream &os = emitter.ostream();
366  Operation &op = *applyOp.getOperation();
367 
368  if (failed(emitter.emitAssignPrefix(op)))
369  return failure();
370  os << applyOp.applicableOperator();
371  os << emitter.getOrCreateName(applyOp.getOperand());
372 
373  return success();
374 }
375 
376 static LogicalResult printOperation(CppEmitter &emitter,
377  emitc::IncludeOp includeOp) {
378  raw_ostream &os = emitter.ostream();
379 
380  os << "#include ";
381  if (includeOp.is_standard_include())
382  os << "<" << includeOp.include() << ">";
383  else
384  os << "\"" << includeOp.include() << "\"";
385 
386  return success();
387 }
388 
389 static LogicalResult printOperation(CppEmitter &emitter, scf::ForOp forOp) {
390 
391  raw_indented_ostream &os = emitter.ostream();
392 
393  OperandRange operands = forOp.getIterOperands();
394  Block::BlockArgListType iterArgs = forOp.getRegionIterArgs();
395  Operation::result_range results = forOp.getResults();
396 
397  if (!emitter.shouldDeclareVariablesAtTop()) {
398  for (OpResult result : results) {
399  if (failed(emitter.emitVariableDeclaration(result,
400  /*trailingSemicolon=*/true)))
401  return failure();
402  }
403  }
404 
405  for (auto pair : llvm::zip(iterArgs, operands)) {
406  if (failed(emitter.emitType(forOp.getLoc(), std::get<0>(pair).getType())))
407  return failure();
408  os << " " << emitter.getOrCreateName(std::get<0>(pair)) << " = ";
409  os << emitter.getOrCreateName(std::get<1>(pair)) << ";";
410  os << "\n";
411  }
412 
413  os << "for (";
414  if (failed(
415  emitter.emitType(forOp.getLoc(), forOp.getInductionVar().getType())))
416  return failure();
417  os << " ";
418  os << emitter.getOrCreateName(forOp.getInductionVar());
419  os << " = ";
420  os << emitter.getOrCreateName(forOp.getLowerBound());
421  os << "; ";
422  os << emitter.getOrCreateName(forOp.getInductionVar());
423  os << " < ";
424  os << emitter.getOrCreateName(forOp.getUpperBound());
425  os << "; ";
426  os << emitter.getOrCreateName(forOp.getInductionVar());
427  os << " += ";
428  os << emitter.getOrCreateName(forOp.getStep());
429  os << ") {\n";
430  os.indent();
431 
432  Region &forRegion = forOp.getRegion();
433  auto regionOps = forRegion.getOps();
434 
435  // We skip the trailing yield op because this updates the result variables
436  // of the for op in the generated code. Instead we update the iterArgs at
437  // the end of a loop iteration and set the result variables after the for
438  // loop.
439  for (auto it = regionOps.begin(); std::next(it) != regionOps.end(); ++it) {
440  if (failed(emitter.emitOperation(*it, /*trailingSemicolon=*/true)))
441  return failure();
442  }
443 
444  Operation *yieldOp = forRegion.getBlocks().front().getTerminator();
445  // Copy yield operands into iterArgs at the end of a loop iteration.
446  for (auto pair : llvm::zip(iterArgs, yieldOp->getOperands())) {
447  BlockArgument iterArg = std::get<0>(pair);
448  Value operand = std::get<1>(pair);
449  os << emitter.getOrCreateName(iterArg) << " = "
450  << emitter.getOrCreateName(operand) << ";\n";
451  }
452 
453  os.unindent() << "}";
454 
455  // Copy iterArgs into results after the for loop.
456  for (auto pair : llvm::zip(results, iterArgs)) {
457  OpResult result = std::get<0>(pair);
458  BlockArgument iterArg = std::get<1>(pair);
459  os << "\n"
460  << emitter.getOrCreateName(result) << " = "
461  << emitter.getOrCreateName(iterArg) << ";";
462  }
463 
464  return success();
465 }
466 
467 static LogicalResult printOperation(CppEmitter &emitter, scf::IfOp ifOp) {
468  raw_indented_ostream &os = emitter.ostream();
469 
470  if (!emitter.shouldDeclareVariablesAtTop()) {
471  for (OpResult result : ifOp.getResults()) {
472  if (failed(emitter.emitVariableDeclaration(result,
473  /*trailingSemicolon=*/true)))
474  return failure();
475  }
476  }
477 
478  os << "if (";
479  if (failed(emitter.emitOperands(*ifOp.getOperation())))
480  return failure();
481  os << ") {\n";
482  os.indent();
483 
484  Region &thenRegion = ifOp.getThenRegion();
485  for (Operation &op : thenRegion.getOps()) {
486  // Note: This prints a superfluous semicolon if the terminating yield op has
487  // zero results.
488  if (failed(emitter.emitOperation(op, /*trailingSemicolon=*/true)))
489  return failure();
490  }
491 
492  os.unindent() << "}";
493 
494  Region &elseRegion = ifOp.getElseRegion();
495  if (!elseRegion.empty()) {
496  os << " else {\n";
497  os.indent();
498 
499  for (Operation &op : elseRegion.getOps()) {
500  // Note: This prints a superfluous semicolon if the terminating yield op
501  // has zero results.
502  if (failed(emitter.emitOperation(op, /*trailingSemicolon=*/true)))
503  return failure();
504  }
505 
506  os.unindent() << "}";
507  }
508 
509  return success();
510 }
511 
512 static LogicalResult printOperation(CppEmitter &emitter, scf::YieldOp yieldOp) {
513  raw_ostream &os = emitter.ostream();
514  Operation &parentOp = *yieldOp.getOperation()->getParentOp();
515 
516  if (yieldOp.getNumOperands() != parentOp.getNumResults()) {
517  return yieldOp.emitError("number of operands does not to match the number "
518  "of the parent op's results");
519  }
520 
522  llvm::zip(parentOp.getResults(), yieldOp.getOperands()),
523  [&](auto pair) -> LogicalResult {
524  auto result = std::get<0>(pair);
525  auto operand = std::get<1>(pair);
526  os << emitter.getOrCreateName(result) << " = ";
527 
528  if (!emitter.hasValueInScope(operand))
529  return yieldOp.emitError("operand value not in scope");
530  os << emitter.getOrCreateName(operand);
531  return success();
532  },
533  [&]() { os << ";\n"; })))
534  return failure();
535 
536  return success();
537 }
538 
539 static LogicalResult printOperation(CppEmitter &emitter, ReturnOp returnOp) {
540  raw_ostream &os = emitter.ostream();
541  os << "return";
542  switch (returnOp.getNumOperands()) {
543  case 0:
544  return success();
545  case 1:
546  os << " " << emitter.getOrCreateName(returnOp.getOperand(0));
547  return success(emitter.hasValueInScope(returnOp.getOperand(0)));
548  default:
549  os << " std::make_tuple(";
550  if (failed(emitter.emitOperandsAndAttributes(*returnOp.getOperation())))
551  return failure();
552  os << ")";
553  return success();
554  }
555 }
556 
557 static LogicalResult printOperation(CppEmitter &emitter, ModuleOp moduleOp) {
558  CppEmitter::Scope scope(emitter);
559 
560  for (Operation &op : moduleOp) {
561  if (failed(emitter.emitOperation(op, /*trailingSemicolon=*/false)))
562  return failure();
563  }
564  return success();
565 }
566 
567 static LogicalResult printOperation(CppEmitter &emitter, FuncOp functionOp) {
568  // We need to declare variables at top if the function has multiple blocks.
569  if (!emitter.shouldDeclareVariablesAtTop() &&
570  functionOp.getBlocks().size() > 1) {
571  return functionOp.emitOpError(
572  "with multiple blocks needs variables declared at top");
573  }
574 
575  CppEmitter::Scope scope(emitter);
576  raw_indented_ostream &os = emitter.ostream();
577  if (failed(emitter.emitTypes(functionOp.getLoc(),
578  functionOp.getType().getResults())))
579  return failure();
580  os << " " << functionOp.getName();
581 
582  os << "(";
584  functionOp.getArguments(), os,
585  [&](BlockArgument arg) -> LogicalResult {
586  if (failed(emitter.emitType(functionOp.getLoc(), arg.getType())))
587  return failure();
588  os << " " << emitter.getOrCreateName(arg);
589  return success();
590  })))
591  return failure();
592  os << ") {\n";
593  os.indent();
594  if (emitter.shouldDeclareVariablesAtTop()) {
595  // Declare all variables that hold op results including those from nested
596  // regions.
597  WalkResult result =
598  functionOp.walk<WalkOrder::PreOrder>([&](Operation *op) -> WalkResult {
599  for (OpResult result : op->getResults()) {
600  if (failed(emitter.emitVariableDeclaration(
601  result, /*trailingSemicolon=*/true))) {
602  return WalkResult(
603  op->emitError("unable to declare result variable for op"));
604  }
605  }
606  return WalkResult::advance();
607  });
608  if (result.wasInterrupted())
609  return failure();
610  }
611 
612  Region::BlockListType &blocks = functionOp.getBlocks();
613  // Create label names for basic blocks.
614  for (Block &block : blocks) {
615  emitter.getOrCreateName(block);
616  }
617 
618  // Declare variables for basic block arguments.
619  for (auto it = std::next(blocks.begin()); it != blocks.end(); ++it) {
620  Block &block = *it;
621  for (BlockArgument &arg : block.getArguments()) {
622  if (emitter.hasValueInScope(arg))
623  return functionOp.emitOpError(" block argument #")
624  << arg.getArgNumber() << " is out of scope";
625  if (failed(
626  emitter.emitType(block.getParentOp()->getLoc(), arg.getType()))) {
627  return failure();
628  }
629  os << " " << emitter.getOrCreateName(arg) << ";\n";
630  }
631  }
632 
633  for (Block &block : blocks) {
634  // Only print a label if there is more than one block.
635  if (blocks.size() > 1) {
636  if (failed(emitter.emitLabel(block)))
637  return failure();
638  }
639  for (Operation &op : block.getOperations()) {
640  // When generating code for an scf.if or std.cond_br op no semicolon needs
641  // to be printed after the closing brace.
642  // When generating code for an scf.for op, printing a trailing semicolon
643  // is handled within the printOperation function.
644  bool trailingSemicolon = !isa<scf::IfOp, scf::ForOp, CondBranchOp>(op);
645 
646  if (failed(emitter.emitOperation(
647  op, /*trailingSemicolon=*/trailingSemicolon)))
648  return failure();
649  }
650  }
651  os.unindent() << "}\n";
652  return success();
653 }
654 
655 CppEmitter::CppEmitter(raw_ostream &os, bool declareVariablesAtTop)
656  : os(os), declareVariablesAtTop(declareVariablesAtTop) {
657  valueInScopeCount.push(0);
658  labelInScopeCount.push(0);
659 }
660 
661 /// Return the existing or a new name for a Value.
662 StringRef CppEmitter::getOrCreateName(Value val) {
663  if (!valueMapper.count(val))
664  valueMapper.insert(val, formatv("v{0}", ++valueInScopeCount.top()));
665  return *valueMapper.begin(val);
666 }
667 
668 /// Return the existing or a new label for a Block.
669 StringRef CppEmitter::getOrCreateName(Block &block) {
670  if (!blockMapper.count(&block))
671  blockMapper.insert(&block, formatv("label{0}", ++labelInScopeCount.top()));
672  return *blockMapper.begin(&block);
673 }
674 
675 bool CppEmitter::shouldMapToUnsigned(IntegerType::SignednessSemantics val) {
676  switch (val) {
677  case IntegerType::Signless:
678  return false;
679  case IntegerType::Signed:
680  return false;
681  case IntegerType::Unsigned:
682  return true;
683  }
684  llvm_unreachable("Unexpected IntegerType::SignednessSemantics");
685 }
686 
687 bool CppEmitter::hasValueInScope(Value val) { return valueMapper.count(val); }
688 
689 bool CppEmitter::hasBlockLabel(Block &block) {
690  return blockMapper.count(&block);
691 }
692 
693 LogicalResult CppEmitter::emitAttribute(Location loc, Attribute attr) {
694  auto printInt = [&](const APInt &val, bool isUnsigned) {
695  if (val.getBitWidth() == 1) {
696  if (val.getBoolValue())
697  os << "true";
698  else
699  os << "false";
700  } else {
701  SmallString<128> strValue;
702  val.toString(strValue, 10, !isUnsigned, false);
703  os << strValue;
704  }
705  };
706 
707  auto printFloat = [&](const APFloat &val) {
708  if (val.isFinite()) {
709  SmallString<128> strValue;
710  // Use default values of toString except don't truncate zeros.
711  val.toString(strValue, 0, 0, false);
712  switch (llvm::APFloatBase::SemanticsToEnum(val.getSemantics())) {
713  case llvm::APFloatBase::S_IEEEsingle:
714  os << "(float)";
715  break;
716  case llvm::APFloatBase::S_IEEEdouble:
717  os << "(double)";
718  break;
719  default:
720  break;
721  };
722  os << strValue;
723  } else if (val.isNaN()) {
724  os << "NAN";
725  } else if (val.isInfinity()) {
726  if (val.isNegative())
727  os << "-";
728  os << "INFINITY";
729  }
730  };
731 
732  // Print floating point attributes.
733  if (auto fAttr = attr.dyn_cast<FloatAttr>()) {
734  printFloat(fAttr.getValue());
735  return success();
736  }
737  if (auto dense = attr.dyn_cast<DenseFPElementsAttr>()) {
738  os << '{';
739  interleaveComma(dense, os, [&](const APFloat &val) { printFloat(val); });
740  os << '}';
741  return success();
742  }
743 
744  // Print integer attributes.
745  if (auto iAttr = attr.dyn_cast<IntegerAttr>()) {
746  if (auto iType = iAttr.getType().dyn_cast<IntegerType>()) {
747  printInt(iAttr.getValue(), shouldMapToUnsigned(iType.getSignedness()));
748  return success();
749  }
750  if (auto iType = iAttr.getType().dyn_cast<IndexType>()) {
751  printInt(iAttr.getValue(), false);
752  return success();
753  }
754  }
755  if (auto dense = attr.dyn_cast<DenseIntElementsAttr>()) {
756  if (auto iType = dense.getType()
757  .cast<TensorType>()
758  .getElementType()
759  .dyn_cast<IntegerType>()) {
760  os << '{';
761  interleaveComma(dense, os, [&](const APInt &val) {
762  printInt(val, shouldMapToUnsigned(iType.getSignedness()));
763  });
764  os << '}';
765  return success();
766  }
767  if (auto iType = dense.getType()
768  .cast<TensorType>()
769  .getElementType()
770  .dyn_cast<IndexType>()) {
771  os << '{';
772  interleaveComma(dense, os,
773  [&](const APInt &val) { printInt(val, false); });
774  os << '}';
775  return success();
776  }
777  }
778 
779  // Print opaque attributes.
780  if (auto oAttr = attr.dyn_cast<emitc::OpaqueAttr>()) {
781  os << oAttr.getValue();
782  return success();
783  }
784 
785  // Print symbolic reference attributes.
786  if (auto sAttr = attr.dyn_cast<SymbolRefAttr>()) {
787  if (sAttr.getNestedReferences().size() > 1)
788  return emitError(loc, "attribute has more than 1 nested reference");
789  os << sAttr.getRootReference().getValue();
790  return success();
791  }
792 
793  // Print type attributes.
794  if (auto type = attr.dyn_cast<TypeAttr>())
795  return emitType(loc, type.getValue());
796 
797  return emitError(loc, "cannot emit attribute of type ") << attr.getType();
798 }
799 
800 LogicalResult CppEmitter::emitOperands(Operation &op) {
801  auto emitOperandName = [&](Value result) -> LogicalResult {
802  if (!hasValueInScope(result))
803  return op.emitOpError() << "operand value not in scope";
804  os << getOrCreateName(result);
805  return success();
806  };
807  return interleaveCommaWithError(op.getOperands(), os, emitOperandName);
808 }
809 
811 CppEmitter::emitOperandsAndAttributes(Operation &op,
812  ArrayRef<StringRef> exclude) {
813  if (failed(emitOperands(op)))
814  return failure();
815  // Insert comma in between operands and non-filtered attributes if needed.
816  if (op.getNumOperands() > 0) {
817  for (NamedAttribute attr : op.getAttrs()) {
818  if (!llvm::is_contained(exclude, attr.getName().strref())) {
819  os << ", ";
820  break;
821  }
822  }
823  }
824  // Emit attributes.
825  auto emitNamedAttribute = [&](NamedAttribute attr) -> LogicalResult {
826  if (llvm::is_contained(exclude, attr.getName().strref()))
827  return success();
828  os << "/* " << attr.getName().getValue() << " */";
829  if (failed(emitAttribute(op.getLoc(), attr.getValue())))
830  return failure();
831  return success();
832  };
833  return interleaveCommaWithError(op.getAttrs(), os, emitNamedAttribute);
834 }
835 
836 LogicalResult CppEmitter::emitVariableAssignment(OpResult result) {
837  if (!hasValueInScope(result)) {
838  return result.getDefiningOp()->emitOpError(
839  "result variable for the operation has not been declared");
840  }
841  os << getOrCreateName(result) << " = ";
842  return success();
843 }
844 
845 LogicalResult CppEmitter::emitVariableDeclaration(OpResult result,
846  bool trailingSemicolon) {
847  if (hasValueInScope(result)) {
848  return result.getDefiningOp()->emitError(
849  "result variable for the operation already declared");
850  }
851  if (failed(emitType(result.getOwner()->getLoc(), result.getType())))
852  return failure();
853  os << " " << getOrCreateName(result);
854  if (trailingSemicolon)
855  os << ";\n";
856  return success();
857 }
858 
859 LogicalResult CppEmitter::emitAssignPrefix(Operation &op) {
860  switch (op.getNumResults()) {
861  case 0:
862  break;
863  case 1: {
864  OpResult result = op.getResult(0);
865  if (shouldDeclareVariablesAtTop()) {
866  if (failed(emitVariableAssignment(result)))
867  return failure();
868  } else {
869  if (failed(emitVariableDeclaration(result, /*trailingSemicolon=*/false)))
870  return failure();
871  os << " = ";
872  }
873  break;
874  }
875  default:
876  if (!shouldDeclareVariablesAtTop()) {
877  for (OpResult result : op.getResults()) {
878  if (failed(emitVariableDeclaration(result, /*trailingSemicolon=*/true)))
879  return failure();
880  }
881  }
882  os << "std::tie(";
883  interleaveComma(op.getResults(), os,
884  [&](Value result) { os << getOrCreateName(result); });
885  os << ") = ";
886  }
887  return success();
888 }
889 
890 LogicalResult CppEmitter::emitLabel(Block &block) {
891  if (!hasBlockLabel(block))
892  return block.getParentOp()->emitError("label for block not found");
893  // FIXME: Add feature in `raw_indented_ostream` to ignore indent for block
894  // label instead of using `getOStream`.
895  os.getOStream() << getOrCreateName(block) << ":\n";
896  return success();
897 }
898 
899 LogicalResult CppEmitter::emitOperation(Operation &op, bool trailingSemicolon) {
900  LogicalResult status =
902  // EmitC ops.
903  .Case<emitc::ApplyOp, emitc::CallOp, emitc::ConstantOp,
904  emitc::IncludeOp>(
905  [&](auto op) { return printOperation(*this, op); })
906  // SCF ops.
907  .Case<scf::ForOp, scf::IfOp, scf::YieldOp>(
908  [&](auto op) { return printOperation(*this, op); })
909  // Standard ops.
910  .Case<BranchOp, mlir::CallOp, CondBranchOp, mlir::ConstantOp, FuncOp,
911  ModuleOp, ReturnOp>(
912  [&](auto op) { return printOperation(*this, op); })
913  // Arithmetic ops.
914  .Case<arith::ConstantOp>(
915  [&](auto op) { return printOperation(*this, op); })
916  .Default([&](Operation *) {
917  return op.emitOpError("unable to find printer for op");
918  });
919 
920  if (failed(status))
921  return failure();
922  os << (trailingSemicolon ? ";\n" : "\n");
923  return success();
924 }
925 
926 LogicalResult CppEmitter::emitType(Location loc, Type type) {
927  if (auto iType = type.dyn_cast<IntegerType>()) {
928  switch (iType.getWidth()) {
929  case 1:
930  return (os << "bool"), success();
931  case 8:
932  case 16:
933  case 32:
934  case 64:
935  if (shouldMapToUnsigned(iType.getSignedness()))
936  return (os << "uint" << iType.getWidth() << "_t"), success();
937  else
938  return (os << "int" << iType.getWidth() << "_t"), success();
939  default:
940  return emitError(loc, "cannot emit integer type ") << type;
941  }
942  }
943  if (auto fType = type.dyn_cast<FloatType>()) {
944  switch (fType.getWidth()) {
945  case 32:
946  return (os << "float"), success();
947  case 64:
948  return (os << "double"), success();
949  default:
950  return emitError(loc, "cannot emit float type ") << type;
951  }
952  }
953  if (auto iType = type.dyn_cast<IndexType>())
954  return (os << "size_t"), success();
955  if (auto tType = type.dyn_cast<TensorType>()) {
956  if (!tType.hasRank())
957  return emitError(loc, "cannot emit unranked tensor type");
958  if (!tType.hasStaticShape())
959  return emitError(loc, "cannot emit tensor type with non static shape");
960  os << "Tensor<";
961  if (failed(emitType(loc, tType.getElementType())))
962  return failure();
963  auto shape = tType.getShape();
964  for (auto dimSize : shape) {
965  os << ", ";
966  os << dimSize;
967  }
968  os << ">";
969  return success();
970  }
971  if (auto tType = type.dyn_cast<TupleType>())
972  return emitTupleType(loc, tType.getTypes());
973  if (auto oType = type.dyn_cast<emitc::OpaqueType>()) {
974  os << oType.getValue();
975  return success();
976  }
977  return emitError(loc, "cannot emit type ") << type;
978 }
979 
980 LogicalResult CppEmitter::emitTypes(Location loc, ArrayRef<Type> types) {
981  switch (types.size()) {
982  case 0:
983  os << "void";
984  return success();
985  case 1:
986  return emitType(loc, types.front());
987  default:
988  return emitTupleType(loc, types);
989  }
990 }
991 
992 LogicalResult CppEmitter::emitTupleType(Location loc, ArrayRef<Type> types) {
993  os << "std::tuple<";
995  types, os, [&](Type type) { return emitType(loc, type); })))
996  return failure();
997  os << ">";
998  return success();
999 }
1000 
1002  bool declareVariablesAtTop) {
1003  CppEmitter emitter(os, declareVariablesAtTop);
1004  return emitter.emitOperation(*op, /*trailingSemicolon=*/false);
1005 }
LogicalResult interleaveWithError(ForwardIterator begin, ForwardIterator end, UnaryFunctor eachFn, NullaryFunctor betweenFn)
Convenience functions to produce interleaved output with functions returning a LogicalResult.
Include the generated interface declarations.
Block * getSuccessor(unsigned i)
Definition: Block.cpp:240
This class contains a list of basic blocks and a link to the parent operation it is attached to...
Definition: Region.h:26
An attribute that represents a reference to a dense float vector or tensor object.
Operation is a basic unit of execution within MLIR.
Definition: Operation.h:28
Operation * getParentOp()
Returns the closest surrounding operation that contains this block.
Definition: Block.cpp:30
BlockListType & getBlocks()
Definition: Region.h:45
This is a value defined by a result of an operation.
Definition: Value.h:423
operand_range getOperands()
Returns an iterator on the underlying Value&#39;s.
Definition: Operation.h:247
Block represents an ordered list of Operations.
Definition: Block.h:29
bool wasInterrupted() const
Returns true if the walk was interrupted.
Definition: Visitors.h:55
ArrayRef< NamedAttribute > getAttrs()
Return all of the attributes on this operation.
Definition: Operation.h:308
Value getOperand(unsigned idx)
Definition: Operation.h:219
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value...
Definition: LogicalResult.h:72
unsigned getNumOperands()
Definition: Operation.h:215
static Type getElementType(Type type, ArrayRef< int32_t > indices, function_ref< InFlightDiagnostic(StringRef)> emitErrorFn)
Walks the given type hierarchy with the given indices, potentially down to component granularity...
Definition: SPIRVOps.cpp:639
This class implements the result iterators for the Operation class.
raw_indented_ostream & indent()
Increases the indent and returning this raw_indented_ostream.
static LogicalResult printOperation(CppEmitter &emitter, emitc::ConstantOp constantOp)
static constexpr const bool value
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:48
Operation * getOwner() const
Returns the operation that owns this result.
Definition: Value.h:432
LogicalResult translateToCpp(Operation *op, raw_ostream &os, bool declareVariablesAtTop=false)
Translates the given operation to C++ code.
NamedAttribute represents a combination of a name and an Attribute value.
Definition: Attributes.h:137
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:56
This class represents an efficient way to signal success or failure.
Definition: LogicalResult.h:26
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:62
iterator_range< OpIterator > getOps()
Definition: Region.h:172
bool empty()
Definition: Region.h:60
U dyn_cast() const
Definition: Types.h:244
Attributes are known-constant values of operations.
Definition: Attributes.h:24
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
Definition: Operation.h:117
OpResult getResult(unsigned idx)
Get the &#39;idx&#39;th result of this operation.
Definition: Operation.h:276
static WalkResult advance()
Definition: Visitors.h:51
Location getLoc()
The source location the operation was defined or derived from.
Definition: Operation.h:106
BlockArgListType getArguments()
Definition: Block.h:76
A utility result that is used to signal how to proceed with an ongoing walk:
Definition: Visitors.h:34
This class represents an argument of a Block.
Definition: Value.h:298
Tensor types represent multi-dimensional arrays, and have two variants: RankedTensorType and Unranked...
Definition: BuiltinTypes.h:73
LogicalResult interleaveCommaWithError(const Container &c, raw_ostream &os, UnaryFunctor eachFn)
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:72
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:84
Type getType() const
Return the type of this attribute.
Definition: Attributes.h:64
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
Type getType() const
Return the type of this value.
Definition: Value.h:117
U dyn_cast() const
Definition: Attributes.h:117
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition: Value.cpp:20
This class implements the operand iterators for the Operation class.
static LogicalResult printConstantOp(CppEmitter &emitter, Operation *operation, Attribute value)
raw_indented_ostream & unindent()
Decreases the indent and returning this raw_indented_ostream.
unsigned getNumResults()
Return the number of results held by this operation.
Definition: Operation.h:273
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "&#39;dim&#39; op " which is convenient for verifiers...
Definition: Operation.cpp:518
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
Definition: Operation.cpp:231
raw_ostream subclass that simplifies indention a sequence of code.
result_range getResults()
Definition: Operation.h:284
llvm::iplist< Block > BlockListType
Definition: Region.h:44
An attribute that represents a reference to a dense integer vector or tensor object.