MLIR  20.0.0git
TranslateToCpp.cpp
Go to the documentation of this file.
1 //===- TranslateToCpp.cpp - Translating to C++ calls ----------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
12 #include "mlir/IR/BuiltinOps.h"
13 #include "mlir/IR/BuiltinTypes.h"
14 #include "mlir/IR/Dialect.h"
15 #include "mlir/IR/Operation.h"
16 #include "mlir/IR/SymbolTable.h"
18 #include "mlir/Support/LLVM.h"
20 #include "llvm/ADT/DenseMap.h"
21 #include "llvm/ADT/ScopedHashTable.h"
22 #include "llvm/ADT/StringExtras.h"
23 #include "llvm/ADT/StringMap.h"
24 #include "llvm/ADT/TypeSwitch.h"
25 #include "llvm/Support/Debug.h"
26 #include "llvm/Support/FormatVariadic.h"
27 #include <stack>
28 #include <utility>
29 
30 #define DEBUG_TYPE "translate-to-cpp"
31 
32 using namespace mlir;
33 using namespace mlir::emitc;
34 using llvm::formatv;
35 
36 /// Convenience functions to produce interleaved output with functions returning
37 /// a LogicalResult. This is different than those in STLExtras as functions used
38 /// on each element doesn't return a string.
39 template <typename ForwardIterator, typename UnaryFunctor,
40  typename NullaryFunctor>
41 inline LogicalResult
43  UnaryFunctor eachFn, NullaryFunctor betweenFn) {
44  if (begin == end)
45  return success();
46  if (failed(eachFn(*begin)))
47  return failure();
48  ++begin;
49  for (; begin != end; ++begin) {
50  betweenFn();
51  if (failed(eachFn(*begin)))
52  return failure();
53  }
54  return success();
55 }
56 
57 template <typename Container, typename UnaryFunctor, typename NullaryFunctor>
58 inline LogicalResult interleaveWithError(const Container &c,
59  UnaryFunctor eachFn,
60  NullaryFunctor betweenFn) {
61  return interleaveWithError(c.begin(), c.end(), eachFn, betweenFn);
62 }
63 
64 template <typename Container, typename UnaryFunctor>
65 inline LogicalResult interleaveCommaWithError(const Container &c,
66  raw_ostream &os,
67  UnaryFunctor eachFn) {
68  return interleaveWithError(c.begin(), c.end(), eachFn, [&]() { os << ", "; });
69 }
70 
71 /// Return the precedence of a operator as an integer, higher values
72 /// imply higher precedence.
73 static FailureOr<int> getOperatorPrecedence(Operation *operation) {
75  .Case<emitc::AddOp>([&](auto op) { return 12; })
76  .Case<emitc::ApplyOp>([&](auto op) { return 15; })
77  .Case<emitc::BitwiseAndOp>([&](auto op) { return 7; })
78  .Case<emitc::BitwiseLeftShiftOp>([&](auto op) { return 11; })
79  .Case<emitc::BitwiseNotOp>([&](auto op) { return 15; })
80  .Case<emitc::BitwiseOrOp>([&](auto op) { return 5; })
81  .Case<emitc::BitwiseRightShiftOp>([&](auto op) { return 11; })
82  .Case<emitc::BitwiseXorOp>([&](auto op) { return 6; })
83  .Case<emitc::CallOp>([&](auto op) { return 16; })
84  .Case<emitc::CallOpaqueOp>([&](auto op) { return 16; })
85  .Case<emitc::CastOp>([&](auto op) { return 15; })
86  .Case<emitc::CmpOp>([&](auto op) -> FailureOr<int> {
87  switch (op.getPredicate()) {
88  case emitc::CmpPredicate::eq:
89  case emitc::CmpPredicate::ne:
90  return 8;
91  case emitc::CmpPredicate::lt:
92  case emitc::CmpPredicate::le:
93  case emitc::CmpPredicate::gt:
94  case emitc::CmpPredicate::ge:
95  return 9;
96  case emitc::CmpPredicate::three_way:
97  return 10;
98  }
99  return op->emitError("unsupported cmp predicate");
100  })
101  .Case<emitc::ConditionalOp>([&](auto op) { return 2; })
102  .Case<emitc::DivOp>([&](auto op) { return 13; })
103  .Case<emitc::LogicalAndOp>([&](auto op) { return 4; })
104  .Case<emitc::LogicalNotOp>([&](auto op) { return 15; })
105  .Case<emitc::LogicalOrOp>([&](auto op) { return 3; })
106  .Case<emitc::MulOp>([&](auto op) { return 13; })
107  .Case<emitc::RemOp>([&](auto op) { return 13; })
108  .Case<emitc::SubOp>([&](auto op) { return 12; })
109  .Case<emitc::UnaryMinusOp>([&](auto op) { return 15; })
110  .Case<emitc::UnaryPlusOp>([&](auto op) { return 15; })
111  .Default([](auto op) { return op->emitError("unsupported operation"); });
112 }
113 
114 namespace {
115 /// Emitter that uses dialect specific emitters to emit C++ code.
116 struct CppEmitter {
117  explicit CppEmitter(raw_ostream &os, bool declareVariablesAtTop);
118 
119  /// Emits attribute or returns failure.
120  LogicalResult emitAttribute(Location loc, Attribute attr);
121 
122  /// Emits operation 'op' with/without training semicolon or returns failure.
123  LogicalResult emitOperation(Operation &op, bool trailingSemicolon);
124 
125  /// Emits type 'type' or returns failure.
126  LogicalResult emitType(Location loc, Type type);
127 
128  /// Emits array of types as a std::tuple of the emitted types.
129  /// - emits void for an empty array;
130  /// - emits the type of the only element for arrays of size one;
131  /// - emits a std::tuple otherwise;
132  LogicalResult emitTypes(Location loc, ArrayRef<Type> types);
133 
134  /// Emits array of types as a std::tuple of the emitted types independently of
135  /// the array size.
136  LogicalResult emitTupleType(Location loc, ArrayRef<Type> types);
137 
138  /// Emits an assignment for a variable which has been declared previously.
139  LogicalResult emitVariableAssignment(OpResult result);
140 
141  /// Emits a variable declaration for a result of an operation.
142  LogicalResult emitVariableDeclaration(OpResult result,
143  bool trailingSemicolon);
144 
145  /// Emits a declaration of a variable with the given type and name.
146  LogicalResult emitVariableDeclaration(Location loc, Type type,
147  StringRef name);
148 
149  /// Emits the variable declaration and assignment prefix for 'op'.
150  /// - emits separate variable followed by std::tie for multi-valued operation;
151  /// - emits single type followed by variable for single result;
152  /// - emits nothing if no value produced by op;
153  /// Emits final '=' operator where a type is produced. Returns failure if
154  /// any result type could not be converted.
155  LogicalResult emitAssignPrefix(Operation &op);
156 
157  /// Emits a global variable declaration or definition.
158  LogicalResult emitGlobalVariable(GlobalOp op);
159 
160  /// Emits a label for the block.
161  LogicalResult emitLabel(Block &block);
162 
163  /// Emits the operands and atttributes of the operation. All operands are
164  /// emitted first and then all attributes in alphabetical order.
165  LogicalResult emitOperandsAndAttributes(Operation &op,
166  ArrayRef<StringRef> exclude = {});
167 
168  /// Emits the operands of the operation. All operands are emitted in order.
169  LogicalResult emitOperands(Operation &op);
170 
171  /// Emits value as an operands of an operation
172  LogicalResult emitOperand(Value value);
173 
174  /// Emit an expression as a C expression.
175  LogicalResult emitExpression(ExpressionOp expressionOp);
176 
177  /// Insert the expression representing the operation into the value cache.
178  void cacheDeferredOpResult(Value value, StringRef str);
179 
180  /// Return the existing or a new name for a Value.
181  StringRef getOrCreateName(Value val);
182 
183  // Returns the textual representation of a subscript operation.
184  std::string getSubscriptName(emitc::SubscriptOp op);
185 
186  // Returns the textual representation of a member (of object) operation.
187  std::string createMemberAccess(emitc::MemberOp op);
188 
189  // Returns the textual representation of a member of pointer operation.
190  std::string createMemberAccess(emitc::MemberOfPtrOp op);
191 
192  /// Return the existing or a new label of a Block.
193  StringRef getOrCreateName(Block &block);
194 
195  /// Whether to map an mlir integer to a unsigned integer in C++.
196  bool shouldMapToUnsigned(IntegerType::SignednessSemantics val);
197 
198  /// RAII helper function to manage entering/exiting C++ scopes.
199  struct Scope {
200  Scope(CppEmitter &emitter)
201  : valueMapperScope(emitter.valueMapper),
202  blockMapperScope(emitter.blockMapper), emitter(emitter) {
203  emitter.valueInScopeCount.push(emitter.valueInScopeCount.top());
204  emitter.labelInScopeCount.push(emitter.labelInScopeCount.top());
205  }
206  ~Scope() {
207  emitter.valueInScopeCount.pop();
208  emitter.labelInScopeCount.pop();
209  }
210 
211  private:
212  llvm::ScopedHashTableScope<Value, std::string> valueMapperScope;
213  llvm::ScopedHashTableScope<Block *, std::string> blockMapperScope;
214  CppEmitter &emitter;
215  };
216 
217  /// Returns wether the Value is assigned to a C++ variable in the scope.
218  bool hasValueInScope(Value val);
219 
220  // Returns whether a label is assigned to the block.
221  bool hasBlockLabel(Block &block);
222 
223  /// Returns the output stream.
224  raw_indented_ostream &ostream() { return os; };
225 
226  /// Returns if all variables for op results and basic block arguments need to
227  /// be declared at the beginning of a function.
228  bool shouldDeclareVariablesAtTop() { return declareVariablesAtTop; };
229 
230  /// Get expression currently being emitted.
231  ExpressionOp getEmittedExpression() { return emittedExpression; }
232 
233  /// Determine whether given value is part of the expression potentially being
234  /// emitted.
235  bool isPartOfCurrentExpression(Value value) {
236  if (!emittedExpression)
237  return false;
238  Operation *def = value.getDefiningOp();
239  if (!def)
240  return false;
241  auto operandExpression = dyn_cast<ExpressionOp>(def->getParentOp());
242  return operandExpression == emittedExpression;
243  };
244 
245 private:
246  using ValueMapper = llvm::ScopedHashTable<Value, std::string>;
247  using BlockMapper = llvm::ScopedHashTable<Block *, std::string>;
248 
249  /// Output stream to emit to.
251 
252  /// Boolean to enforce that all variables for op results and block
253  /// arguments are declared at the beginning of the function. This also
254  /// includes results from ops located in nested regions.
255  bool declareVariablesAtTop;
256 
257  /// Map from value to name of C++ variable that contain the name.
258  ValueMapper valueMapper;
259 
260  /// Map from block to name of C++ label.
261  BlockMapper blockMapper;
262 
263  /// The number of values in the current scope. This is used to declare the
264  /// names of values in a scope.
265  std::stack<int64_t> valueInScopeCount;
266  std::stack<int64_t> labelInScopeCount;
267 
268  /// State of the current expression being emitted.
269  ExpressionOp emittedExpression;
270  SmallVector<int> emittedExpressionPrecedence;
271 
272  void pushExpressionPrecedence(int precedence) {
273  emittedExpressionPrecedence.push_back(precedence);
274  }
275  void popExpressionPrecedence() { emittedExpressionPrecedence.pop_back(); }
276  static int lowestPrecedence() { return 0; }
277  int getExpressionPrecedence() {
278  if (emittedExpressionPrecedence.empty())
279  return lowestPrecedence();
280  return emittedExpressionPrecedence.back();
281  }
282 };
283 } // namespace
284 
285 /// Determine whether expression \p op should be emitted in a deferred way.
286 static bool hasDeferredEmission(Operation *op) {
287  return isa_and_nonnull<emitc::GetGlobalOp, emitc::LiteralOp, emitc::MemberOp,
288  emitc::MemberOfPtrOp, emitc::SubscriptOp>(op);
289 }
290 
291 /// Determine whether expression \p expressionOp should be emitted inline, i.e.
292 /// as part of its user. This function recommends inlining of any expressions
293 /// that can be inlined unless it is used by another expression, under the
294 /// assumption that any expression fusion/re-materialization was taken care of
295 /// by transformations run by the backend.
296 static bool shouldBeInlined(ExpressionOp expressionOp) {
297  // Do not inline if expression is marked as such.
298  if (expressionOp.getDoNotInline())
299  return false;
300 
301  // Do not inline expressions with side effects to prevent side-effect
302  // reordering.
303  if (expressionOp.hasSideEffects())
304  return false;
305 
306  // Do not inline expressions with multiple uses.
307  Value result = expressionOp.getResult();
308  if (!result.hasOneUse())
309  return false;
310 
311  Operation *user = *result.getUsers().begin();
312 
313  // Do not inline expressions used by operations with deferred emission, since
314  // their translation requires the materialization of variables.
315  if (hasDeferredEmission(user))
316  return false;
317 
318  // Do not inline expressions used by ops with the CExpression trait. If this
319  // was intended, the user could have been merged into the expression op.
320  return !user->hasTrait<OpTrait::emitc::CExpression>();
321 }
322 
323 static LogicalResult printConstantOp(CppEmitter &emitter, Operation *operation,
324  Attribute value) {
325  OpResult result = operation->getResult(0);
326 
327  // Only emit an assignment as the variable was already declared when printing
328  // the FuncOp.
329  if (emitter.shouldDeclareVariablesAtTop()) {
330  // Skip the assignment if the emitc.constant has no value.
331  if (auto oAttr = dyn_cast<emitc::OpaqueAttr>(value)) {
332  if (oAttr.getValue().empty())
333  return success();
334  }
335 
336  if (failed(emitter.emitVariableAssignment(result)))
337  return failure();
338  return emitter.emitAttribute(operation->getLoc(), value);
339  }
340 
341  // Emit a variable declaration for an emitc.constant op without value.
342  if (auto oAttr = dyn_cast<emitc::OpaqueAttr>(value)) {
343  if (oAttr.getValue().empty())
344  // The semicolon gets printed by the emitOperation function.
345  return emitter.emitVariableDeclaration(result,
346  /*trailingSemicolon=*/false);
347  }
348 
349  // Emit a variable declaration.
350  if (failed(emitter.emitAssignPrefix(*operation)))
351  return failure();
352  return emitter.emitAttribute(operation->getLoc(), value);
353 }
354 
355 static LogicalResult printOperation(CppEmitter &emitter,
356  emitc::ConstantOp constantOp) {
357  Operation *operation = constantOp.getOperation();
358  Attribute value = constantOp.getValue();
359 
360  return printConstantOp(emitter, operation, value);
361 }
362 
363 static LogicalResult printOperation(CppEmitter &emitter,
364  emitc::VariableOp variableOp) {
365  Operation *operation = variableOp.getOperation();
366  Attribute value = variableOp.getValue();
367 
368  return printConstantOp(emitter, operation, value);
369 }
370 
371 static LogicalResult printOperation(CppEmitter &emitter,
372  emitc::GlobalOp globalOp) {
373 
374  return emitter.emitGlobalVariable(globalOp);
375 }
376 
377 static LogicalResult printOperation(CppEmitter &emitter,
378  emitc::AssignOp assignOp) {
379  OpResult result = assignOp.getVar().getDefiningOp()->getResult(0);
380 
381  if (failed(emitter.emitVariableAssignment(result)))
382  return failure();
383 
384  return emitter.emitOperand(assignOp.getValue());
385 }
386 
387 static LogicalResult printBinaryOperation(CppEmitter &emitter,
388  Operation *operation,
389  StringRef binaryOperator) {
390  raw_ostream &os = emitter.ostream();
391 
392  if (failed(emitter.emitAssignPrefix(*operation)))
393  return failure();
394 
395  if (failed(emitter.emitOperand(operation->getOperand(0))))
396  return failure();
397 
398  os << " " << binaryOperator << " ";
399 
400  if (failed(emitter.emitOperand(operation->getOperand(1))))
401  return failure();
402 
403  return success();
404 }
405 
406 static LogicalResult printUnaryOperation(CppEmitter &emitter,
407  Operation *operation,
408  StringRef unaryOperator) {
409  raw_ostream &os = emitter.ostream();
410 
411  if (failed(emitter.emitAssignPrefix(*operation)))
412  return failure();
413 
414  os << unaryOperator;
415 
416  if (failed(emitter.emitOperand(operation->getOperand(0))))
417  return failure();
418 
419  return success();
420 }
421 
422 static LogicalResult printOperation(CppEmitter &emitter, emitc::AddOp addOp) {
423  Operation *operation = addOp.getOperation();
424 
425  return printBinaryOperation(emitter, operation, "+");
426 }
427 
428 static LogicalResult printOperation(CppEmitter &emitter, emitc::DivOp divOp) {
429  Operation *operation = divOp.getOperation();
430 
431  return printBinaryOperation(emitter, operation, "/");
432 }
433 
434 static LogicalResult printOperation(CppEmitter &emitter, emitc::MulOp mulOp) {
435  Operation *operation = mulOp.getOperation();
436 
437  return printBinaryOperation(emitter, operation, "*");
438 }
439 
440 static LogicalResult printOperation(CppEmitter &emitter, emitc::RemOp remOp) {
441  Operation *operation = remOp.getOperation();
442 
443  return printBinaryOperation(emitter, operation, "%");
444 }
445 
446 static LogicalResult printOperation(CppEmitter &emitter, emitc::SubOp subOp) {
447  Operation *operation = subOp.getOperation();
448 
449  return printBinaryOperation(emitter, operation, "-");
450 }
451 
452 static LogicalResult printOperation(CppEmitter &emitter, emitc::CmpOp cmpOp) {
453  Operation *operation = cmpOp.getOperation();
454 
455  StringRef binaryOperator;
456 
457  switch (cmpOp.getPredicate()) {
458  case emitc::CmpPredicate::eq:
459  binaryOperator = "==";
460  break;
461  case emitc::CmpPredicate::ne:
462  binaryOperator = "!=";
463  break;
464  case emitc::CmpPredicate::lt:
465  binaryOperator = "<";
466  break;
467  case emitc::CmpPredicate::le:
468  binaryOperator = "<=";
469  break;
470  case emitc::CmpPredicate::gt:
471  binaryOperator = ">";
472  break;
473  case emitc::CmpPredicate::ge:
474  binaryOperator = ">=";
475  break;
476  case emitc::CmpPredicate::three_way:
477  binaryOperator = "<=>";
478  break;
479  }
480 
481  return printBinaryOperation(emitter, operation, binaryOperator);
482 }
483 
484 static LogicalResult printOperation(CppEmitter &emitter,
485  emitc::ConditionalOp conditionalOp) {
486  raw_ostream &os = emitter.ostream();
487 
488  if (failed(emitter.emitAssignPrefix(*conditionalOp)))
489  return failure();
490 
491  if (failed(emitter.emitOperand(conditionalOp.getCondition())))
492  return failure();
493 
494  os << " ? ";
495 
496  if (failed(emitter.emitOperand(conditionalOp.getTrueValue())))
497  return failure();
498 
499  os << " : ";
500 
501  if (failed(emitter.emitOperand(conditionalOp.getFalseValue())))
502  return failure();
503 
504  return success();
505 }
506 
507 static LogicalResult printOperation(CppEmitter &emitter,
508  emitc::VerbatimOp verbatimOp) {
509  raw_ostream &os = emitter.ostream();
510 
511  os << verbatimOp.getValue();
512 
513  return success();
514 }
515 
516 static LogicalResult printOperation(CppEmitter &emitter,
517  cf::BranchOp branchOp) {
518  raw_ostream &os = emitter.ostream();
519  Block &successor = *branchOp.getSuccessor();
520 
521  for (auto pair :
522  llvm::zip(branchOp.getOperands(), successor.getArguments())) {
523  Value &operand = std::get<0>(pair);
524  BlockArgument &argument = std::get<1>(pair);
525  os << emitter.getOrCreateName(argument) << " = "
526  << emitter.getOrCreateName(operand) << ";\n";
527  }
528 
529  os << "goto ";
530  if (!(emitter.hasBlockLabel(successor)))
531  return branchOp.emitOpError("unable to find label for successor block");
532  os << emitter.getOrCreateName(successor);
533  return success();
534 }
535 
536 static LogicalResult printOperation(CppEmitter &emitter,
537  cf::CondBranchOp condBranchOp) {
538  raw_indented_ostream &os = emitter.ostream();
539  Block &trueSuccessor = *condBranchOp.getTrueDest();
540  Block &falseSuccessor = *condBranchOp.getFalseDest();
541 
542  os << "if (" << emitter.getOrCreateName(condBranchOp.getCondition())
543  << ") {\n";
544 
545  os.indent();
546 
547  // If condition is true.
548  for (auto pair : llvm::zip(condBranchOp.getTrueOperands(),
549  trueSuccessor.getArguments())) {
550  Value &operand = std::get<0>(pair);
551  BlockArgument &argument = std::get<1>(pair);
552  os << emitter.getOrCreateName(argument) << " = "
553  << emitter.getOrCreateName(operand) << ";\n";
554  }
555 
556  os << "goto ";
557  if (!(emitter.hasBlockLabel(trueSuccessor))) {
558  return condBranchOp.emitOpError("unable to find label for successor block");
559  }
560  os << emitter.getOrCreateName(trueSuccessor) << ";\n";
561  os.unindent() << "} else {\n";
562  os.indent();
563  // If condition is false.
564  for (auto pair : llvm::zip(condBranchOp.getFalseOperands(),
565  falseSuccessor.getArguments())) {
566  Value &operand = std::get<0>(pair);
567  BlockArgument &argument = std::get<1>(pair);
568  os << emitter.getOrCreateName(argument) << " = "
569  << emitter.getOrCreateName(operand) << ";\n";
570  }
571 
572  os << "goto ";
573  if (!(emitter.hasBlockLabel(falseSuccessor))) {
574  return condBranchOp.emitOpError()
575  << "unable to find label for successor block";
576  }
577  os << emitter.getOrCreateName(falseSuccessor) << ";\n";
578  os.unindent() << "}";
579  return success();
580 }
581 
582 static LogicalResult printCallOperation(CppEmitter &emitter, Operation *callOp,
583  StringRef callee) {
584  if (failed(emitter.emitAssignPrefix(*callOp)))
585  return failure();
586 
587  raw_ostream &os = emitter.ostream();
588  os << callee << "(";
589  if (failed(emitter.emitOperands(*callOp)))
590  return failure();
591  os << ")";
592  return success();
593 }
594 
595 static LogicalResult printOperation(CppEmitter &emitter, func::CallOp callOp) {
596  Operation *operation = callOp.getOperation();
597  StringRef callee = callOp.getCallee();
598 
599  return printCallOperation(emitter, operation, callee);
600 }
601 
602 static LogicalResult printOperation(CppEmitter &emitter, emitc::CallOp callOp) {
603  Operation *operation = callOp.getOperation();
604  StringRef callee = callOp.getCallee();
605 
606  return printCallOperation(emitter, operation, callee);
607 }
608 
609 static LogicalResult printOperation(CppEmitter &emitter,
610  emitc::CallOpaqueOp callOpaqueOp) {
611  raw_ostream &os = emitter.ostream();
612  Operation &op = *callOpaqueOp.getOperation();
613 
614  if (failed(emitter.emitAssignPrefix(op)))
615  return failure();
616  os << callOpaqueOp.getCallee();
617 
618  auto emitArgs = [&](Attribute attr) -> LogicalResult {
619  if (auto t = dyn_cast<IntegerAttr>(attr)) {
620  // Index attributes are treated specially as operand index.
621  if (t.getType().isIndex()) {
622  int64_t idx = t.getInt();
623  Value operand = op.getOperand(idx);
624  if (!emitter.hasValueInScope(operand))
625  return op.emitOpError("operand ")
626  << idx << "'s value not defined in scope";
627  os << emitter.getOrCreateName(operand);
628  return success();
629  }
630  }
631  if (failed(emitter.emitAttribute(op.getLoc(), attr)))
632  return failure();
633 
634  return success();
635  };
636 
637  if (callOpaqueOp.getTemplateArgs()) {
638  os << "<";
639  if (failed(interleaveCommaWithError(*callOpaqueOp.getTemplateArgs(), os,
640  emitArgs)))
641  return failure();
642  os << ">";
643  }
644 
645  os << "(";
646 
647  LogicalResult emittedArgs =
648  callOpaqueOp.getArgs()
649  ? interleaveCommaWithError(*callOpaqueOp.getArgs(), os, emitArgs)
650  : emitter.emitOperands(op);
651  if (failed(emittedArgs))
652  return failure();
653  os << ")";
654  return success();
655 }
656 
657 static LogicalResult printOperation(CppEmitter &emitter,
658  emitc::ApplyOp applyOp) {
659  raw_ostream &os = emitter.ostream();
660  Operation &op = *applyOp.getOperation();
661 
662  if (failed(emitter.emitAssignPrefix(op)))
663  return failure();
664  os << applyOp.getApplicableOperator();
665  os << emitter.getOrCreateName(applyOp.getOperand());
666 
667  return success();
668 }
669 
670 static LogicalResult printOperation(CppEmitter &emitter,
671  emitc::BitwiseAndOp bitwiseAndOp) {
672  Operation *operation = bitwiseAndOp.getOperation();
673  return printBinaryOperation(emitter, operation, "&");
674 }
675 
676 static LogicalResult
677 printOperation(CppEmitter &emitter,
678  emitc::BitwiseLeftShiftOp bitwiseLeftShiftOp) {
679  Operation *operation = bitwiseLeftShiftOp.getOperation();
680  return printBinaryOperation(emitter, operation, "<<");
681 }
682 
683 static LogicalResult printOperation(CppEmitter &emitter,
684  emitc::BitwiseNotOp bitwiseNotOp) {
685  Operation *operation = bitwiseNotOp.getOperation();
686  return printUnaryOperation(emitter, operation, "~");
687 }
688 
689 static LogicalResult printOperation(CppEmitter &emitter,
690  emitc::BitwiseOrOp bitwiseOrOp) {
691  Operation *operation = bitwiseOrOp.getOperation();
692  return printBinaryOperation(emitter, operation, "|");
693 }
694 
695 static LogicalResult
696 printOperation(CppEmitter &emitter,
697  emitc::BitwiseRightShiftOp bitwiseRightShiftOp) {
698  Operation *operation = bitwiseRightShiftOp.getOperation();
699  return printBinaryOperation(emitter, operation, ">>");
700 }
701 
702 static LogicalResult printOperation(CppEmitter &emitter,
703  emitc::BitwiseXorOp bitwiseXorOp) {
704  Operation *operation = bitwiseXorOp.getOperation();
705  return printBinaryOperation(emitter, operation, "^");
706 }
707 
708 static LogicalResult printOperation(CppEmitter &emitter,
709  emitc::UnaryPlusOp unaryPlusOp) {
710  Operation *operation = unaryPlusOp.getOperation();
711  return printUnaryOperation(emitter, operation, "+");
712 }
713 
714 static LogicalResult printOperation(CppEmitter &emitter,
715  emitc::UnaryMinusOp unaryMinusOp) {
716  Operation *operation = unaryMinusOp.getOperation();
717  return printUnaryOperation(emitter, operation, "-");
718 }
719 
720 static LogicalResult printOperation(CppEmitter &emitter, emitc::CastOp castOp) {
721  raw_ostream &os = emitter.ostream();
722  Operation &op = *castOp.getOperation();
723 
724  if (failed(emitter.emitAssignPrefix(op)))
725  return failure();
726  os << "(";
727  if (failed(emitter.emitType(op.getLoc(), op.getResult(0).getType())))
728  return failure();
729  os << ") ";
730  return emitter.emitOperand(castOp.getOperand());
731 }
732 
733 static LogicalResult printOperation(CppEmitter &emitter,
734  emitc::ExpressionOp expressionOp) {
735  if (shouldBeInlined(expressionOp))
736  return success();
737 
738  Operation &op = *expressionOp.getOperation();
739 
740  if (failed(emitter.emitAssignPrefix(op)))
741  return failure();
742 
743  return emitter.emitExpression(expressionOp);
744 }
745 
746 static LogicalResult printOperation(CppEmitter &emitter,
747  emitc::IncludeOp includeOp) {
748  raw_ostream &os = emitter.ostream();
749 
750  os << "#include ";
751  if (includeOp.getIsStandardInclude())
752  os << "<" << includeOp.getInclude() << ">";
753  else
754  os << "\"" << includeOp.getInclude() << "\"";
755 
756  return success();
757 }
758 
759 static LogicalResult printOperation(CppEmitter &emitter,
760  emitc::LogicalAndOp logicalAndOp) {
761  Operation *operation = logicalAndOp.getOperation();
762  return printBinaryOperation(emitter, operation, "&&");
763 }
764 
765 static LogicalResult printOperation(CppEmitter &emitter,
766  emitc::LogicalNotOp logicalNotOp) {
767  Operation *operation = logicalNotOp.getOperation();
768  return printUnaryOperation(emitter, operation, "!");
769 }
770 
771 static LogicalResult printOperation(CppEmitter &emitter,
772  emitc::LogicalOrOp logicalOrOp) {
773  Operation *operation = logicalOrOp.getOperation();
774  return printBinaryOperation(emitter, operation, "||");
775 }
776 
777 static LogicalResult printOperation(CppEmitter &emitter, emitc::ForOp forOp) {
778 
779  raw_indented_ostream &os = emitter.ostream();
780 
781  // Utility function to determine whether a value is an expression that will be
782  // inlined, and as such should be wrapped in parentheses in order to guarantee
783  // its precedence and associativity.
784  auto requiresParentheses = [&](Value value) {
785  auto expressionOp =
786  dyn_cast_if_present<ExpressionOp>(value.getDefiningOp());
787  if (!expressionOp)
788  return false;
789  return shouldBeInlined(expressionOp);
790  };
791 
792  os << "for (";
793  if (failed(
794  emitter.emitType(forOp.getLoc(), forOp.getInductionVar().getType())))
795  return failure();
796  os << " ";
797  os << emitter.getOrCreateName(forOp.getInductionVar());
798  os << " = ";
799  if (failed(emitter.emitOperand(forOp.getLowerBound())))
800  return failure();
801  os << "; ";
802  os << emitter.getOrCreateName(forOp.getInductionVar());
803  os << " < ";
804  Value upperBound = forOp.getUpperBound();
805  bool upperBoundRequiresParentheses = requiresParentheses(upperBound);
806  if (upperBoundRequiresParentheses)
807  os << "(";
808  if (failed(emitter.emitOperand(upperBound)))
809  return failure();
810  if (upperBoundRequiresParentheses)
811  os << ")";
812  os << "; ";
813  os << emitter.getOrCreateName(forOp.getInductionVar());
814  os << " += ";
815  if (failed(emitter.emitOperand(forOp.getStep())))
816  return failure();
817  os << ") {\n";
818  os.indent();
819 
820  Region &forRegion = forOp.getRegion();
821  auto regionOps = forRegion.getOps();
822 
823  // We skip the trailing yield op.
824  for (auto it = regionOps.begin(); std::next(it) != regionOps.end(); ++it) {
825  if (failed(emitter.emitOperation(*it, /*trailingSemicolon=*/true)))
826  return failure();
827  }
828 
829  os.unindent() << "}";
830 
831  return success();
832 }
833 
834 static LogicalResult printOperation(CppEmitter &emitter, emitc::IfOp ifOp) {
835  raw_indented_ostream &os = emitter.ostream();
836 
837  // Helper function to emit all ops except the last one, expected to be
838  // emitc::yield.
839  auto emitAllExceptLast = [&emitter](Region &region) {
840  Region::OpIterator it = region.op_begin(), end = region.op_end();
841  for (; std::next(it) != end; ++it) {
842  if (failed(emitter.emitOperation(*it, /*trailingSemicolon=*/true)))
843  return failure();
844  }
845  assert(isa<emitc::YieldOp>(*it) &&
846  "Expected last operation in the region to be emitc::yield");
847  return success();
848  };
849 
850  os << "if (";
851  if (failed(emitter.emitOperand(ifOp.getCondition())))
852  return failure();
853  os << ") {\n";
854  os.indent();
855  if (failed(emitAllExceptLast(ifOp.getThenRegion())))
856  return failure();
857  os.unindent() << "}";
858 
859  Region &elseRegion = ifOp.getElseRegion();
860  if (!elseRegion.empty()) {
861  os << " else {\n";
862  os.indent();
863  if (failed(emitAllExceptLast(elseRegion)))
864  return failure();
865  os.unindent() << "}";
866  }
867 
868  return success();
869 }
870 
871 static LogicalResult printOperation(CppEmitter &emitter,
872  func::ReturnOp returnOp) {
873  raw_ostream &os = emitter.ostream();
874  os << "return";
875  switch (returnOp.getNumOperands()) {
876  case 0:
877  return success();
878  case 1:
879  os << " ";
880  if (failed(emitter.emitOperand(returnOp.getOperand(0))))
881  return failure();
882  return success();
883  default:
884  os << " std::make_tuple(";
885  if (failed(emitter.emitOperandsAndAttributes(*returnOp.getOperation())))
886  return failure();
887  os << ")";
888  return success();
889  }
890 }
891 
892 static LogicalResult printOperation(CppEmitter &emitter,
893  emitc::ReturnOp returnOp) {
894  raw_ostream &os = emitter.ostream();
895  os << "return";
896  if (returnOp.getNumOperands() == 0)
897  return success();
898 
899  os << " ";
900  if (failed(emitter.emitOperand(returnOp.getOperand())))
901  return failure();
902  return success();
903 }
904 
905 static LogicalResult printOperation(CppEmitter &emitter, ModuleOp moduleOp) {
906  CppEmitter::Scope scope(emitter);
907 
908  for (Operation &op : moduleOp) {
909  if (failed(emitter.emitOperation(op, /*trailingSemicolon=*/false)))
910  return failure();
911  }
912  return success();
913 }
914 
915 static LogicalResult printFunctionArgs(CppEmitter &emitter,
916  Operation *functionOp,
917  ArrayRef<Type> arguments) {
918  raw_indented_ostream &os = emitter.ostream();
919 
920  return (
921  interleaveCommaWithError(arguments, os, [&](Type arg) -> LogicalResult {
922  return emitter.emitType(functionOp->getLoc(), arg);
923  }));
924 }
925 
926 static LogicalResult printFunctionArgs(CppEmitter &emitter,
927  Operation *functionOp,
928  Region::BlockArgListType arguments) {
929  raw_indented_ostream &os = emitter.ostream();
930 
931  return (interleaveCommaWithError(
932  arguments, os, [&](BlockArgument arg) -> LogicalResult {
933  return emitter.emitVariableDeclaration(
934  functionOp->getLoc(), arg.getType(), emitter.getOrCreateName(arg));
935  }));
936 }
937 
938 static LogicalResult printFunctionBody(CppEmitter &emitter,
939  Operation *functionOp,
940  Region::BlockListType &blocks) {
941  raw_indented_ostream &os = emitter.ostream();
942  os.indent();
943 
944  if (emitter.shouldDeclareVariablesAtTop()) {
945  // Declare all variables that hold op results including those from nested
946  // regions.
947  WalkResult result =
948  functionOp->walk<WalkOrder::PreOrder>([&](Operation *op) -> WalkResult {
949  if (isa<emitc::ExpressionOp>(op->getParentOp()) ||
950  (isa<emitc::ExpressionOp>(op) &&
951  shouldBeInlined(cast<emitc::ExpressionOp>(op))))
952  return WalkResult::skip();
953  for (OpResult result : op->getResults()) {
954  if (failed(emitter.emitVariableDeclaration(
955  result, /*trailingSemicolon=*/true))) {
956  return WalkResult(
957  op->emitError("unable to declare result variable for op"));
958  }
959  }
960  return WalkResult::advance();
961  });
962  if (result.wasInterrupted())
963  return failure();
964  }
965 
966  // Create label names for basic blocks.
967  for (Block &block : blocks) {
968  emitter.getOrCreateName(block);
969  }
970 
971  // Declare variables for basic block arguments.
972  for (Block &block : llvm::drop_begin(blocks)) {
973  for (BlockArgument &arg : block.getArguments()) {
974  if (emitter.hasValueInScope(arg))
975  return functionOp->emitOpError(" block argument #")
976  << arg.getArgNumber() << " is out of scope";
977  if (isa<ArrayType>(arg.getType()))
978  return functionOp->emitOpError("cannot emit block argument #")
979  << arg.getArgNumber() << " with array type";
980  if (failed(
981  emitter.emitType(block.getParentOp()->getLoc(), arg.getType()))) {
982  return failure();
983  }
984  os << " " << emitter.getOrCreateName(arg) << ";\n";
985  }
986  }
987 
988  for (Block &block : blocks) {
989  // Only print a label if the block has predecessors.
990  if (!block.hasNoPredecessors()) {
991  if (failed(emitter.emitLabel(block)))
992  return failure();
993  }
994  for (Operation &op : block.getOperations()) {
995  // When generating code for an emitc.if or cf.cond_br op no semicolon
996  // needs to be printed after the closing brace.
997  // When generating code for an emitc.for and emitc.verbatim op, printing a
998  // trailing semicolon is handled within the printOperation function.
999  bool trailingSemicolon =
1000  !isa<cf::CondBranchOp, emitc::DeclareFuncOp, emitc::ForOp,
1001  emitc::IfOp, emitc::VerbatimOp>(op);
1002 
1003  if (failed(emitter.emitOperation(
1004  op, /*trailingSemicolon=*/trailingSemicolon)))
1005  return failure();
1006  }
1007  }
1008 
1009  os.unindent();
1010 
1011  return success();
1012 }
1013 
1014 static LogicalResult printOperation(CppEmitter &emitter,
1015  func::FuncOp functionOp) {
1016  // We need to declare variables at top if the function has multiple blocks.
1017  if (!emitter.shouldDeclareVariablesAtTop() &&
1018  functionOp.getBlocks().size() > 1) {
1019  return functionOp.emitOpError(
1020  "with multiple blocks needs variables declared at top");
1021  }
1022 
1023  if (llvm::any_of(functionOp.getResultTypes(), llvm::IsaPred<ArrayType>)) {
1024  return functionOp.emitOpError() << "cannot emit array type as result type";
1025  }
1026 
1027  CppEmitter::Scope scope(emitter);
1028  raw_indented_ostream &os = emitter.ostream();
1029  if (failed(emitter.emitTypes(functionOp.getLoc(),
1030  functionOp.getFunctionType().getResults())))
1031  return failure();
1032  os << " " << functionOp.getName();
1033 
1034  os << "(";
1035  Operation *operation = functionOp.getOperation();
1036  if (failed(printFunctionArgs(emitter, operation, functionOp.getArguments())))
1037  return failure();
1038  os << ") {\n";
1039  if (failed(printFunctionBody(emitter, operation, functionOp.getBlocks())))
1040  return failure();
1041  os << "}\n";
1042 
1043  return success();
1044 }
1045 
1046 static LogicalResult printOperation(CppEmitter &emitter,
1047  emitc::FuncOp functionOp) {
1048  // We need to declare variables at top if the function has multiple blocks.
1049  if (!emitter.shouldDeclareVariablesAtTop() &&
1050  functionOp.getBlocks().size() > 1) {
1051  return functionOp.emitOpError(
1052  "with multiple blocks needs variables declared at top");
1053  }
1054 
1055  CppEmitter::Scope scope(emitter);
1056  raw_indented_ostream &os = emitter.ostream();
1057  if (functionOp.getSpecifiers()) {
1058  for (Attribute specifier : functionOp.getSpecifiersAttr()) {
1059  os << cast<StringAttr>(specifier).str() << " ";
1060  }
1061  }
1062 
1063  if (failed(emitter.emitTypes(functionOp.getLoc(),
1064  functionOp.getFunctionType().getResults())))
1065  return failure();
1066  os << " " << functionOp.getName();
1067 
1068  os << "(";
1069  Operation *operation = functionOp.getOperation();
1070  if (functionOp.isExternal()) {
1071  if (failed(printFunctionArgs(emitter, operation,
1072  functionOp.getArgumentTypes())))
1073  return failure();
1074  os << ");";
1075  return success();
1076  }
1077  if (failed(printFunctionArgs(emitter, operation, functionOp.getArguments())))
1078  return failure();
1079  os << ") {\n";
1080  if (failed(printFunctionBody(emitter, operation, functionOp.getBlocks())))
1081  return failure();
1082  os << "}\n";
1083 
1084  return success();
1085 }
1086 
1087 static LogicalResult printOperation(CppEmitter &emitter,
1088  DeclareFuncOp declareFuncOp) {
1089  CppEmitter::Scope scope(emitter);
1090  raw_indented_ostream &os = emitter.ostream();
1091 
1092  auto functionOp = SymbolTable::lookupNearestSymbolFrom<emitc::FuncOp>(
1093  declareFuncOp, declareFuncOp.getSymNameAttr());
1094 
1095  if (!functionOp)
1096  return failure();
1097 
1098  if (functionOp.getSpecifiers()) {
1099  for (Attribute specifier : functionOp.getSpecifiersAttr()) {
1100  os << cast<StringAttr>(specifier).str() << " ";
1101  }
1102  }
1103 
1104  if (failed(emitter.emitTypes(functionOp.getLoc(),
1105  functionOp.getFunctionType().getResults())))
1106  return failure();
1107  os << " " << functionOp.getName();
1108 
1109  os << "(";
1110  Operation *operation = functionOp.getOperation();
1111  if (failed(printFunctionArgs(emitter, operation, functionOp.getArguments())))
1112  return failure();
1113  os << ");";
1114 
1115  return success();
1116 }
1117 
1118 CppEmitter::CppEmitter(raw_ostream &os, bool declareVariablesAtTop)
1119  : os(os), declareVariablesAtTop(declareVariablesAtTop) {
1120  valueInScopeCount.push(0);
1121  labelInScopeCount.push(0);
1122 }
1123 
1124 std::string CppEmitter::getSubscriptName(emitc::SubscriptOp op) {
1125  std::string out;
1126  llvm::raw_string_ostream ss(out);
1127  ss << getOrCreateName(op.getValue());
1128  for (auto index : op.getIndices()) {
1129  ss << "[" << getOrCreateName(index) << "]";
1130  }
1131  return out;
1132 }
1133 
1134 std::string CppEmitter::createMemberAccess(emitc::MemberOp op) {
1135  std::string out;
1136  llvm::raw_string_ostream ss(out);
1137  ss << getOrCreateName(op.getOperand());
1138  ss << "." << op.getMember();
1139  return out;
1140 }
1141 
1142 std::string CppEmitter::createMemberAccess(emitc::MemberOfPtrOp op) {
1143  std::string out;
1144  llvm::raw_string_ostream ss(out);
1145  ss << getOrCreateName(op.getOperand());
1146  ss << "->" << op.getMember();
1147  return out;
1148 }
1149 
1150 void CppEmitter::cacheDeferredOpResult(Value value, StringRef str) {
1151  if (!valueMapper.count(value))
1152  valueMapper.insert(value, str.str());
1153 }
1154 
1155 /// Return the existing or a new name for a Value.
1156 StringRef CppEmitter::getOrCreateName(Value val) {
1157  if (!valueMapper.count(val)) {
1158  assert(!hasDeferredEmission(val.getDefiningOp()) &&
1159  "cacheDeferredOpResult should have been called on this value, "
1160  "update the emitOperation function.");
1161  valueMapper.insert(val, formatv("v{0}", ++valueInScopeCount.top()));
1162  }
1163  return *valueMapper.begin(val);
1164 }
1165 
1166 /// Return the existing or a new label for a Block.
1167 StringRef CppEmitter::getOrCreateName(Block &block) {
1168  if (!blockMapper.count(&block))
1169  blockMapper.insert(&block, formatv("label{0}", ++labelInScopeCount.top()));
1170  return *blockMapper.begin(&block);
1171 }
1172 
1173 bool CppEmitter::shouldMapToUnsigned(IntegerType::SignednessSemantics val) {
1174  switch (val) {
1175  case IntegerType::Signless:
1176  return false;
1177  case IntegerType::Signed:
1178  return false;
1179  case IntegerType::Unsigned:
1180  return true;
1181  }
1182  llvm_unreachable("Unexpected IntegerType::SignednessSemantics");
1183 }
1184 
1185 bool CppEmitter::hasValueInScope(Value val) { return valueMapper.count(val); }
1186 
1187 bool CppEmitter::hasBlockLabel(Block &block) {
1188  return blockMapper.count(&block);
1189 }
1190 
1191 LogicalResult CppEmitter::emitAttribute(Location loc, Attribute attr) {
1192  auto printInt = [&](const APInt &val, bool isUnsigned) {
1193  if (val.getBitWidth() == 1) {
1194  if (val.getBoolValue())
1195  os << "true";
1196  else
1197  os << "false";
1198  } else {
1199  SmallString<128> strValue;
1200  val.toString(strValue, 10, !isUnsigned, false);
1201  os << strValue;
1202  }
1203  };
1204 
1205  auto printFloat = [&](const APFloat &val) {
1206  if (val.isFinite()) {
1207  SmallString<128> strValue;
1208  // Use default values of toString except don't truncate zeros.
1209  val.toString(strValue, 0, 0, false);
1210  os << strValue;
1211  switch (llvm::APFloatBase::SemanticsToEnum(val.getSemantics())) {
1212  case llvm::APFloatBase::S_IEEEsingle:
1213  os << "f";
1214  break;
1215  case llvm::APFloatBase::S_IEEEdouble:
1216  break;
1217  default:
1218  llvm_unreachable("unsupported floating point type");
1219  };
1220  } else if (val.isNaN()) {
1221  os << "NAN";
1222  } else if (val.isInfinity()) {
1223  if (val.isNegative())
1224  os << "-";
1225  os << "INFINITY";
1226  }
1227  };
1228 
1229  // Print floating point attributes.
1230  if (auto fAttr = dyn_cast<FloatAttr>(attr)) {
1231  if (!isa<Float32Type, Float64Type>(fAttr.getType())) {
1232  return emitError(loc,
1233  "expected floating point attribute to be f32 or f64");
1234  }
1235  printFloat(fAttr.getValue());
1236  return success();
1237  }
1238  if (auto dense = dyn_cast<DenseFPElementsAttr>(attr)) {
1239  if (!isa<Float32Type, Float64Type>(dense.getElementType())) {
1240  return emitError(loc,
1241  "expected floating point attribute to be f32 or f64");
1242  }
1243  os << '{';
1244  interleaveComma(dense, os, [&](const APFloat &val) { printFloat(val); });
1245  os << '}';
1246  return success();
1247  }
1248 
1249  // Print integer attributes.
1250  if (auto iAttr = dyn_cast<IntegerAttr>(attr)) {
1251  if (auto iType = dyn_cast<IntegerType>(iAttr.getType())) {
1252  printInt(iAttr.getValue(), shouldMapToUnsigned(iType.getSignedness()));
1253  return success();
1254  }
1255  if (auto iType = dyn_cast<IndexType>(iAttr.getType())) {
1256  printInt(iAttr.getValue(), false);
1257  return success();
1258  }
1259  }
1260  if (auto dense = dyn_cast<DenseIntElementsAttr>(attr)) {
1261  if (auto iType = dyn_cast<IntegerType>(
1262  cast<TensorType>(dense.getType()).getElementType())) {
1263  os << '{';
1264  interleaveComma(dense, os, [&](const APInt &val) {
1265  printInt(val, shouldMapToUnsigned(iType.getSignedness()));
1266  });
1267  os << '}';
1268  return success();
1269  }
1270  if (auto iType = dyn_cast<IndexType>(
1271  cast<TensorType>(dense.getType()).getElementType())) {
1272  os << '{';
1273  interleaveComma(dense, os,
1274  [&](const APInt &val) { printInt(val, false); });
1275  os << '}';
1276  return success();
1277  }
1278  }
1279 
1280  // Print opaque attributes.
1281  if (auto oAttr = dyn_cast<emitc::OpaqueAttr>(attr)) {
1282  os << oAttr.getValue();
1283  return success();
1284  }
1285 
1286  // Print symbolic reference attributes.
1287  if (auto sAttr = dyn_cast<SymbolRefAttr>(attr)) {
1288  if (sAttr.getNestedReferences().size() > 1)
1289  return emitError(loc, "attribute has more than 1 nested reference");
1290  os << sAttr.getRootReference().getValue();
1291  return success();
1292  }
1293 
1294  // Print type attributes.
1295  if (auto type = dyn_cast<TypeAttr>(attr))
1296  return emitType(loc, type.getValue());
1297 
1298  return emitError(loc, "cannot emit attribute: ") << attr;
1299 }
1300 
1301 LogicalResult CppEmitter::emitExpression(ExpressionOp expressionOp) {
1302  assert(emittedExpressionPrecedence.empty() &&
1303  "Expected precedence stack to be empty");
1304  Operation *rootOp = expressionOp.getRootOp();
1305 
1306  emittedExpression = expressionOp;
1307  FailureOr<int> precedence = getOperatorPrecedence(rootOp);
1308  if (failed(precedence))
1309  return failure();
1310  pushExpressionPrecedence(precedence.value());
1311 
1312  if (failed(emitOperation(*rootOp, /*trailingSemicolon=*/false)))
1313  return failure();
1314 
1315  popExpressionPrecedence();
1316  assert(emittedExpressionPrecedence.empty() &&
1317  "Expected precedence stack to be empty");
1318  emittedExpression = nullptr;
1319 
1320  return success();
1321 }
1322 
1323 LogicalResult CppEmitter::emitOperand(Value value) {
1324  if (isPartOfCurrentExpression(value)) {
1325  Operation *def = value.getDefiningOp();
1326  assert(def && "Expected operand to be defined by an operation");
1327  FailureOr<int> precedence = getOperatorPrecedence(def);
1328  if (failed(precedence))
1329  return failure();
1330 
1331  // Sub-expressions with equal or lower precedence need to be parenthesized,
1332  // as they might be evaluated in the wrong order depending on the shape of
1333  // the expression tree.
1334  bool encloseInParenthesis = precedence.value() <= getExpressionPrecedence();
1335  if (encloseInParenthesis) {
1336  os << "(";
1337  pushExpressionPrecedence(lowestPrecedence());
1338  } else
1339  pushExpressionPrecedence(precedence.value());
1340 
1341  if (failed(emitOperation(*def, /*trailingSemicolon=*/false)))
1342  return failure();
1343 
1344  if (encloseInParenthesis)
1345  os << ")";
1346 
1347  popExpressionPrecedence();
1348  return success();
1349  }
1350 
1351  auto expressionOp = dyn_cast_if_present<ExpressionOp>(value.getDefiningOp());
1352  if (expressionOp && shouldBeInlined(expressionOp))
1353  return emitExpression(expressionOp);
1354 
1355  os << getOrCreateName(value);
1356  return success();
1357 }
1358 
1359 LogicalResult CppEmitter::emitOperands(Operation &op) {
1360  return interleaveCommaWithError(op.getOperands(), os, [&](Value operand) {
1361  // If an expression is being emitted, push lowest precedence as these
1362  // operands are either wrapped by parenthesis.
1363  if (getEmittedExpression())
1364  pushExpressionPrecedence(lowestPrecedence());
1365  if (failed(emitOperand(operand)))
1366  return failure();
1367  if (getEmittedExpression())
1368  popExpressionPrecedence();
1369  return success();
1370  });
1371 }
1372 
1373 LogicalResult
1374 CppEmitter::emitOperandsAndAttributes(Operation &op,
1375  ArrayRef<StringRef> exclude) {
1376  if (failed(emitOperands(op)))
1377  return failure();
1378  // Insert comma in between operands and non-filtered attributes if needed.
1379  if (op.getNumOperands() > 0) {
1380  for (NamedAttribute attr : op.getAttrs()) {
1381  if (!llvm::is_contained(exclude, attr.getName().strref())) {
1382  os << ", ";
1383  break;
1384  }
1385  }
1386  }
1387  // Emit attributes.
1388  auto emitNamedAttribute = [&](NamedAttribute attr) -> LogicalResult {
1389  if (llvm::is_contained(exclude, attr.getName().strref()))
1390  return success();
1391  os << "/* " << attr.getName().getValue() << " */";
1392  if (failed(emitAttribute(op.getLoc(), attr.getValue())))
1393  return failure();
1394  return success();
1395  };
1396  return interleaveCommaWithError(op.getAttrs(), os, emitNamedAttribute);
1397 }
1398 
1399 LogicalResult CppEmitter::emitVariableAssignment(OpResult result) {
1400  if (!hasValueInScope(result)) {
1401  return result.getDefiningOp()->emitOpError(
1402  "result variable for the operation has not been declared");
1403  }
1404  os << getOrCreateName(result) << " = ";
1405  return success();
1406 }
1407 
1408 LogicalResult CppEmitter::emitVariableDeclaration(OpResult result,
1409  bool trailingSemicolon) {
1410  if (hasDeferredEmission(result.getDefiningOp()))
1411  return success();
1412  if (hasValueInScope(result)) {
1413  return result.getDefiningOp()->emitError(
1414  "result variable for the operation already declared");
1415  }
1416  if (failed(emitVariableDeclaration(result.getOwner()->getLoc(),
1417  result.getType(),
1418  getOrCreateName(result))))
1419  return failure();
1420  if (trailingSemicolon)
1421  os << ";\n";
1422  return success();
1423 }
1424 
1425 LogicalResult CppEmitter::emitGlobalVariable(GlobalOp op) {
1426  if (op.getExternSpecifier())
1427  os << "extern ";
1428  else if (op.getStaticSpecifier())
1429  os << "static ";
1430  if (op.getConstSpecifier())
1431  os << "const ";
1432 
1433  if (failed(emitVariableDeclaration(op->getLoc(), op.getType(),
1434  op.getSymName()))) {
1435  return failure();
1436  }
1437 
1438  std::optional<Attribute> initialValue = op.getInitialValue();
1439  if (initialValue) {
1440  os << " = ";
1441  if (failed(emitAttribute(op->getLoc(), *initialValue)))
1442  return failure();
1443  }
1444 
1445  os << ";";
1446  return success();
1447 }
1448 
1449 LogicalResult CppEmitter::emitAssignPrefix(Operation &op) {
1450  // If op is being emitted as part of an expression, bail out.
1451  if (getEmittedExpression())
1452  return success();
1453 
1454  switch (op.getNumResults()) {
1455  case 0:
1456  break;
1457  case 1: {
1458  OpResult result = op.getResult(0);
1459  if (shouldDeclareVariablesAtTop()) {
1460  if (failed(emitVariableAssignment(result)))
1461  return failure();
1462  } else {
1463  if (failed(emitVariableDeclaration(result, /*trailingSemicolon=*/false)))
1464  return failure();
1465  os << " = ";
1466  }
1467  break;
1468  }
1469  default:
1470  if (!shouldDeclareVariablesAtTop()) {
1471  for (OpResult result : op.getResults()) {
1472  if (failed(emitVariableDeclaration(result, /*trailingSemicolon=*/true)))
1473  return failure();
1474  }
1475  }
1476  os << "std::tie(";
1477  interleaveComma(op.getResults(), os,
1478  [&](Value result) { os << getOrCreateName(result); });
1479  os << ") = ";
1480  }
1481  return success();
1482 }
1483 
1484 LogicalResult CppEmitter::emitLabel(Block &block) {
1485  if (!hasBlockLabel(block))
1486  return block.getParentOp()->emitError("label for block not found");
1487  // FIXME: Add feature in `raw_indented_ostream` to ignore indent for block
1488  // label instead of using `getOStream`.
1489  os.getOStream() << getOrCreateName(block) << ":\n";
1490  return success();
1491 }
1492 
1493 LogicalResult CppEmitter::emitOperation(Operation &op, bool trailingSemicolon) {
1494  LogicalResult status =
1496  // Builtin ops.
1497  .Case<ModuleOp>([&](auto op) { return printOperation(*this, op); })
1498  // CF ops.
1499  .Case<cf::BranchOp, cf::CondBranchOp>(
1500  [&](auto op) { return printOperation(*this, op); })
1501  // EmitC ops.
1502  .Case<emitc::AddOp, emitc::ApplyOp, emitc::AssignOp,
1503  emitc::BitwiseAndOp, emitc::BitwiseLeftShiftOp,
1504  emitc::BitwiseNotOp, emitc::BitwiseOrOp,
1505  emitc::BitwiseRightShiftOp, emitc::BitwiseXorOp, emitc::CallOp,
1506  emitc::CallOpaqueOp, emitc::CastOp, emitc::CmpOp,
1507  emitc::ConditionalOp, emitc::ConstantOp, emitc::DeclareFuncOp,
1508  emitc::DivOp, emitc::ExpressionOp, emitc::ForOp, emitc::FuncOp,
1509  emitc::GlobalOp, emitc::IfOp, emitc::IncludeOp,
1510  emitc::LogicalAndOp, emitc::LogicalNotOp, emitc::LogicalOrOp,
1511  emitc::MulOp, emitc::RemOp, emitc::ReturnOp, emitc::SubOp,
1512  emitc::UnaryMinusOp, emitc::UnaryPlusOp, emitc::VariableOp,
1513  emitc::VerbatimOp>(
1514  [&](auto op) { return printOperation(*this, op); })
1515  // Func ops.
1516  .Case<func::CallOp, func::FuncOp, func::ReturnOp>(
1517  [&](auto op) { return printOperation(*this, op); })
1518  .Case<emitc::GetGlobalOp>([&](auto op) {
1519  cacheDeferredOpResult(op.getResult(), op.getName());
1520  return success();
1521  })
1522  .Case<emitc::LiteralOp>([&](auto op) {
1523  cacheDeferredOpResult(op.getResult(), op.getValue());
1524  return success();
1525  })
1526  .Case<emitc::MemberOp>([&](auto op) {
1527  cacheDeferredOpResult(op.getResult(), createMemberAccess(op));
1528  return success();
1529  })
1530  .Case<emitc::MemberOfPtrOp>([&](auto op) {
1531  cacheDeferredOpResult(op.getResult(), createMemberAccess(op));
1532  return success();
1533  })
1534  .Case<emitc::SubscriptOp>([&](auto op) {
1535  cacheDeferredOpResult(op.getResult(), getSubscriptName(op));
1536  return success();
1537  })
1538  .Default([&](Operation *) {
1539  return op.emitOpError("unable to find printer for op");
1540  });
1541 
1542  if (failed(status))
1543  return failure();
1544 
1545  if (hasDeferredEmission(&op))
1546  return success();
1547 
1548  if (getEmittedExpression() ||
1549  (isa<emitc::ExpressionOp>(op) &&
1550  shouldBeInlined(cast<emitc::ExpressionOp>(op))))
1551  return success();
1552 
1553  os << (trailingSemicolon ? ";\n" : "\n");
1554 
1555  return success();
1556 }
1557 
1558 LogicalResult CppEmitter::emitVariableDeclaration(Location loc, Type type,
1559  StringRef name) {
1560  if (auto arrType = dyn_cast<emitc::ArrayType>(type)) {
1561  if (failed(emitType(loc, arrType.getElementType())))
1562  return failure();
1563  os << " " << name;
1564  for (auto dim : arrType.getShape()) {
1565  os << "[" << dim << "]";
1566  }
1567  return success();
1568  }
1569  if (failed(emitType(loc, type)))
1570  return failure();
1571  os << " " << name;
1572  return success();
1573 }
1574 
1575 LogicalResult CppEmitter::emitType(Location loc, Type type) {
1576  if (auto iType = dyn_cast<IntegerType>(type)) {
1577  switch (iType.getWidth()) {
1578  case 1:
1579  return (os << "bool"), success();
1580  case 8:
1581  case 16:
1582  case 32:
1583  case 64:
1584  if (shouldMapToUnsigned(iType.getSignedness()))
1585  return (os << "uint" << iType.getWidth() << "_t"), success();
1586  else
1587  return (os << "int" << iType.getWidth() << "_t"), success();
1588  default:
1589  return emitError(loc, "cannot emit integer type ") << type;
1590  }
1591  }
1592  if (auto fType = dyn_cast<FloatType>(type)) {
1593  switch (fType.getWidth()) {
1594  case 32:
1595  return (os << "float"), success();
1596  case 64:
1597  return (os << "double"), success();
1598  default:
1599  return emitError(loc, "cannot emit float type ") << type;
1600  }
1601  }
1602  if (auto iType = dyn_cast<IndexType>(type))
1603  return (os << "size_t"), success();
1604  if (auto sType = dyn_cast<emitc::SizeTType>(type))
1605  return (os << "size_t"), success();
1606  if (auto sType = dyn_cast<emitc::SignedSizeTType>(type))
1607  return (os << "ssize_t"), success();
1608  if (auto pType = dyn_cast<emitc::PtrDiffTType>(type))
1609  return (os << "ptrdiff_t"), success();
1610  if (auto tType = dyn_cast<TensorType>(type)) {
1611  if (!tType.hasRank())
1612  return emitError(loc, "cannot emit unranked tensor type");
1613  if (!tType.hasStaticShape())
1614  return emitError(loc, "cannot emit tensor type with non static shape");
1615  os << "Tensor<";
1616  if (isa<ArrayType>(tType.getElementType()))
1617  return emitError(loc, "cannot emit tensor of array type ") << type;
1618  if (failed(emitType(loc, tType.getElementType())))
1619  return failure();
1620  auto shape = tType.getShape();
1621  for (auto dimSize : shape) {
1622  os << ", ";
1623  os << dimSize;
1624  }
1625  os << ">";
1626  return success();
1627  }
1628  if (auto tType = dyn_cast<TupleType>(type))
1629  return emitTupleType(loc, tType.getTypes());
1630  if (auto oType = dyn_cast<emitc::OpaqueType>(type)) {
1631  os << oType.getValue();
1632  return success();
1633  }
1634  if (auto aType = dyn_cast<emitc::ArrayType>(type)) {
1635  if (failed(emitType(loc, aType.getElementType())))
1636  return failure();
1637  for (auto dim : aType.getShape())
1638  os << "[" << dim << "]";
1639  return success();
1640  }
1641  if (auto pType = dyn_cast<emitc::PointerType>(type)) {
1642  if (isa<ArrayType>(pType.getPointee()))
1643  return emitError(loc, "cannot emit pointer to array type ") << type;
1644  if (failed(emitType(loc, pType.getPointee())))
1645  return failure();
1646  os << "*";
1647  return success();
1648  }
1649  return emitError(loc, "cannot emit type ") << type;
1650 }
1651 
1652 LogicalResult CppEmitter::emitTypes(Location loc, ArrayRef<Type> types) {
1653  switch (types.size()) {
1654  case 0:
1655  os << "void";
1656  return success();
1657  case 1:
1658  return emitType(loc, types.front());
1659  default:
1660  return emitTupleType(loc, types);
1661  }
1662 }
1663 
1664 LogicalResult CppEmitter::emitTupleType(Location loc, ArrayRef<Type> types) {
1665  if (llvm::any_of(types, llvm::IsaPred<ArrayType>)) {
1666  return emitError(loc, "cannot emit tuple of array type");
1667  }
1668  os << "std::tuple<";
1669  if (failed(interleaveCommaWithError(
1670  types, os, [&](Type type) { return emitType(loc, type); })))
1671  return failure();
1672  os << ">";
1673  return success();
1674 }
1675 
1676 LogicalResult emitc::translateToCpp(Operation *op, raw_ostream &os,
1677  bool declareVariablesAtTop) {
1678  CppEmitter emitter(os, declareVariablesAtTop);
1679  return emitter.emitOperation(*op, /*trailingSemicolon=*/false);
1680 }
static LogicalResult printCallOperation(CppEmitter &emitter, Operation *callOp, StringRef callee)
static bool shouldBeInlined(ExpressionOp expressionOp)
Determine whether expression expressionOp should be emitted inline, i.e.
LogicalResult interleaveCommaWithError(const Container &c, raw_ostream &os, UnaryFunctor eachFn)
static FailureOr< int > getOperatorPrecedence(Operation *operation)
Return the precedence of a operator as an integer, higher values imply higher precedence.
static LogicalResult printFunctionArgs(CppEmitter &emitter, Operation *functionOp, ArrayRef< Type > arguments)
static LogicalResult printFunctionBody(CppEmitter &emitter, Operation *functionOp, Region::BlockListType &blocks)
static LogicalResult printConstantOp(CppEmitter &emitter, Operation *operation, Attribute value)
LogicalResult interleaveWithError(ForwardIterator begin, ForwardIterator end, UnaryFunctor eachFn, NullaryFunctor betweenFn)
Convenience functions to produce interleaved output with functions returning a LogicalResult.
static LogicalResult printBinaryOperation(CppEmitter &emitter, Operation *operation, StringRef binaryOperator)
static LogicalResult printOperation(CppEmitter &emitter, emitc::ConstantOp constantOp)
static LogicalResult printUnaryOperation(CppEmitter &emitter, Operation *operation, StringRef unaryOperator)
static bool hasDeferredEmission(Operation *op)
Determine whether expression op should be emitted in a deferred way.
Attributes are known-constant values of operations.
Definition: Attributes.h:25
This class represents an argument of a Block.
Definition: Value.h:319
Block represents an ordered list of Operations.
Definition: Block.h:31
BlockArgListType getArguments()
Definition: Block.h:85
Block * getSuccessor(unsigned i)
Definition: Block.cpp:258
Operation * getParentOp()
Returns the closest surrounding operation that contains this block.
Definition: Block.cpp:30
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:63
NamedAttribute represents a combination of a name and an Attribute value.
Definition: Attributes.h:207
This is a value defined by a result of an operation.
Definition: Value.h:457
Operation * getOwner() const
Returns the operation that owns this result.
Definition: Value.h:466
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
Value getOperand(unsigned idx)
Definition: Operation.h:345
bool hasTrait()
Returns true if the operation was registered with a particular trait, e.g.
Definition: Operation.h:745
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Definition: Operation.h:402
std::enable_if_t< llvm::function_traits< std::decay_t< FnT > >::num_args==1, RetT > walk(FnT &&callback)
Walk the operation by calling the callback for each nested operation (including this one),...
Definition: Operation.h:793
Location getLoc()
The source location the operation was defined or derived from.
Definition: Operation.h:223
unsigned getNumOperands()
Definition: Operation.h:341
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
Definition: Operation.h:234
ArrayRef< NamedAttribute > getAttrs()
Return all of the attributes on this operation.
Definition: Operation.h:507
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
Definition: Operation.cpp:268
OperationName getName()
The name of an operation is the key identifier for it.
Definition: Operation.h:119
operand_range getOperands()
Returns an iterator on the underlying Value's.
Definition: Operation.h:373
result_range getResults()
Definition: Operation.h:410
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
Definition: Operation.cpp:671
unsigned getNumResults()
Return the number of results held by this operation.
Definition: Operation.h:399
This class provides iteration over the held operations of blocks directly within a region.
Definition: Region.h:134
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition: Region.h:26
iterator_range< OpIterator > getOps()
Definition: Region.h:172
llvm::iplist< Block > BlockListType
Definition: Region.h:44
bool empty()
Definition: Region.h:60
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:129
user_range getUsers() const
Definition: Value.h:228
bool hasOneUse() const
Returns true if this value has exactly one use.
Definition: Value.h:215
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition: Value.cpp:20
A utility result that is used to signal how to proceed with an ongoing walk:
Definition: Visitors.h:33
static WalkResult skip()
Definition: Visitors.h:52
static WalkResult advance()
Definition: Visitors.h:51
bool wasInterrupted() const
Returns true if the walk was interrupted.
Definition: Visitors.h:55
raw_ostream subclass that simplifies indention a sequence of code.
raw_indented_ostream & unindent()
Decreases the indent and returning this raw_indented_ostream.
raw_indented_ostream & indent()
Increases the indent and returning this raw_indented_ostream.
LogicalResult translateToCpp(Operation *op, raw_ostream &os, bool declareVariablesAtTop=false)
Translates the given operation to C++ code.
Include the generated interface declarations.
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
This iterator enumerates the elements in "forward" order.
Definition: Visitors.h:65