MLIR  16.0.0git
TranslateToCpp.cpp
Go to the documentation of this file.
1 //===- TranslateToCpp.cpp - Translating to C++ calls ----------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
14 #include "mlir/IR/BuiltinOps.h"
15 #include "mlir/IR/BuiltinTypes.h"
16 #include "mlir/IR/Dialect.h"
17 #include "mlir/IR/Operation.h"
20 #include "llvm/ADT/DenseMap.h"
21 #include "llvm/ADT/StringExtras.h"
22 #include "llvm/ADT/StringMap.h"
23 #include "llvm/ADT/TypeSwitch.h"
24 #include "llvm/Support/Debug.h"
25 #include "llvm/Support/FormatVariadic.h"
26 #include <utility>
27 
28 #define DEBUG_TYPE "translate-to-cpp"
29 
30 using namespace mlir;
31 using namespace mlir::emitc;
32 using llvm::formatv;
33 
34 /// Convenience functions to produce interleaved output with functions returning
35 /// a LogicalResult. This is different than those in STLExtras as functions used
36 /// on each element doesn't return a string.
37 template <typename ForwardIterator, typename UnaryFunctor,
38  typename NullaryFunctor>
39 inline LogicalResult
40 interleaveWithError(ForwardIterator begin, ForwardIterator end,
41  UnaryFunctor eachFn, NullaryFunctor betweenFn) {
42  if (begin == end)
43  return success();
44  if (failed(eachFn(*begin)))
45  return failure();
46  ++begin;
47  for (; begin != end; ++begin) {
48  betweenFn();
49  if (failed(eachFn(*begin)))
50  return failure();
51  }
52  return success();
53 }
54 
55 template <typename Container, typename UnaryFunctor, typename NullaryFunctor>
56 inline LogicalResult interleaveWithError(const Container &c,
57  UnaryFunctor eachFn,
58  NullaryFunctor betweenFn) {
59  return interleaveWithError(c.begin(), c.end(), eachFn, betweenFn);
60 }
61 
62 template <typename Container, typename UnaryFunctor>
63 inline LogicalResult interleaveCommaWithError(const Container &c,
64  raw_ostream &os,
65  UnaryFunctor eachFn) {
66  return interleaveWithError(c.begin(), c.end(), eachFn, [&]() { os << ", "; });
67 }
68 
69 namespace {
70 /// Emitter that uses dialect specific emitters to emit C++ code.
71 struct CppEmitter {
72  explicit CppEmitter(raw_ostream &os, bool declareVariablesAtTop);
73 
74  /// Emits attribute or returns failure.
75  LogicalResult emitAttribute(Location loc, Attribute attr);
76 
77  /// Emits operation 'op' with/without training semicolon or returns failure.
78  LogicalResult emitOperation(Operation &op, bool trailingSemicolon);
79 
80  /// Emits type 'type' or returns failure.
81  LogicalResult emitType(Location loc, Type type);
82 
83  /// Emits array of types as a std::tuple of the emitted types.
84  /// - emits void for an empty array;
85  /// - emits the type of the only element for arrays of size one;
86  /// - emits a std::tuple otherwise;
87  LogicalResult emitTypes(Location loc, ArrayRef<Type> types);
88 
89  /// Emits array of types as a std::tuple of the emitted types independently of
90  /// the array size.
91  LogicalResult emitTupleType(Location loc, ArrayRef<Type> types);
92 
93  /// Emits an assignment for a variable which has been declared previously.
94  LogicalResult emitVariableAssignment(OpResult result);
95 
96  /// Emits a variable declaration for a result of an operation.
97  LogicalResult emitVariableDeclaration(OpResult result,
98  bool trailingSemicolon);
99 
100  /// Emits the variable declaration and assignment prefix for 'op'.
101  /// - emits separate variable followed by std::tie for multi-valued operation;
102  /// - emits single type followed by variable for single result;
103  /// - emits nothing if no value produced by op;
104  /// Emits final '=' operator where a type is produced. Returns failure if
105  /// any result type could not be converted.
106  LogicalResult emitAssignPrefix(Operation &op);
107 
108  /// Emits a label for the block.
109  LogicalResult emitLabel(Block &block);
110 
111  /// Emits the operands and atttributes of the operation. All operands are
112  /// emitted first and then all attributes in alphabetical order.
113  LogicalResult emitOperandsAndAttributes(Operation &op,
114  ArrayRef<StringRef> exclude = {});
115 
116  /// Emits the operands of the operation. All operands are emitted in order.
117  LogicalResult emitOperands(Operation &op);
118 
119  /// Return the existing or a new name for a Value.
120  StringRef getOrCreateName(Value val);
121 
122  /// Return the existing or a new label of a Block.
123  StringRef getOrCreateName(Block &block);
124 
125  /// Whether to map an mlir integer to a unsigned integer in C++.
126  bool shouldMapToUnsigned(IntegerType::SignednessSemantics val);
127 
128  /// RAII helper function to manage entering/exiting C++ scopes.
129  struct Scope {
130  Scope(CppEmitter &emitter)
131  : valueMapperScope(emitter.valueMapper),
132  blockMapperScope(emitter.blockMapper), emitter(emitter) {
133  emitter.valueInScopeCount.push(emitter.valueInScopeCount.top());
134  emitter.labelInScopeCount.push(emitter.labelInScopeCount.top());
135  }
136  ~Scope() {
137  emitter.valueInScopeCount.pop();
138  emitter.labelInScopeCount.pop();
139  }
140 
141  private:
142  llvm::ScopedHashTableScope<Value, std::string> valueMapperScope;
143  llvm::ScopedHashTableScope<Block *, std::string> blockMapperScope;
144  CppEmitter &emitter;
145  };
146 
147  /// Returns wether the Value is assigned to a C++ variable in the scope.
148  bool hasValueInScope(Value val);
149 
150  // Returns whether a label is assigned to the block.
151  bool hasBlockLabel(Block &block);
152 
153  /// Returns the output stream.
154  raw_indented_ostream &ostream() { return os; };
155 
156  /// Returns if all variables for op results and basic block arguments need to
157  /// be declared at the beginning of a function.
158  bool shouldDeclareVariablesAtTop() { return declareVariablesAtTop; };
159 
160 private:
161  using ValueMapper = llvm::ScopedHashTable<Value, std::string>;
162  using BlockMapper = llvm::ScopedHashTable<Block *, std::string>;
163 
164  /// Output stream to emit to.
166 
167  /// Boolean to enforce that all variables for op results and block
168  /// arguments are declared at the beginning of the function. This also
169  /// includes results from ops located in nested regions.
170  bool declareVariablesAtTop;
171 
172  /// Map from value to name of C++ variable that contain the name.
173  ValueMapper valueMapper;
174 
175  /// Map from block to name of C++ label.
176  BlockMapper blockMapper;
177 
178  /// The number of values in the current scope. This is used to declare the
179  /// names of values in a scope.
180  std::stack<int64_t> valueInScopeCount;
181  std::stack<int64_t> labelInScopeCount;
182 };
183 } // namespace
184 
185 static LogicalResult printConstantOp(CppEmitter &emitter, Operation *operation,
186  Attribute value) {
187  OpResult result = operation->getResult(0);
188 
189  // Only emit an assignment as the variable was already declared when printing
190  // the FuncOp.
191  if (emitter.shouldDeclareVariablesAtTop()) {
192  // Skip the assignment if the emitc.constant has no value.
193  if (auto oAttr = value.dyn_cast<emitc::OpaqueAttr>()) {
194  if (oAttr.getValue().empty())
195  return success();
196  }
197 
198  if (failed(emitter.emitVariableAssignment(result)))
199  return failure();
200  return emitter.emitAttribute(operation->getLoc(), value);
201  }
202 
203  // Emit a variable declaration for an emitc.constant op without value.
204  if (auto oAttr = value.dyn_cast<emitc::OpaqueAttr>()) {
205  if (oAttr.getValue().empty())
206  // The semicolon gets printed by the emitOperation function.
207  return emitter.emitVariableDeclaration(result,
208  /*trailingSemicolon=*/false);
209  }
210 
211  // Emit a variable declaration.
212  if (failed(emitter.emitAssignPrefix(*operation)))
213  return failure();
214  return emitter.emitAttribute(operation->getLoc(), value);
215 }
216 
217 static LogicalResult printOperation(CppEmitter &emitter,
218  emitc::ConstantOp constantOp) {
219  Operation *operation = constantOp.getOperation();
220  Attribute value = constantOp.getValue();
221 
222  return printConstantOp(emitter, operation, value);
223 }
224 
225 static LogicalResult printOperation(CppEmitter &emitter,
226  emitc::VariableOp variableOp) {
227  Operation *operation = variableOp.getOperation();
228  Attribute value = variableOp.getValue();
229 
230  return printConstantOp(emitter, operation, value);
231 }
232 
233 static LogicalResult printOperation(CppEmitter &emitter,
234  arith::ConstantOp constantOp) {
235  Operation *operation = constantOp.getOperation();
236  Attribute value = constantOp.getValue();
237 
238  return printConstantOp(emitter, operation, value);
239 }
240 
241 static LogicalResult printOperation(CppEmitter &emitter,
242  func::ConstantOp constantOp) {
243  Operation *operation = constantOp.getOperation();
244  Attribute value = constantOp.getValueAttr();
245 
246  return printConstantOp(emitter, operation, value);
247 }
248 
249 static LogicalResult printOperation(CppEmitter &emitter,
250  cf::BranchOp branchOp) {
251  raw_ostream &os = emitter.ostream();
252  Block &successor = *branchOp.getSuccessor();
253 
254  for (auto pair :
255  llvm::zip(branchOp.getOperands(), successor.getArguments())) {
256  Value &operand = std::get<0>(pair);
257  BlockArgument &argument = std::get<1>(pair);
258  os << emitter.getOrCreateName(argument) << " = "
259  << emitter.getOrCreateName(operand) << ";\n";
260  }
261 
262  os << "goto ";
263  if (!(emitter.hasBlockLabel(successor)))
264  return branchOp.emitOpError("unable to find label for successor block");
265  os << emitter.getOrCreateName(successor);
266  return success();
267 }
268 
269 static LogicalResult printOperation(CppEmitter &emitter,
270  cf::CondBranchOp condBranchOp) {
271  raw_indented_ostream &os = emitter.ostream();
272  Block &trueSuccessor = *condBranchOp.getTrueDest();
273  Block &falseSuccessor = *condBranchOp.getFalseDest();
274 
275  os << "if (" << emitter.getOrCreateName(condBranchOp.getCondition())
276  << ") {\n";
277 
278  os.indent();
279 
280  // If condition is true.
281  for (auto pair : llvm::zip(condBranchOp.getTrueOperands(),
282  trueSuccessor.getArguments())) {
283  Value &operand = std::get<0>(pair);
284  BlockArgument &argument = std::get<1>(pair);
285  os << emitter.getOrCreateName(argument) << " = "
286  << emitter.getOrCreateName(operand) << ";\n";
287  }
288 
289  os << "goto ";
290  if (!(emitter.hasBlockLabel(trueSuccessor))) {
291  return condBranchOp.emitOpError("unable to find label for successor block");
292  }
293  os << emitter.getOrCreateName(trueSuccessor) << ";\n";
294  os.unindent() << "} else {\n";
295  os.indent();
296  // If condition is false.
297  for (auto pair : llvm::zip(condBranchOp.getFalseOperands(),
298  falseSuccessor.getArguments())) {
299  Value &operand = std::get<0>(pair);
300  BlockArgument &argument = std::get<1>(pair);
301  os << emitter.getOrCreateName(argument) << " = "
302  << emitter.getOrCreateName(operand) << ";\n";
303  }
304 
305  os << "goto ";
306  if (!(emitter.hasBlockLabel(falseSuccessor))) {
307  return condBranchOp.emitOpError()
308  << "unable to find label for successor block";
309  }
310  os << emitter.getOrCreateName(falseSuccessor) << ";\n";
311  os.unindent() << "}";
312  return success();
313 }
314 
315 static LogicalResult printOperation(CppEmitter &emitter, func::CallOp callOp) {
316  if (failed(emitter.emitAssignPrefix(*callOp.getOperation())))
317  return failure();
318 
319  raw_ostream &os = emitter.ostream();
320  os << callOp.getCallee() << "(";
321  if (failed(emitter.emitOperands(*callOp.getOperation())))
322  return failure();
323  os << ")";
324  return success();
325 }
326 
327 static LogicalResult printOperation(CppEmitter &emitter, emitc::CallOp callOp) {
328  raw_ostream &os = emitter.ostream();
329  Operation &op = *callOp.getOperation();
330 
331  if (failed(emitter.emitAssignPrefix(op)))
332  return failure();
333  os << callOp.getCallee();
334 
335  auto emitArgs = [&](Attribute attr) -> LogicalResult {
336  if (auto t = attr.dyn_cast<IntegerAttr>()) {
337  // Index attributes are treated specially as operand index.
338  if (t.getType().isIndex()) {
339  int64_t idx = t.getInt();
340  if ((idx < 0) || (idx >= op.getNumOperands()))
341  return op.emitOpError("invalid operand index");
342  if (!emitter.hasValueInScope(op.getOperand(idx)))
343  return op.emitOpError("operand ")
344  << idx << "'s value not defined in scope";
345  os << emitter.getOrCreateName(op.getOperand(idx));
346  return success();
347  }
348  }
349  if (failed(emitter.emitAttribute(op.getLoc(), attr)))
350  return failure();
351 
352  return success();
353  };
354 
355  if (callOp.getTemplateArgs()) {
356  os << "<";
357  if (failed(
358  interleaveCommaWithError(*callOp.getTemplateArgs(), os, emitArgs)))
359  return failure();
360  os << ">";
361  }
362 
363  os << "(";
364 
365  LogicalResult emittedArgs =
366  callOp.getArgs()
367  ? interleaveCommaWithError(*callOp.getArgs(), os, emitArgs)
368  : emitter.emitOperands(op);
369  if (failed(emittedArgs))
370  return failure();
371  os << ")";
372  return success();
373 }
374 
375 static LogicalResult printOperation(CppEmitter &emitter,
376  emitc::ApplyOp applyOp) {
377  raw_ostream &os = emitter.ostream();
378  Operation &op = *applyOp.getOperation();
379 
380  if (failed(emitter.emitAssignPrefix(op)))
381  return failure();
382  os << applyOp.getApplicableOperator();
383  os << emitter.getOrCreateName(applyOp.getOperand());
384 
385  return success();
386 }
387 
388 static LogicalResult printOperation(CppEmitter &emitter, emitc::CastOp castOp) {
389  raw_ostream &os = emitter.ostream();
390  Operation &op = *castOp.getOperation();
391 
392  if (failed(emitter.emitAssignPrefix(op)))
393  return failure();
394  os << "(";
395  if (failed(emitter.emitType(op.getLoc(), op.getResult(0).getType())))
396  return failure();
397  os << ") ";
398  os << emitter.getOrCreateName(castOp.getOperand());
399 
400  return success();
401 }
402 
403 static LogicalResult printOperation(CppEmitter &emitter,
404  emitc::IncludeOp includeOp) {
405  raw_ostream &os = emitter.ostream();
406 
407  os << "#include ";
408  if (includeOp.getIsStandardInclude())
409  os << "<" << includeOp.getInclude() << ">";
410  else
411  os << "\"" << includeOp.getInclude() << "\"";
412 
413  return success();
414 }
415 
416 static LogicalResult printOperation(CppEmitter &emitter, scf::ForOp forOp) {
417 
418  raw_indented_ostream &os = emitter.ostream();
419 
420  OperandRange operands = forOp.getIterOperands();
421  Block::BlockArgListType iterArgs = forOp.getRegionIterArgs();
422  Operation::result_range results = forOp.getResults();
423 
424  if (!emitter.shouldDeclareVariablesAtTop()) {
425  for (OpResult result : results) {
426  if (failed(emitter.emitVariableDeclaration(result,
427  /*trailingSemicolon=*/true)))
428  return failure();
429  }
430  }
431 
432  for (auto pair : llvm::zip(iterArgs, operands)) {
433  if (failed(emitter.emitType(forOp.getLoc(), std::get<0>(pair).getType())))
434  return failure();
435  os << " " << emitter.getOrCreateName(std::get<0>(pair)) << " = ";
436  os << emitter.getOrCreateName(std::get<1>(pair)) << ";";
437  os << "\n";
438  }
439 
440  os << "for (";
441  if (failed(
442  emitter.emitType(forOp.getLoc(), forOp.getInductionVar().getType())))
443  return failure();
444  os << " ";
445  os << emitter.getOrCreateName(forOp.getInductionVar());
446  os << " = ";
447  os << emitter.getOrCreateName(forOp.getLowerBound());
448  os << "; ";
449  os << emitter.getOrCreateName(forOp.getInductionVar());
450  os << " < ";
451  os << emitter.getOrCreateName(forOp.getUpperBound());
452  os << "; ";
453  os << emitter.getOrCreateName(forOp.getInductionVar());
454  os << " += ";
455  os << emitter.getOrCreateName(forOp.getStep());
456  os << ") {\n";
457  os.indent();
458 
459  Region &forRegion = forOp.getRegion();
460  auto regionOps = forRegion.getOps();
461 
462  // We skip the trailing yield op because this updates the result variables
463  // of the for op in the generated code. Instead we update the iterArgs at
464  // the end of a loop iteration and set the result variables after the for
465  // loop.
466  for (auto it = regionOps.begin(); std::next(it) != regionOps.end(); ++it) {
467  if (failed(emitter.emitOperation(*it, /*trailingSemicolon=*/true)))
468  return failure();
469  }
470 
471  Operation *yieldOp = forRegion.getBlocks().front().getTerminator();
472  // Copy yield operands into iterArgs at the end of a loop iteration.
473  for (auto pair : llvm::zip(iterArgs, yieldOp->getOperands())) {
474  BlockArgument iterArg = std::get<0>(pair);
475  Value operand = std::get<1>(pair);
476  os << emitter.getOrCreateName(iterArg) << " = "
477  << emitter.getOrCreateName(operand) << ";\n";
478  }
479 
480  os.unindent() << "}";
481 
482  // Copy iterArgs into results after the for loop.
483  for (auto pair : llvm::zip(results, iterArgs)) {
484  OpResult result = std::get<0>(pair);
485  BlockArgument iterArg = std::get<1>(pair);
486  os << "\n"
487  << emitter.getOrCreateName(result) << " = "
488  << emitter.getOrCreateName(iterArg) << ";";
489  }
490 
491  return success();
492 }
493 
494 static LogicalResult printOperation(CppEmitter &emitter, scf::IfOp ifOp) {
495  raw_indented_ostream &os = emitter.ostream();
496 
497  if (!emitter.shouldDeclareVariablesAtTop()) {
498  for (OpResult result : ifOp.getResults()) {
499  if (failed(emitter.emitVariableDeclaration(result,
500  /*trailingSemicolon=*/true)))
501  return failure();
502  }
503  }
504 
505  os << "if (";
506  if (failed(emitter.emitOperands(*ifOp.getOperation())))
507  return failure();
508  os << ") {\n";
509  os.indent();
510 
511  Region &thenRegion = ifOp.getThenRegion();
512  for (Operation &op : thenRegion.getOps()) {
513  // Note: This prints a superfluous semicolon if the terminating yield op has
514  // zero results.
515  if (failed(emitter.emitOperation(op, /*trailingSemicolon=*/true)))
516  return failure();
517  }
518 
519  os.unindent() << "}";
520 
521  Region &elseRegion = ifOp.getElseRegion();
522  if (!elseRegion.empty()) {
523  os << " else {\n";
524  os.indent();
525 
526  for (Operation &op : elseRegion.getOps()) {
527  // Note: This prints a superfluous semicolon if the terminating yield op
528  // has zero results.
529  if (failed(emitter.emitOperation(op, /*trailingSemicolon=*/true)))
530  return failure();
531  }
532 
533  os.unindent() << "}";
534  }
535 
536  return success();
537 }
538 
539 static LogicalResult printOperation(CppEmitter &emitter, scf::YieldOp yieldOp) {
540  raw_ostream &os = emitter.ostream();
541  Operation &parentOp = *yieldOp.getOperation()->getParentOp();
542 
543  if (yieldOp.getNumOperands() != parentOp.getNumResults()) {
544  return yieldOp.emitError("number of operands does not to match the number "
545  "of the parent op's results");
546  }
547 
549  llvm::zip(parentOp.getResults(), yieldOp.getOperands()),
550  [&](auto pair) -> LogicalResult {
551  auto result = std::get<0>(pair);
552  auto operand = std::get<1>(pair);
553  os << emitter.getOrCreateName(result) << " = ";
554 
555  if (!emitter.hasValueInScope(operand))
556  return yieldOp.emitError("operand value not in scope");
557  os << emitter.getOrCreateName(operand);
558  return success();
559  },
560  [&]() { os << ";\n"; })))
561  return failure();
562 
563  return success();
564 }
565 
566 static LogicalResult printOperation(CppEmitter &emitter,
567  func::ReturnOp returnOp) {
568  raw_ostream &os = emitter.ostream();
569  os << "return";
570  switch (returnOp.getNumOperands()) {
571  case 0:
572  return success();
573  case 1:
574  os << " " << emitter.getOrCreateName(returnOp.getOperand(0));
575  return success(emitter.hasValueInScope(returnOp.getOperand(0)));
576  default:
577  os << " std::make_tuple(";
578  if (failed(emitter.emitOperandsAndAttributes(*returnOp.getOperation())))
579  return failure();
580  os << ")";
581  return success();
582  }
583 }
584 
585 static LogicalResult printOperation(CppEmitter &emitter, ModuleOp moduleOp) {
586  CppEmitter::Scope scope(emitter);
587 
588  for (Operation &op : moduleOp) {
589  if (failed(emitter.emitOperation(op, /*trailingSemicolon=*/false)))
590  return failure();
591  }
592  return success();
593 }
594 
595 static LogicalResult printOperation(CppEmitter &emitter,
596  func::FuncOp functionOp) {
597  // We need to declare variables at top if the function has multiple blocks.
598  if (!emitter.shouldDeclareVariablesAtTop() &&
599  functionOp.getBlocks().size() > 1) {
600  return functionOp.emitOpError(
601  "with multiple blocks needs variables declared at top");
602  }
603 
604  CppEmitter::Scope scope(emitter);
605  raw_indented_ostream &os = emitter.ostream();
606  if (failed(emitter.emitTypes(functionOp.getLoc(),
607  functionOp.getFunctionType().getResults())))
608  return failure();
609  os << " " << functionOp.getName();
610 
611  os << "(";
613  functionOp.getArguments(), os,
614  [&](BlockArgument arg) -> LogicalResult {
615  if (failed(emitter.emitType(functionOp.getLoc(), arg.getType())))
616  return failure();
617  os << " " << emitter.getOrCreateName(arg);
618  return success();
619  })))
620  return failure();
621  os << ") {\n";
622  os.indent();
623  if (emitter.shouldDeclareVariablesAtTop()) {
624  // Declare all variables that hold op results including those from nested
625  // regions.
626  WalkResult result =
627  functionOp.walk<WalkOrder::PreOrder>([&](Operation *op) -> WalkResult {
628  for (OpResult result : op->getResults()) {
629  if (failed(emitter.emitVariableDeclaration(
630  result, /*trailingSemicolon=*/true))) {
631  return WalkResult(
632  op->emitError("unable to declare result variable for op"));
633  }
634  }
635  return WalkResult::advance();
636  });
637  if (result.wasInterrupted())
638  return failure();
639  }
640 
641  Region::BlockListType &blocks = functionOp.getBlocks();
642  // Create label names for basic blocks.
643  for (Block &block : blocks) {
644  emitter.getOrCreateName(block);
645  }
646 
647  // Declare variables for basic block arguments.
648  for (Block &block : llvm::drop_begin(blocks)) {
649  for (BlockArgument &arg : block.getArguments()) {
650  if (emitter.hasValueInScope(arg))
651  return functionOp.emitOpError(" block argument #")
652  << arg.getArgNumber() << " is out of scope";
653  if (failed(
654  emitter.emitType(block.getParentOp()->getLoc(), arg.getType()))) {
655  return failure();
656  }
657  os << " " << emitter.getOrCreateName(arg) << ";\n";
658  }
659  }
660 
661  for (Block &block : blocks) {
662  // Only print a label if the block has predecessors.
663  if (!block.hasNoPredecessors()) {
664  if (failed(emitter.emitLabel(block)))
665  return failure();
666  }
667  for (Operation &op : block.getOperations()) {
668  // When generating code for an scf.if or cf.cond_br op no semicolon needs
669  // to be printed after the closing brace.
670  // When generating code for an scf.for op, printing a trailing semicolon
671  // is handled within the printOperation function.
672  bool trailingSemicolon =
673  !isa<scf::IfOp, scf::ForOp, cf::CondBranchOp>(op);
674 
675  if (failed(emitter.emitOperation(
676  op, /*trailingSemicolon=*/trailingSemicolon)))
677  return failure();
678  }
679  }
680  os.unindent() << "}\n";
681  return success();
682 }
683 
684 CppEmitter::CppEmitter(raw_ostream &os, bool declareVariablesAtTop)
685  : os(os), declareVariablesAtTop(declareVariablesAtTop) {
686  valueInScopeCount.push(0);
687  labelInScopeCount.push(0);
688 }
689 
690 /// Return the existing or a new name for a Value.
691 StringRef CppEmitter::getOrCreateName(Value val) {
692  if (!valueMapper.count(val))
693  valueMapper.insert(val, formatv("v{0}", ++valueInScopeCount.top()));
694  return *valueMapper.begin(val);
695 }
696 
697 /// Return the existing or a new label for a Block.
698 StringRef CppEmitter::getOrCreateName(Block &block) {
699  if (!blockMapper.count(&block))
700  blockMapper.insert(&block, formatv("label{0}", ++labelInScopeCount.top()));
701  return *blockMapper.begin(&block);
702 }
703 
704 bool CppEmitter::shouldMapToUnsigned(IntegerType::SignednessSemantics val) {
705  switch (val) {
706  case IntegerType::Signless:
707  return false;
708  case IntegerType::Signed:
709  return false;
710  case IntegerType::Unsigned:
711  return true;
712  }
713  llvm_unreachable("Unexpected IntegerType::SignednessSemantics");
714 }
715 
716 bool CppEmitter::hasValueInScope(Value val) { return valueMapper.count(val); }
717 
718 bool CppEmitter::hasBlockLabel(Block &block) {
719  return blockMapper.count(&block);
720 }
721 
722 LogicalResult CppEmitter::emitAttribute(Location loc, Attribute attr) {
723  auto printInt = [&](const APInt &val, bool isUnsigned) {
724  if (val.getBitWidth() == 1) {
725  if (val.getBoolValue())
726  os << "true";
727  else
728  os << "false";
729  } else {
730  SmallString<128> strValue;
731  val.toString(strValue, 10, !isUnsigned, false);
732  os << strValue;
733  }
734  };
735 
736  auto printFloat = [&](const APFloat &val) {
737  if (val.isFinite()) {
738  SmallString<128> strValue;
739  // Use default values of toString except don't truncate zeros.
740  val.toString(strValue, 0, 0, false);
741  switch (llvm::APFloatBase::SemanticsToEnum(val.getSemantics())) {
742  case llvm::APFloatBase::S_IEEEsingle:
743  os << "(float)";
744  break;
745  case llvm::APFloatBase::S_IEEEdouble:
746  os << "(double)";
747  break;
748  default:
749  break;
750  };
751  os << strValue;
752  } else if (val.isNaN()) {
753  os << "NAN";
754  } else if (val.isInfinity()) {
755  if (val.isNegative())
756  os << "-";
757  os << "INFINITY";
758  }
759  };
760 
761  // Print floating point attributes.
762  if (auto fAttr = attr.dyn_cast<FloatAttr>()) {
763  printFloat(fAttr.getValue());
764  return success();
765  }
766  if (auto dense = attr.dyn_cast<DenseFPElementsAttr>()) {
767  os << '{';
768  interleaveComma(dense, os, [&](const APFloat &val) { printFloat(val); });
769  os << '}';
770  return success();
771  }
772 
773  // Print integer attributes.
774  if (auto iAttr = attr.dyn_cast<IntegerAttr>()) {
775  if (auto iType = iAttr.getType().dyn_cast<IntegerType>()) {
776  printInt(iAttr.getValue(), shouldMapToUnsigned(iType.getSignedness()));
777  return success();
778  }
779  if (auto iType = iAttr.getType().dyn_cast<IndexType>()) {
780  printInt(iAttr.getValue(), false);
781  return success();
782  }
783  }
784  if (auto dense = attr.dyn_cast<DenseIntElementsAttr>()) {
785  if (auto iType = dense.getType()
786  .cast<TensorType>()
787  .getElementType()
788  .dyn_cast<IntegerType>()) {
789  os << '{';
790  interleaveComma(dense, os, [&](const APInt &val) {
791  printInt(val, shouldMapToUnsigned(iType.getSignedness()));
792  });
793  os << '}';
794  return success();
795  }
796  if (auto iType = dense.getType()
797  .cast<TensorType>()
798  .getElementType()
799  .dyn_cast<IndexType>()) {
800  os << '{';
801  interleaveComma(dense, os,
802  [&](const APInt &val) { printInt(val, false); });
803  os << '}';
804  return success();
805  }
806  }
807 
808  // Print opaque attributes.
809  if (auto oAttr = attr.dyn_cast<emitc::OpaqueAttr>()) {
810  os << oAttr.getValue();
811  return success();
812  }
813 
814  // Print symbolic reference attributes.
815  if (auto sAttr = attr.dyn_cast<SymbolRefAttr>()) {
816  if (sAttr.getNestedReferences().size() > 1)
817  return emitError(loc, "attribute has more than 1 nested reference");
818  os << sAttr.getRootReference().getValue();
819  return success();
820  }
821 
822  // Print type attributes.
823  if (auto type = attr.dyn_cast<TypeAttr>())
824  return emitType(loc, type.getValue());
825 
826  return emitError(loc, "cannot emit attribute: ") << attr;
827 }
828 
829 LogicalResult CppEmitter::emitOperands(Operation &op) {
830  auto emitOperandName = [&](Value result) -> LogicalResult {
831  if (!hasValueInScope(result))
832  return op.emitOpError() << "operand value not in scope";
833  os << getOrCreateName(result);
834  return success();
835  };
836  return interleaveCommaWithError(op.getOperands(), os, emitOperandName);
837 }
838 
840 CppEmitter::emitOperandsAndAttributes(Operation &op,
841  ArrayRef<StringRef> exclude) {
842  if (failed(emitOperands(op)))
843  return failure();
844  // Insert comma in between operands and non-filtered attributes if needed.
845  if (op.getNumOperands() > 0) {
846  for (NamedAttribute attr : op.getAttrs()) {
847  if (!llvm::is_contained(exclude, attr.getName().strref())) {
848  os << ", ";
849  break;
850  }
851  }
852  }
853  // Emit attributes.
854  auto emitNamedAttribute = [&](NamedAttribute attr) -> LogicalResult {
855  if (llvm::is_contained(exclude, attr.getName().strref()))
856  return success();
857  os << "/* " << attr.getName().getValue() << " */";
858  if (failed(emitAttribute(op.getLoc(), attr.getValue())))
859  return failure();
860  return success();
861  };
862  return interleaveCommaWithError(op.getAttrs(), os, emitNamedAttribute);
863 }
864 
865 LogicalResult CppEmitter::emitVariableAssignment(OpResult result) {
866  if (!hasValueInScope(result)) {
867  return result.getDefiningOp()->emitOpError(
868  "result variable for the operation has not been declared");
869  }
870  os << getOrCreateName(result) << " = ";
871  return success();
872 }
873 
874 LogicalResult CppEmitter::emitVariableDeclaration(OpResult result,
875  bool trailingSemicolon) {
876  if (hasValueInScope(result)) {
877  return result.getDefiningOp()->emitError(
878  "result variable for the operation already declared");
879  }
880  if (failed(emitType(result.getOwner()->getLoc(), result.getType())))
881  return failure();
882  os << " " << getOrCreateName(result);
883  if (trailingSemicolon)
884  os << ";\n";
885  return success();
886 }
887 
888 LogicalResult CppEmitter::emitAssignPrefix(Operation &op) {
889  switch (op.getNumResults()) {
890  case 0:
891  break;
892  case 1: {
893  OpResult result = op.getResult(0);
894  if (shouldDeclareVariablesAtTop()) {
895  if (failed(emitVariableAssignment(result)))
896  return failure();
897  } else {
898  if (failed(emitVariableDeclaration(result, /*trailingSemicolon=*/false)))
899  return failure();
900  os << " = ";
901  }
902  break;
903  }
904  default:
905  if (!shouldDeclareVariablesAtTop()) {
906  for (OpResult result : op.getResults()) {
907  if (failed(emitVariableDeclaration(result, /*trailingSemicolon=*/true)))
908  return failure();
909  }
910  }
911  os << "std::tie(";
912  interleaveComma(op.getResults(), os,
913  [&](Value result) { os << getOrCreateName(result); });
914  os << ") = ";
915  }
916  return success();
917 }
918 
919 LogicalResult CppEmitter::emitLabel(Block &block) {
920  if (!hasBlockLabel(block))
921  return block.getParentOp()->emitError("label for block not found");
922  // FIXME: Add feature in `raw_indented_ostream` to ignore indent for block
923  // label instead of using `getOStream`.
924  os.getOStream() << getOrCreateName(block) << ":\n";
925  return success();
926 }
927 
928 LogicalResult CppEmitter::emitOperation(Operation &op, bool trailingSemicolon) {
929  LogicalResult status =
931  // Builtin ops.
932  .Case<ModuleOp>([&](auto op) { return printOperation(*this, op); })
933  // CF ops.
934  .Case<cf::BranchOp, cf::CondBranchOp>(
935  [&](auto op) { return printOperation(*this, op); })
936  // EmitC ops.
937  .Case<emitc::ApplyOp, emitc::CallOp, emitc::CastOp, emitc::ConstantOp,
938  emitc::IncludeOp, emitc::VariableOp>(
939  [&](auto op) { return printOperation(*this, op); })
940  // Func ops.
941  .Case<func::CallOp, func::ConstantOp, func::FuncOp, func::ReturnOp>(
942  [&](auto op) { return printOperation(*this, op); })
943  // SCF ops.
944  .Case<scf::ForOp, scf::IfOp, scf::YieldOp>(
945  [&](auto op) { return printOperation(*this, op); })
946  // Arithmetic ops.
947  .Case<arith::ConstantOp>(
948  [&](auto op) { return printOperation(*this, op); })
949  .Default([&](Operation *) {
950  return op.emitOpError("unable to find printer for op");
951  });
952 
953  if (failed(status))
954  return failure();
955  os << (trailingSemicolon ? ";\n" : "\n");
956  return success();
957 }
958 
959 LogicalResult CppEmitter::emitType(Location loc, Type type) {
960  if (auto iType = type.dyn_cast<IntegerType>()) {
961  switch (iType.getWidth()) {
962  case 1:
963  return (os << "bool"), success();
964  case 8:
965  case 16:
966  case 32:
967  case 64:
968  if (shouldMapToUnsigned(iType.getSignedness()))
969  return (os << "uint" << iType.getWidth() << "_t"), success();
970  else
971  return (os << "int" << iType.getWidth() << "_t"), success();
972  default:
973  return emitError(loc, "cannot emit integer type ") << type;
974  }
975  }
976  if (auto fType = type.dyn_cast<FloatType>()) {
977  switch (fType.getWidth()) {
978  case 32:
979  return (os << "float"), success();
980  case 64:
981  return (os << "double"), success();
982  default:
983  return emitError(loc, "cannot emit float type ") << type;
984  }
985  }
986  if (auto iType = type.dyn_cast<IndexType>())
987  return (os << "size_t"), success();
988  if (auto tType = type.dyn_cast<TensorType>()) {
989  if (!tType.hasRank())
990  return emitError(loc, "cannot emit unranked tensor type");
991  if (!tType.hasStaticShape())
992  return emitError(loc, "cannot emit tensor type with non static shape");
993  os << "Tensor<";
994  if (failed(emitType(loc, tType.getElementType())))
995  return failure();
996  auto shape = tType.getShape();
997  for (auto dimSize : shape) {
998  os << ", ";
999  os << dimSize;
1000  }
1001  os << ">";
1002  return success();
1003  }
1004  if (auto tType = type.dyn_cast<TupleType>())
1005  return emitTupleType(loc, tType.getTypes());
1006  if (auto oType = type.dyn_cast<emitc::OpaqueType>()) {
1007  os << oType.getValue();
1008  return success();
1009  }
1010  if (auto pType = type.dyn_cast<emitc::PointerType>()) {
1011  if (failed(emitType(loc, pType.getPointee())))
1012  return failure();
1013  os << "*";
1014  return success();
1015  }
1016  return emitError(loc, "cannot emit type ") << type;
1017 }
1018 
1019 LogicalResult CppEmitter::emitTypes(Location loc, ArrayRef<Type> types) {
1020  switch (types.size()) {
1021  case 0:
1022  os << "void";
1023  return success();
1024  case 1:
1025  return emitType(loc, types.front());
1026  default:
1027  return emitTupleType(loc, types);
1028  }
1029 }
1030 
1031 LogicalResult CppEmitter::emitTupleType(Location loc, ArrayRef<Type> types) {
1032  os << "std::tuple<";
1034  types, os, [&](Type type) { return emitType(loc, type); })))
1035  return failure();
1036  os << ">";
1037  return success();
1038 }
1039 
1041  bool declareVariablesAtTop) {
1042  CppEmitter emitter(os, declareVariablesAtTop);
1043  return emitter.emitOperation(*op, /*trailingSemicolon=*/false);
1044 }
LogicalResult interleaveWithError(ForwardIterator begin, ForwardIterator end, UnaryFunctor eachFn, NullaryFunctor betweenFn)
Convenience functions to produce interleaved output with functions returning a LogicalResult.
Include the generated interface declarations.
Block * getSuccessor(unsigned i)
Definition: Block.cpp:240
This class contains a list of basic blocks and a link to the parent operation it is attached to...
Definition: Region.h:26
An attribute that represents a reference to a dense float vector or tensor object.
Operation is a basic unit of execution within MLIR.
Definition: Operation.h:28
Operation * getParentOp()
Returns the closest surrounding operation that contains this block.
Definition: Block.cpp:30
BlockListType & getBlocks()
Definition: Region.h:45
This is a value defined by a result of an operation.
Definition: Value.h:425
operand_range getOperands()
Returns an iterator on the underlying Value&#39;s.
Definition: Operation.h:295
Block represents an ordered list of Operations.
Definition: Block.h:29
bool wasInterrupted() const
Returns true if the walk was interrupted.
Definition: Visitors.h:55
ArrayRef< NamedAttribute > getAttrs()
Return all of the attributes on this operation.
Definition: Operation.h:356
Value getOperand(unsigned idx)
Definition: Operation.h:267
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value...
Definition: LogicalResult.h:72
unsigned getNumOperands()
Definition: Operation.h:263
static Type getElementType(Type type, ArrayRef< int32_t > indices, function_ref< InFlightDiagnostic(StringRef)> emitErrorFn)
Walks the given type hierarchy with the given indices, potentially down to component granularity...
Definition: SPIRVOps.cpp:685
This class implements the result iterators for the Operation class.
Definition: ValueRange.h:228
raw_indented_ostream & indent()
Increases the indent and returning this raw_indented_ostream.
static LogicalResult printOperation(CppEmitter &emitter, emitc::ConstantOp constantOp)
static constexpr const bool value
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:48
Operation * getOwner() const
Returns the operation that owns this result.
Definition: Value.h:434
LogicalResult translateToCpp(Operation *op, raw_ostream &os, bool declareVariablesAtTop=false)
Translates the given operation to C++ code.
NamedAttribute represents a combination of a name and an Attribute value.
Definition: Attributes.h:149
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:56
This class represents an efficient way to signal success or failure.
Definition: LogicalResult.h:26
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:62
iterator_range< OpIterator > getOps()
Definition: Region.h:172
bool empty()
Definition: Region.h:60
U dyn_cast() const
Definition: Types.h:270
Attributes are known-constant values of operations.
Definition: Attributes.h:24
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
Definition: Operation.h:165
OpResult getResult(unsigned idx)
Get the &#39;idx&#39;th result of this operation.
Definition: Operation.h:324
static WalkResult advance()
Definition: Visitors.h:51
Location getLoc()
The source location the operation was defined or derived from.
Definition: Operation.h:154
BlockArgListType getArguments()
Definition: Block.h:76
A utility result that is used to signal how to proceed with an ongoing walk:
Definition: Visitors.h:34
This class represents an argument of a Block.
Definition: Value.h:300
Tensor types represent multi-dimensional arrays, and have two variants: RankedTensorType and Unranked...
Definition: BuiltinTypes.h:76
LogicalResult interleaveCommaWithError(const Container &c, raw_ostream &os, UnaryFunctor eachFn)
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:72
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:85
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
Type getType() const
Return the type of this value.
Definition: Value.h:118
U dyn_cast() const
Definition: Attributes.h:127
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition: Value.cpp:20
This class implements the operand iterators for the Operation class.
Definition: ValueRange.h:40
static LogicalResult printConstantOp(CppEmitter &emitter, Operation *operation, Attribute value)
raw_indented_ostream & unindent()
Decreases the indent and returning this raw_indented_ostream.
unsigned getNumResults()
Return the number of results held by this operation.
Definition: Operation.h:321
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "&#39;dim&#39; op " which is convenient for verifiers...
Definition: Operation.cpp:508
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
Definition: Operation.cpp:221
raw_ostream subclass that simplifies indention a sequence of code.
result_range getResults()
Definition: Operation.h:332
llvm::iplist< Block > BlockListType
Definition: Region.h:44
An attribute that represents a reference to a dense integer vector or tensor object.