MLIR  21.0.0git
TranslateToCpp.cpp
Go to the documentation of this file.
1 //===- TranslateToCpp.cpp - Translating to C++ calls ----------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
12 #include "mlir/IR/BuiltinOps.h"
13 #include "mlir/IR/BuiltinTypes.h"
14 #include "mlir/IR/Dialect.h"
15 #include "mlir/IR/Operation.h"
16 #include "mlir/IR/SymbolTable.h"
18 #include "mlir/Support/LLVM.h"
20 #include "llvm/ADT/DenseMap.h"
21 #include "llvm/ADT/ScopedHashTable.h"
22 #include "llvm/ADT/StringExtras.h"
23 #include "llvm/ADT/StringMap.h"
24 #include "llvm/ADT/TypeSwitch.h"
25 #include "llvm/Support/Debug.h"
26 #include "llvm/Support/FormatVariadic.h"
27 #include <stack>
28 #include <utility>
29 
30 #define DEBUG_TYPE "translate-to-cpp"
31 
32 using namespace mlir;
33 using namespace mlir::emitc;
34 using llvm::formatv;
35 
36 /// Convenience functions to produce interleaved output with functions returning
37 /// a LogicalResult. This is different than those in STLExtras as functions used
38 /// on each element doesn't return a string.
39 template <typename ForwardIterator, typename UnaryFunctor,
40  typename NullaryFunctor>
41 inline LogicalResult
43  UnaryFunctor eachFn, NullaryFunctor betweenFn) {
44  if (begin == end)
45  return success();
46  if (failed(eachFn(*begin)))
47  return failure();
48  ++begin;
49  for (; begin != end; ++begin) {
50  betweenFn();
51  if (failed(eachFn(*begin)))
52  return failure();
53  }
54  return success();
55 }
56 
57 template <typename Container, typename UnaryFunctor, typename NullaryFunctor>
58 inline LogicalResult interleaveWithError(const Container &c,
59  UnaryFunctor eachFn,
60  NullaryFunctor betweenFn) {
61  return interleaveWithError(c.begin(), c.end(), eachFn, betweenFn);
62 }
63 
64 template <typename Container, typename UnaryFunctor>
65 inline LogicalResult interleaveCommaWithError(const Container &c,
66  raw_ostream &os,
67  UnaryFunctor eachFn) {
68  return interleaveWithError(c.begin(), c.end(), eachFn, [&]() { os << ", "; });
69 }
70 
71 /// Return the precedence of a operator as an integer, higher values
72 /// imply higher precedence.
73 static FailureOr<int> getOperatorPrecedence(Operation *operation) {
75  .Case<emitc::AddOp>([&](auto op) { return 12; })
76  .Case<emitc::ApplyOp>([&](auto op) { return 15; })
77  .Case<emitc::BitwiseAndOp>([&](auto op) { return 7; })
78  .Case<emitc::BitwiseLeftShiftOp>([&](auto op) { return 11; })
79  .Case<emitc::BitwiseNotOp>([&](auto op) { return 15; })
80  .Case<emitc::BitwiseOrOp>([&](auto op) { return 5; })
81  .Case<emitc::BitwiseRightShiftOp>([&](auto op) { return 11; })
82  .Case<emitc::BitwiseXorOp>([&](auto op) { return 6; })
83  .Case<emitc::CallOp>([&](auto op) { return 16; })
84  .Case<emitc::CallOpaqueOp>([&](auto op) { return 16; })
85  .Case<emitc::CastOp>([&](auto op) { return 15; })
86  .Case<emitc::CmpOp>([&](auto op) -> FailureOr<int> {
87  switch (op.getPredicate()) {
88  case emitc::CmpPredicate::eq:
89  case emitc::CmpPredicate::ne:
90  return 8;
91  case emitc::CmpPredicate::lt:
92  case emitc::CmpPredicate::le:
93  case emitc::CmpPredicate::gt:
94  case emitc::CmpPredicate::ge:
95  return 9;
96  case emitc::CmpPredicate::three_way:
97  return 10;
98  }
99  return op->emitError("unsupported cmp predicate");
100  })
101  .Case<emitc::ConditionalOp>([&](auto op) { return 2; })
102  .Case<emitc::DivOp>([&](auto op) { return 13; })
103  .Case<emitc::LogicalAndOp>([&](auto op) { return 4; })
104  .Case<emitc::LogicalNotOp>([&](auto op) { return 15; })
105  .Case<emitc::LogicalOrOp>([&](auto op) { return 3; })
106  .Case<emitc::MulOp>([&](auto op) { return 13; })
107  .Case<emitc::RemOp>([&](auto op) { return 13; })
108  .Case<emitc::SubOp>([&](auto op) { return 12; })
109  .Case<emitc::UnaryMinusOp>([&](auto op) { return 15; })
110  .Case<emitc::UnaryPlusOp>([&](auto op) { return 15; })
111  .Default([](auto op) { return op->emitError("unsupported operation"); });
112 }
113 
114 namespace {
115 /// Emitter that uses dialect specific emitters to emit C++ code.
116 struct CppEmitter {
117  explicit CppEmitter(raw_ostream &os, bool declareVariablesAtTop);
118 
119  /// Emits attribute or returns failure.
120  LogicalResult emitAttribute(Location loc, Attribute attr);
121 
122  /// Emits operation 'op' with/without training semicolon or returns failure.
123  ///
124  /// For operations that should never be followed by a semicolon, like ForOp,
125  /// the `trailingSemicolon` argument is ignored and a semicolon is not
126  /// emitted.
127  LogicalResult emitOperation(Operation &op, bool trailingSemicolon);
128 
129  /// Emits type 'type' or returns failure.
130  LogicalResult emitType(Location loc, Type type);
131 
132  /// Emits array of types as a std::tuple of the emitted types.
133  /// - emits void for an empty array;
134  /// - emits the type of the only element for arrays of size one;
135  /// - emits a std::tuple otherwise;
136  LogicalResult emitTypes(Location loc, ArrayRef<Type> types);
137 
138  /// Emits array of types as a std::tuple of the emitted types independently of
139  /// the array size.
140  LogicalResult emitTupleType(Location loc, ArrayRef<Type> types);
141 
142  /// Emits an assignment for a variable which has been declared previously.
143  LogicalResult emitVariableAssignment(OpResult result);
144 
145  /// Emits a variable declaration for a result of an operation.
146  LogicalResult emitVariableDeclaration(OpResult result,
147  bool trailingSemicolon);
148 
149  /// Emits a declaration of a variable with the given type and name.
150  LogicalResult emitVariableDeclaration(Location loc, Type type,
151  StringRef name);
152 
153  /// Emits the variable declaration and assignment prefix for 'op'.
154  /// - emits separate variable followed by std::tie for multi-valued operation;
155  /// - emits single type followed by variable for single result;
156  /// - emits nothing if no value produced by op;
157  /// Emits final '=' operator where a type is produced. Returns failure if
158  /// any result type could not be converted.
159  LogicalResult emitAssignPrefix(Operation &op);
160 
161  /// Emits a global variable declaration or definition.
162  LogicalResult emitGlobalVariable(GlobalOp op);
163 
164  /// Emits a label for the block.
165  LogicalResult emitLabel(Block &block);
166 
167  /// Emits the operands and atttributes of the operation. All operands are
168  /// emitted first and then all attributes in alphabetical order.
169  LogicalResult emitOperandsAndAttributes(Operation &op,
170  ArrayRef<StringRef> exclude = {});
171 
172  /// Emits the operands of the operation. All operands are emitted in order.
173  LogicalResult emitOperands(Operation &op);
174 
175  /// Emits value as an operands of an operation
176  LogicalResult emitOperand(Value value);
177 
178  /// Emit an expression as a C expression.
179  LogicalResult emitExpression(ExpressionOp expressionOp);
180 
181  /// Insert the expression representing the operation into the value cache.
182  void cacheDeferredOpResult(Value value, StringRef str);
183 
184  /// Return the existing or a new name for a Value.
185  StringRef getOrCreateName(Value val);
186 
187  // Returns the textual representation of a subscript operation.
188  std::string getSubscriptName(emitc::SubscriptOp op);
189 
190  // Returns the textual representation of a member (of object) operation.
191  std::string createMemberAccess(emitc::MemberOp op);
192 
193  // Returns the textual representation of a member of pointer operation.
194  std::string createMemberAccess(emitc::MemberOfPtrOp op);
195 
196  /// Return the existing or a new label of a Block.
197  StringRef getOrCreateName(Block &block);
198 
199  /// Whether to map an mlir integer to a unsigned integer in C++.
200  bool shouldMapToUnsigned(IntegerType::SignednessSemantics val);
201 
202  /// RAII helper function to manage entering/exiting C++ scopes.
203  struct Scope {
204  Scope(CppEmitter &emitter)
205  : valueMapperScope(emitter.valueMapper),
206  blockMapperScope(emitter.blockMapper), emitter(emitter) {
207  emitter.valueInScopeCount.push(emitter.valueInScopeCount.top());
208  emitter.labelInScopeCount.push(emitter.labelInScopeCount.top());
209  }
210  ~Scope() {
211  emitter.valueInScopeCount.pop();
212  emitter.labelInScopeCount.pop();
213  }
214 
215  private:
216  llvm::ScopedHashTableScope<Value, std::string> valueMapperScope;
217  llvm::ScopedHashTableScope<Block *, std::string> blockMapperScope;
218  CppEmitter &emitter;
219  };
220 
221  /// Returns wether the Value is assigned to a C++ variable in the scope.
222  bool hasValueInScope(Value val);
223 
224  // Returns whether a label is assigned to the block.
225  bool hasBlockLabel(Block &block);
226 
227  /// Returns the output stream.
228  raw_indented_ostream &ostream() { return os; };
229 
230  /// Returns if all variables for op results and basic block arguments need to
231  /// be declared at the beginning of a function.
232  bool shouldDeclareVariablesAtTop() { return declareVariablesAtTop; };
233 
234  /// Get expression currently being emitted.
235  ExpressionOp getEmittedExpression() { return emittedExpression; }
236 
237  /// Determine whether given value is part of the expression potentially being
238  /// emitted.
239  bool isPartOfCurrentExpression(Value value) {
240  if (!emittedExpression)
241  return false;
242  Operation *def = value.getDefiningOp();
243  if (!def)
244  return false;
245  auto operandExpression = dyn_cast<ExpressionOp>(def->getParentOp());
246  return operandExpression == emittedExpression;
247  };
248 
249 private:
250  using ValueMapper = llvm::ScopedHashTable<Value, std::string>;
251  using BlockMapper = llvm::ScopedHashTable<Block *, std::string>;
252 
253  /// Output stream to emit to.
255 
256  /// Boolean to enforce that all variables for op results and block
257  /// arguments are declared at the beginning of the function. This also
258  /// includes results from ops located in nested regions.
259  bool declareVariablesAtTop;
260 
261  /// Map from value to name of C++ variable that contain the name.
262  ValueMapper valueMapper;
263 
264  /// Map from block to name of C++ label.
265  BlockMapper blockMapper;
266 
267  /// The number of values in the current scope. This is used to declare the
268  /// names of values in a scope.
269  std::stack<int64_t> valueInScopeCount;
270  std::stack<int64_t> labelInScopeCount;
271 
272  /// State of the current expression being emitted.
273  ExpressionOp emittedExpression;
274  SmallVector<int> emittedExpressionPrecedence;
275 
276  void pushExpressionPrecedence(int precedence) {
277  emittedExpressionPrecedence.push_back(precedence);
278  }
279  void popExpressionPrecedence() { emittedExpressionPrecedence.pop_back(); }
280  static int lowestPrecedence() { return 0; }
281  int getExpressionPrecedence() {
282  if (emittedExpressionPrecedence.empty())
283  return lowestPrecedence();
284  return emittedExpressionPrecedence.back();
285  }
286 };
287 } // namespace
288 
289 /// Determine whether expression \p op should be emitted in a deferred way.
290 static bool hasDeferredEmission(Operation *op) {
291  return isa_and_nonnull<emitc::GetGlobalOp, emitc::LiteralOp, emitc::MemberOp,
292  emitc::MemberOfPtrOp, emitc::SubscriptOp>(op);
293 }
294 
295 /// Determine whether expression \p expressionOp should be emitted inline, i.e.
296 /// as part of its user. This function recommends inlining of any expressions
297 /// that can be inlined unless it is used by another expression, under the
298 /// assumption that any expression fusion/re-materialization was taken care of
299 /// by transformations run by the backend.
300 static bool shouldBeInlined(ExpressionOp expressionOp) {
301  // Do not inline if expression is marked as such.
302  if (expressionOp.getDoNotInline())
303  return false;
304 
305  // Do not inline expressions with side effects to prevent side-effect
306  // reordering.
307  if (expressionOp.hasSideEffects())
308  return false;
309 
310  // Do not inline expressions with multiple uses.
311  Value result = expressionOp.getResult();
312  if (!result.hasOneUse())
313  return false;
314 
315  Operation *user = *result.getUsers().begin();
316 
317  // Do not inline expressions used by operations with deferred emission, since
318  // their translation requires the materialization of variables.
319  if (hasDeferredEmission(user))
320  return false;
321 
322  // Do not inline expressions used by ops with the CExpression trait. If this
323  // was intended, the user could have been merged into the expression op.
324  return !user->hasTrait<OpTrait::emitc::CExpression>();
325 }
326 
327 static LogicalResult printConstantOp(CppEmitter &emitter, Operation *operation,
328  Attribute value) {
329  OpResult result = operation->getResult(0);
330 
331  // Only emit an assignment as the variable was already declared when printing
332  // the FuncOp.
333  if (emitter.shouldDeclareVariablesAtTop()) {
334  // Skip the assignment if the emitc.constant has no value.
335  if (auto oAttr = dyn_cast<emitc::OpaqueAttr>(value)) {
336  if (oAttr.getValue().empty())
337  return success();
338  }
339 
340  if (failed(emitter.emitVariableAssignment(result)))
341  return failure();
342  return emitter.emitAttribute(operation->getLoc(), value);
343  }
344 
345  // Emit a variable declaration for an emitc.constant op without value.
346  if (auto oAttr = dyn_cast<emitc::OpaqueAttr>(value)) {
347  if (oAttr.getValue().empty())
348  // The semicolon gets printed by the emitOperation function.
349  return emitter.emitVariableDeclaration(result,
350  /*trailingSemicolon=*/false);
351  }
352 
353  // Emit a variable declaration.
354  if (failed(emitter.emitAssignPrefix(*operation)))
355  return failure();
356  return emitter.emitAttribute(operation->getLoc(), value);
357 }
358 
359 static LogicalResult printOperation(CppEmitter &emitter,
360  emitc::ConstantOp constantOp) {
361  Operation *operation = constantOp.getOperation();
362  Attribute value = constantOp.getValue();
363 
364  return printConstantOp(emitter, operation, value);
365 }
366 
367 static LogicalResult printOperation(CppEmitter &emitter,
368  emitc::VariableOp variableOp) {
369  Operation *operation = variableOp.getOperation();
370  Attribute value = variableOp.getValue();
371 
372  return printConstantOp(emitter, operation, value);
373 }
374 
375 static LogicalResult printOperation(CppEmitter &emitter,
376  emitc::GlobalOp globalOp) {
377 
378  return emitter.emitGlobalVariable(globalOp);
379 }
380 
381 static LogicalResult printOperation(CppEmitter &emitter,
382  emitc::AssignOp assignOp) {
383  OpResult result = assignOp.getVar().getDefiningOp()->getResult(0);
384 
385  if (failed(emitter.emitVariableAssignment(result)))
386  return failure();
387 
388  return emitter.emitOperand(assignOp.getValue());
389 }
390 
391 static LogicalResult printOperation(CppEmitter &emitter, emitc::LoadOp loadOp) {
392  if (failed(emitter.emitAssignPrefix(*loadOp)))
393  return failure();
394 
395  return emitter.emitOperand(loadOp.getOperand());
396 }
397 
398 static LogicalResult printBinaryOperation(CppEmitter &emitter,
399  Operation *operation,
400  StringRef binaryOperator) {
401  raw_ostream &os = emitter.ostream();
402 
403  if (failed(emitter.emitAssignPrefix(*operation)))
404  return failure();
405 
406  if (failed(emitter.emitOperand(operation->getOperand(0))))
407  return failure();
408 
409  os << " " << binaryOperator << " ";
410 
411  if (failed(emitter.emitOperand(operation->getOperand(1))))
412  return failure();
413 
414  return success();
415 }
416 
417 static LogicalResult printUnaryOperation(CppEmitter &emitter,
418  Operation *operation,
419  StringRef unaryOperator) {
420  raw_ostream &os = emitter.ostream();
421 
422  if (failed(emitter.emitAssignPrefix(*operation)))
423  return failure();
424 
425  os << unaryOperator;
426 
427  if (failed(emitter.emitOperand(operation->getOperand(0))))
428  return failure();
429 
430  return success();
431 }
432 
433 static LogicalResult printOperation(CppEmitter &emitter, emitc::AddOp addOp) {
434  Operation *operation = addOp.getOperation();
435 
436  return printBinaryOperation(emitter, operation, "+");
437 }
438 
439 static LogicalResult printOperation(CppEmitter &emitter, emitc::DivOp divOp) {
440  Operation *operation = divOp.getOperation();
441 
442  return printBinaryOperation(emitter, operation, "/");
443 }
444 
445 static LogicalResult printOperation(CppEmitter &emitter, emitc::MulOp mulOp) {
446  Operation *operation = mulOp.getOperation();
447 
448  return printBinaryOperation(emitter, operation, "*");
449 }
450 
451 static LogicalResult printOperation(CppEmitter &emitter, emitc::RemOp remOp) {
452  Operation *operation = remOp.getOperation();
453 
454  return printBinaryOperation(emitter, operation, "%");
455 }
456 
457 static LogicalResult printOperation(CppEmitter &emitter, emitc::SubOp subOp) {
458  Operation *operation = subOp.getOperation();
459 
460  return printBinaryOperation(emitter, operation, "-");
461 }
462 
463 static LogicalResult emitSwitchCase(CppEmitter &emitter,
464  raw_indented_ostream &os, Region &region) {
465  for (Region::OpIterator iteratorOp = region.op_begin(), end = region.op_end();
466  std::next(iteratorOp) != end; ++iteratorOp) {
467  if (failed(emitter.emitOperation(*iteratorOp, /*trailingSemicolon=*/true)))
468  return failure();
469  }
470  os << "break;\n";
471  return success();
472 }
473 
474 static LogicalResult printOperation(CppEmitter &emitter,
475  emitc::SwitchOp switchOp) {
476  raw_indented_ostream &os = emitter.ostream();
477 
478  os << "\nswitch (";
479  if (failed(emitter.emitOperand(switchOp.getArg())))
480  return failure();
481  os << ") {";
482 
483  for (auto pair : llvm::zip(switchOp.getCases(), switchOp.getCaseRegions())) {
484  os << "\ncase " << std::get<0>(pair) << ": {\n";
485  os.indent();
486 
487  if (failed(emitSwitchCase(emitter, os, std::get<1>(pair))))
488  return failure();
489 
490  os.unindent() << "}";
491  }
492 
493  os << "\ndefault: {\n";
494  os.indent();
495 
496  if (failed(emitSwitchCase(emitter, os, switchOp.getDefaultRegion())))
497  return failure();
498 
499  os.unindent() << "}\n}";
500  return success();
501 }
502 
503 static LogicalResult printOperation(CppEmitter &emitter, emitc::CmpOp cmpOp) {
504  Operation *operation = cmpOp.getOperation();
505 
506  StringRef binaryOperator;
507 
508  switch (cmpOp.getPredicate()) {
509  case emitc::CmpPredicate::eq:
510  binaryOperator = "==";
511  break;
512  case emitc::CmpPredicate::ne:
513  binaryOperator = "!=";
514  break;
515  case emitc::CmpPredicate::lt:
516  binaryOperator = "<";
517  break;
518  case emitc::CmpPredicate::le:
519  binaryOperator = "<=";
520  break;
521  case emitc::CmpPredicate::gt:
522  binaryOperator = ">";
523  break;
524  case emitc::CmpPredicate::ge:
525  binaryOperator = ">=";
526  break;
527  case emitc::CmpPredicate::three_way:
528  binaryOperator = "<=>";
529  break;
530  }
531 
532  return printBinaryOperation(emitter, operation, binaryOperator);
533 }
534 
535 static LogicalResult printOperation(CppEmitter &emitter,
536  emitc::ConditionalOp conditionalOp) {
537  raw_ostream &os = emitter.ostream();
538 
539  if (failed(emitter.emitAssignPrefix(*conditionalOp)))
540  return failure();
541 
542  if (failed(emitter.emitOperand(conditionalOp.getCondition())))
543  return failure();
544 
545  os << " ? ";
546 
547  if (failed(emitter.emitOperand(conditionalOp.getTrueValue())))
548  return failure();
549 
550  os << " : ";
551 
552  if (failed(emitter.emitOperand(conditionalOp.getFalseValue())))
553  return failure();
554 
555  return success();
556 }
557 
558 static LogicalResult printOperation(CppEmitter &emitter,
559  emitc::VerbatimOp verbatimOp) {
560  raw_ostream &os = emitter.ostream();
561 
562  os << verbatimOp.getValue();
563 
564  return success();
565 }
566 
567 static LogicalResult printOperation(CppEmitter &emitter,
568  cf::BranchOp branchOp) {
569  raw_ostream &os = emitter.ostream();
570  Block &successor = *branchOp.getSuccessor();
571 
572  for (auto pair :
573  llvm::zip(branchOp.getOperands(), successor.getArguments())) {
574  Value &operand = std::get<0>(pair);
575  BlockArgument &argument = std::get<1>(pair);
576  os << emitter.getOrCreateName(argument) << " = "
577  << emitter.getOrCreateName(operand) << ";\n";
578  }
579 
580  os << "goto ";
581  if (!(emitter.hasBlockLabel(successor)))
582  return branchOp.emitOpError("unable to find label for successor block");
583  os << emitter.getOrCreateName(successor);
584  return success();
585 }
586 
587 static LogicalResult printOperation(CppEmitter &emitter,
588  cf::CondBranchOp condBranchOp) {
589  raw_indented_ostream &os = emitter.ostream();
590  Block &trueSuccessor = *condBranchOp.getTrueDest();
591  Block &falseSuccessor = *condBranchOp.getFalseDest();
592 
593  os << "if (" << emitter.getOrCreateName(condBranchOp.getCondition())
594  << ") {\n";
595 
596  os.indent();
597 
598  // If condition is true.
599  for (auto pair : llvm::zip(condBranchOp.getTrueOperands(),
600  trueSuccessor.getArguments())) {
601  Value &operand = std::get<0>(pair);
602  BlockArgument &argument = std::get<1>(pair);
603  os << emitter.getOrCreateName(argument) << " = "
604  << emitter.getOrCreateName(operand) << ";\n";
605  }
606 
607  os << "goto ";
608  if (!(emitter.hasBlockLabel(trueSuccessor))) {
609  return condBranchOp.emitOpError("unable to find label for successor block");
610  }
611  os << emitter.getOrCreateName(trueSuccessor) << ";\n";
612  os.unindent() << "} else {\n";
613  os.indent();
614  // If condition is false.
615  for (auto pair : llvm::zip(condBranchOp.getFalseOperands(),
616  falseSuccessor.getArguments())) {
617  Value &operand = std::get<0>(pair);
618  BlockArgument &argument = std::get<1>(pair);
619  os << emitter.getOrCreateName(argument) << " = "
620  << emitter.getOrCreateName(operand) << ";\n";
621  }
622 
623  os << "goto ";
624  if (!(emitter.hasBlockLabel(falseSuccessor))) {
625  return condBranchOp.emitOpError()
626  << "unable to find label for successor block";
627  }
628  os << emitter.getOrCreateName(falseSuccessor) << ";\n";
629  os.unindent() << "}";
630  return success();
631 }
632 
633 static LogicalResult printCallOperation(CppEmitter &emitter, Operation *callOp,
634  StringRef callee) {
635  if (failed(emitter.emitAssignPrefix(*callOp)))
636  return failure();
637 
638  raw_ostream &os = emitter.ostream();
639  os << callee << "(";
640  if (failed(emitter.emitOperands(*callOp)))
641  return failure();
642  os << ")";
643  return success();
644 }
645 
646 static LogicalResult printOperation(CppEmitter &emitter, func::CallOp callOp) {
647  Operation *operation = callOp.getOperation();
648  StringRef callee = callOp.getCallee();
649 
650  return printCallOperation(emitter, operation, callee);
651 }
652 
653 static LogicalResult printOperation(CppEmitter &emitter, emitc::CallOp callOp) {
654  Operation *operation = callOp.getOperation();
655  StringRef callee = callOp.getCallee();
656 
657  return printCallOperation(emitter, operation, callee);
658 }
659 
660 static LogicalResult printOperation(CppEmitter &emitter,
661  emitc::CallOpaqueOp callOpaqueOp) {
662  raw_ostream &os = emitter.ostream();
663  Operation &op = *callOpaqueOp.getOperation();
664 
665  if (failed(emitter.emitAssignPrefix(op)))
666  return failure();
667  os << callOpaqueOp.getCallee();
668 
669  auto emitArgs = [&](Attribute attr) -> LogicalResult {
670  if (auto t = dyn_cast<IntegerAttr>(attr)) {
671  // Index attributes are treated specially as operand index.
672  if (t.getType().isIndex()) {
673  int64_t idx = t.getInt();
674  Value operand = op.getOperand(idx);
675  if (!emitter.hasValueInScope(operand))
676  return op.emitOpError("operand ")
677  << idx << "'s value not defined in scope";
678  os << emitter.getOrCreateName(operand);
679  return success();
680  }
681  }
682  if (failed(emitter.emitAttribute(op.getLoc(), attr)))
683  return failure();
684 
685  return success();
686  };
687 
688  if (callOpaqueOp.getTemplateArgs()) {
689  os << "<";
690  if (failed(interleaveCommaWithError(*callOpaqueOp.getTemplateArgs(), os,
691  emitArgs)))
692  return failure();
693  os << ">";
694  }
695 
696  os << "(";
697 
698  LogicalResult emittedArgs =
699  callOpaqueOp.getArgs()
700  ? interleaveCommaWithError(*callOpaqueOp.getArgs(), os, emitArgs)
701  : emitter.emitOperands(op);
702  if (failed(emittedArgs))
703  return failure();
704  os << ")";
705  return success();
706 }
707 
708 static LogicalResult printOperation(CppEmitter &emitter,
709  emitc::ApplyOp applyOp) {
710  raw_ostream &os = emitter.ostream();
711  Operation &op = *applyOp.getOperation();
712 
713  if (failed(emitter.emitAssignPrefix(op)))
714  return failure();
715  os << applyOp.getApplicableOperator();
716  os << emitter.getOrCreateName(applyOp.getOperand());
717 
718  return success();
719 }
720 
721 static LogicalResult printOperation(CppEmitter &emitter,
722  emitc::BitwiseAndOp bitwiseAndOp) {
723  Operation *operation = bitwiseAndOp.getOperation();
724  return printBinaryOperation(emitter, operation, "&");
725 }
726 
727 static LogicalResult
728 printOperation(CppEmitter &emitter,
729  emitc::BitwiseLeftShiftOp bitwiseLeftShiftOp) {
730  Operation *operation = bitwiseLeftShiftOp.getOperation();
731  return printBinaryOperation(emitter, operation, "<<");
732 }
733 
734 static LogicalResult printOperation(CppEmitter &emitter,
735  emitc::BitwiseNotOp bitwiseNotOp) {
736  Operation *operation = bitwiseNotOp.getOperation();
737  return printUnaryOperation(emitter, operation, "~");
738 }
739 
740 static LogicalResult printOperation(CppEmitter &emitter,
741  emitc::BitwiseOrOp bitwiseOrOp) {
742  Operation *operation = bitwiseOrOp.getOperation();
743  return printBinaryOperation(emitter, operation, "|");
744 }
745 
746 static LogicalResult
747 printOperation(CppEmitter &emitter,
748  emitc::BitwiseRightShiftOp bitwiseRightShiftOp) {
749  Operation *operation = bitwiseRightShiftOp.getOperation();
750  return printBinaryOperation(emitter, operation, ">>");
751 }
752 
753 static LogicalResult printOperation(CppEmitter &emitter,
754  emitc::BitwiseXorOp bitwiseXorOp) {
755  Operation *operation = bitwiseXorOp.getOperation();
756  return printBinaryOperation(emitter, operation, "^");
757 }
758 
759 static LogicalResult printOperation(CppEmitter &emitter,
760  emitc::UnaryPlusOp unaryPlusOp) {
761  Operation *operation = unaryPlusOp.getOperation();
762  return printUnaryOperation(emitter, operation, "+");
763 }
764 
765 static LogicalResult printOperation(CppEmitter &emitter,
766  emitc::UnaryMinusOp unaryMinusOp) {
767  Operation *operation = unaryMinusOp.getOperation();
768  return printUnaryOperation(emitter, operation, "-");
769 }
770 
771 static LogicalResult printOperation(CppEmitter &emitter, emitc::CastOp castOp) {
772  raw_ostream &os = emitter.ostream();
773  Operation &op = *castOp.getOperation();
774 
775  if (failed(emitter.emitAssignPrefix(op)))
776  return failure();
777  os << "(";
778  if (failed(emitter.emitType(op.getLoc(), op.getResult(0).getType())))
779  return failure();
780  os << ") ";
781  return emitter.emitOperand(castOp.getOperand());
782 }
783 
784 static LogicalResult printOperation(CppEmitter &emitter,
785  emitc::ExpressionOp expressionOp) {
786  if (shouldBeInlined(expressionOp))
787  return success();
788 
789  Operation &op = *expressionOp.getOperation();
790 
791  if (failed(emitter.emitAssignPrefix(op)))
792  return failure();
793 
794  return emitter.emitExpression(expressionOp);
795 }
796 
797 static LogicalResult printOperation(CppEmitter &emitter,
798  emitc::IncludeOp includeOp) {
799  raw_ostream &os = emitter.ostream();
800 
801  os << "#include ";
802  if (includeOp.getIsStandardInclude())
803  os << "<" << includeOp.getInclude() << ">";
804  else
805  os << "\"" << includeOp.getInclude() << "\"";
806 
807  return success();
808 }
809 
810 static LogicalResult printOperation(CppEmitter &emitter,
811  emitc::LogicalAndOp logicalAndOp) {
812  Operation *operation = logicalAndOp.getOperation();
813  return printBinaryOperation(emitter, operation, "&&");
814 }
815 
816 static LogicalResult printOperation(CppEmitter &emitter,
817  emitc::LogicalNotOp logicalNotOp) {
818  Operation *operation = logicalNotOp.getOperation();
819  return printUnaryOperation(emitter, operation, "!");
820 }
821 
822 static LogicalResult printOperation(CppEmitter &emitter,
823  emitc::LogicalOrOp logicalOrOp) {
824  Operation *operation = logicalOrOp.getOperation();
825  return printBinaryOperation(emitter, operation, "||");
826 }
827 
828 static LogicalResult printOperation(CppEmitter &emitter, emitc::ForOp forOp) {
829 
830  raw_indented_ostream &os = emitter.ostream();
831 
832  // Utility function to determine whether a value is an expression that will be
833  // inlined, and as such should be wrapped in parentheses in order to guarantee
834  // its precedence and associativity.
835  auto requiresParentheses = [&](Value value) {
836  auto expressionOp =
837  dyn_cast_if_present<ExpressionOp>(value.getDefiningOp());
838  if (!expressionOp)
839  return false;
840  return shouldBeInlined(expressionOp);
841  };
842 
843  os << "for (";
844  if (failed(
845  emitter.emitType(forOp.getLoc(), forOp.getInductionVar().getType())))
846  return failure();
847  os << " ";
848  os << emitter.getOrCreateName(forOp.getInductionVar());
849  os << " = ";
850  if (failed(emitter.emitOperand(forOp.getLowerBound())))
851  return failure();
852  os << "; ";
853  os << emitter.getOrCreateName(forOp.getInductionVar());
854  os << " < ";
855  Value upperBound = forOp.getUpperBound();
856  bool upperBoundRequiresParentheses = requiresParentheses(upperBound);
857  if (upperBoundRequiresParentheses)
858  os << "(";
859  if (failed(emitter.emitOperand(upperBound)))
860  return failure();
861  if (upperBoundRequiresParentheses)
862  os << ")";
863  os << "; ";
864  os << emitter.getOrCreateName(forOp.getInductionVar());
865  os << " += ";
866  if (failed(emitter.emitOperand(forOp.getStep())))
867  return failure();
868  os << ") {\n";
869  os.indent();
870 
871  Region &forRegion = forOp.getRegion();
872  auto regionOps = forRegion.getOps();
873 
874  // We skip the trailing yield op.
875  for (auto it = regionOps.begin(); std::next(it) != regionOps.end(); ++it) {
876  if (failed(emitter.emitOperation(*it, /*trailingSemicolon=*/true)))
877  return failure();
878  }
879 
880  os.unindent() << "}";
881 
882  return success();
883 }
884 
885 static LogicalResult printOperation(CppEmitter &emitter, emitc::IfOp ifOp) {
886  raw_indented_ostream &os = emitter.ostream();
887 
888  // Helper function to emit all ops except the last one, expected to be
889  // emitc::yield.
890  auto emitAllExceptLast = [&emitter](Region &region) {
891  Region::OpIterator it = region.op_begin(), end = region.op_end();
892  for (; std::next(it) != end; ++it) {
893  if (failed(emitter.emitOperation(*it, /*trailingSemicolon=*/true)))
894  return failure();
895  }
896  assert(isa<emitc::YieldOp>(*it) &&
897  "Expected last operation in the region to be emitc::yield");
898  return success();
899  };
900 
901  os << "if (";
902  if (failed(emitter.emitOperand(ifOp.getCondition())))
903  return failure();
904  os << ") {\n";
905  os.indent();
906  if (failed(emitAllExceptLast(ifOp.getThenRegion())))
907  return failure();
908  os.unindent() << "}";
909 
910  Region &elseRegion = ifOp.getElseRegion();
911  if (!elseRegion.empty()) {
912  os << " else {\n";
913  os.indent();
914  if (failed(emitAllExceptLast(elseRegion)))
915  return failure();
916  os.unindent() << "}";
917  }
918 
919  return success();
920 }
921 
922 static LogicalResult printOperation(CppEmitter &emitter,
923  func::ReturnOp returnOp) {
924  raw_ostream &os = emitter.ostream();
925  os << "return";
926  switch (returnOp.getNumOperands()) {
927  case 0:
928  return success();
929  case 1:
930  os << " ";
931  if (failed(emitter.emitOperand(returnOp.getOperand(0))))
932  return failure();
933  return success();
934  default:
935  os << " std::make_tuple(";
936  if (failed(emitter.emitOperandsAndAttributes(*returnOp.getOperation())))
937  return failure();
938  os << ")";
939  return success();
940  }
941 }
942 
943 static LogicalResult printOperation(CppEmitter &emitter,
944  emitc::ReturnOp returnOp) {
945  raw_ostream &os = emitter.ostream();
946  os << "return";
947  if (returnOp.getNumOperands() == 0)
948  return success();
949 
950  os << " ";
951  if (failed(emitter.emitOperand(returnOp.getOperand())))
952  return failure();
953  return success();
954 }
955 
956 static LogicalResult printOperation(CppEmitter &emitter, ModuleOp moduleOp) {
957  CppEmitter::Scope scope(emitter);
958 
959  for (Operation &op : moduleOp) {
960  if (failed(emitter.emitOperation(op, /*trailingSemicolon=*/false)))
961  return failure();
962  }
963  return success();
964 }
965 
966 static LogicalResult printFunctionArgs(CppEmitter &emitter,
967  Operation *functionOp,
968  ArrayRef<Type> arguments) {
969  raw_indented_ostream &os = emitter.ostream();
970 
971  return (
972  interleaveCommaWithError(arguments, os, [&](Type arg) -> LogicalResult {
973  return emitter.emitType(functionOp->getLoc(), arg);
974  }));
975 }
976 
977 static LogicalResult printFunctionArgs(CppEmitter &emitter,
978  Operation *functionOp,
979  Region::BlockArgListType arguments) {
980  raw_indented_ostream &os = emitter.ostream();
981 
982  return (interleaveCommaWithError(
983  arguments, os, [&](BlockArgument arg) -> LogicalResult {
984  return emitter.emitVariableDeclaration(
985  functionOp->getLoc(), arg.getType(), emitter.getOrCreateName(arg));
986  }));
987 }
988 
989 static LogicalResult printFunctionBody(CppEmitter &emitter,
990  Operation *functionOp,
991  Region::BlockListType &blocks) {
992  raw_indented_ostream &os = emitter.ostream();
993  os.indent();
994 
995  if (emitter.shouldDeclareVariablesAtTop()) {
996  // Declare all variables that hold op results including those from nested
997  // regions.
998  WalkResult result =
999  functionOp->walk<WalkOrder::PreOrder>([&](Operation *op) -> WalkResult {
1000  if (isa<emitc::ExpressionOp>(op->getParentOp()) ||
1001  (isa<emitc::ExpressionOp>(op) &&
1002  shouldBeInlined(cast<emitc::ExpressionOp>(op))))
1003  return WalkResult::skip();
1004  for (OpResult result : op->getResults()) {
1005  if (failed(emitter.emitVariableDeclaration(
1006  result, /*trailingSemicolon=*/true))) {
1007  return WalkResult(
1008  op->emitError("unable to declare result variable for op"));
1009  }
1010  }
1011  return WalkResult::advance();
1012  });
1013  if (result.wasInterrupted())
1014  return failure();
1015  }
1016 
1017  // Create label names for basic blocks.
1018  for (Block &block : blocks) {
1019  emitter.getOrCreateName(block);
1020  }
1021 
1022  // Declare variables for basic block arguments.
1023  for (Block &block : llvm::drop_begin(blocks)) {
1024  for (BlockArgument &arg : block.getArguments()) {
1025  if (emitter.hasValueInScope(arg))
1026  return functionOp->emitOpError(" block argument #")
1027  << arg.getArgNumber() << " is out of scope";
1028  if (isa<ArrayType, LValueType>(arg.getType()))
1029  return functionOp->emitOpError("cannot emit block argument #")
1030  << arg.getArgNumber() << " with type " << arg.getType();
1031  if (failed(
1032  emitter.emitType(block.getParentOp()->getLoc(), arg.getType()))) {
1033  return failure();
1034  }
1035  os << " " << emitter.getOrCreateName(arg) << ";\n";
1036  }
1037  }
1038 
1039  for (Block &block : blocks) {
1040  // Only print a label if the block has predecessors.
1041  if (!block.hasNoPredecessors()) {
1042  if (failed(emitter.emitLabel(block)))
1043  return failure();
1044  }
1045  for (Operation &op : block.getOperations()) {
1046  if (failed(emitter.emitOperation(op, /*trailingSemicolon=*/true)))
1047  return failure();
1048  }
1049  }
1050 
1051  os.unindent();
1052 
1053  return success();
1054 }
1055 
1056 static LogicalResult printOperation(CppEmitter &emitter,
1057  func::FuncOp functionOp) {
1058  // We need to declare variables at top if the function has multiple blocks.
1059  if (!emitter.shouldDeclareVariablesAtTop() &&
1060  functionOp.getBlocks().size() > 1) {
1061  return functionOp.emitOpError(
1062  "with multiple blocks needs variables declared at top");
1063  }
1064 
1065  if (llvm::any_of(functionOp.getArgumentTypes(), llvm::IsaPred<LValueType>)) {
1066  return functionOp.emitOpError()
1067  << "cannot emit lvalue type as argument type";
1068  }
1069 
1070  if (llvm::any_of(functionOp.getResultTypes(), llvm::IsaPred<ArrayType>)) {
1071  return functionOp.emitOpError() << "cannot emit array type as result type";
1072  }
1073 
1074  CppEmitter::Scope scope(emitter);
1075  raw_indented_ostream &os = emitter.ostream();
1076  if (failed(emitter.emitTypes(functionOp.getLoc(),
1077  functionOp.getFunctionType().getResults())))
1078  return failure();
1079  os << " " << functionOp.getName();
1080 
1081  os << "(";
1082  Operation *operation = functionOp.getOperation();
1083  if (failed(printFunctionArgs(emitter, operation, functionOp.getArguments())))
1084  return failure();
1085  os << ") {\n";
1086  if (failed(printFunctionBody(emitter, operation, functionOp.getBlocks())))
1087  return failure();
1088  os << "}\n";
1089 
1090  return success();
1091 }
1092 
1093 static LogicalResult printOperation(CppEmitter &emitter,
1094  emitc::FuncOp functionOp) {
1095  // We need to declare variables at top if the function has multiple blocks.
1096  if (!emitter.shouldDeclareVariablesAtTop() &&
1097  functionOp.getBlocks().size() > 1) {
1098  return functionOp.emitOpError(
1099  "with multiple blocks needs variables declared at top");
1100  }
1101 
1102  CppEmitter::Scope scope(emitter);
1103  raw_indented_ostream &os = emitter.ostream();
1104  if (functionOp.getSpecifiers()) {
1105  for (Attribute specifier : functionOp.getSpecifiersAttr()) {
1106  os << cast<StringAttr>(specifier).str() << " ";
1107  }
1108  }
1109 
1110  if (failed(emitter.emitTypes(functionOp.getLoc(),
1111  functionOp.getFunctionType().getResults())))
1112  return failure();
1113  os << " " << functionOp.getName();
1114 
1115  os << "(";
1116  Operation *operation = functionOp.getOperation();
1117  if (functionOp.isExternal()) {
1118  if (failed(printFunctionArgs(emitter, operation,
1119  functionOp.getArgumentTypes())))
1120  return failure();
1121  os << ");";
1122  return success();
1123  }
1124  if (failed(printFunctionArgs(emitter, operation, functionOp.getArguments())))
1125  return failure();
1126  os << ") {\n";
1127  if (failed(printFunctionBody(emitter, operation, functionOp.getBlocks())))
1128  return failure();
1129  os << "}\n";
1130 
1131  return success();
1132 }
1133 
1134 static LogicalResult printOperation(CppEmitter &emitter,
1135  DeclareFuncOp declareFuncOp) {
1136  CppEmitter::Scope scope(emitter);
1137  raw_indented_ostream &os = emitter.ostream();
1138 
1139  auto functionOp = SymbolTable::lookupNearestSymbolFrom<emitc::FuncOp>(
1140  declareFuncOp, declareFuncOp.getSymNameAttr());
1141 
1142  if (!functionOp)
1143  return failure();
1144 
1145  if (functionOp.getSpecifiers()) {
1146  for (Attribute specifier : functionOp.getSpecifiersAttr()) {
1147  os << cast<StringAttr>(specifier).str() << " ";
1148  }
1149  }
1150 
1151  if (failed(emitter.emitTypes(functionOp.getLoc(),
1152  functionOp.getFunctionType().getResults())))
1153  return failure();
1154  os << " " << functionOp.getName();
1155 
1156  os << "(";
1157  Operation *operation = functionOp.getOperation();
1158  if (failed(printFunctionArgs(emitter, operation, functionOp.getArguments())))
1159  return failure();
1160  os << ");";
1161 
1162  return success();
1163 }
1164 
1165 CppEmitter::CppEmitter(raw_ostream &os, bool declareVariablesAtTop)
1166  : os(os), declareVariablesAtTop(declareVariablesAtTop) {
1167  valueInScopeCount.push(0);
1168  labelInScopeCount.push(0);
1169 }
1170 
1171 std::string CppEmitter::getSubscriptName(emitc::SubscriptOp op) {
1172  std::string out;
1173  llvm::raw_string_ostream ss(out);
1174  ss << getOrCreateName(op.getValue());
1175  for (auto index : op.getIndices()) {
1176  ss << "[" << getOrCreateName(index) << "]";
1177  }
1178  return out;
1179 }
1180 
1181 std::string CppEmitter::createMemberAccess(emitc::MemberOp op) {
1182  std::string out;
1183  llvm::raw_string_ostream ss(out);
1184  ss << getOrCreateName(op.getOperand());
1185  ss << "." << op.getMember();
1186  return out;
1187 }
1188 
1189 std::string CppEmitter::createMemberAccess(emitc::MemberOfPtrOp op) {
1190  std::string out;
1191  llvm::raw_string_ostream ss(out);
1192  ss << getOrCreateName(op.getOperand());
1193  ss << "->" << op.getMember();
1194  return out;
1195 }
1196 
1197 void CppEmitter::cacheDeferredOpResult(Value value, StringRef str) {
1198  if (!valueMapper.count(value))
1199  valueMapper.insert(value, str.str());
1200 }
1201 
1202 /// Return the existing or a new name for a Value.
1203 StringRef CppEmitter::getOrCreateName(Value val) {
1204  if (!valueMapper.count(val)) {
1205  assert(!hasDeferredEmission(val.getDefiningOp()) &&
1206  "cacheDeferredOpResult should have been called on this value, "
1207  "update the emitOperation function.");
1208  valueMapper.insert(val, formatv("v{0}", ++valueInScopeCount.top()));
1209  }
1210  return *valueMapper.begin(val);
1211 }
1212 
1213 /// Return the existing or a new label for a Block.
1214 StringRef CppEmitter::getOrCreateName(Block &block) {
1215  if (!blockMapper.count(&block))
1216  blockMapper.insert(&block, formatv("label{0}", ++labelInScopeCount.top()));
1217  return *blockMapper.begin(&block);
1218 }
1219 
1220 bool CppEmitter::shouldMapToUnsigned(IntegerType::SignednessSemantics val) {
1221  switch (val) {
1222  case IntegerType::Signless:
1223  return false;
1224  case IntegerType::Signed:
1225  return false;
1226  case IntegerType::Unsigned:
1227  return true;
1228  }
1229  llvm_unreachable("Unexpected IntegerType::SignednessSemantics");
1230 }
1231 
1232 bool CppEmitter::hasValueInScope(Value val) { return valueMapper.count(val); }
1233 
1234 bool CppEmitter::hasBlockLabel(Block &block) {
1235  return blockMapper.count(&block);
1236 }
1237 
1238 LogicalResult CppEmitter::emitAttribute(Location loc, Attribute attr) {
1239  auto printInt = [&](const APInt &val, bool isUnsigned) {
1240  if (val.getBitWidth() == 1) {
1241  if (val.getBoolValue())
1242  os << "true";
1243  else
1244  os << "false";
1245  } else {
1246  SmallString<128> strValue;
1247  val.toString(strValue, 10, !isUnsigned, false);
1248  os << strValue;
1249  }
1250  };
1251 
1252  auto printFloat = [&](const APFloat &val) {
1253  if (val.isFinite()) {
1254  SmallString<128> strValue;
1255  // Use default values of toString except don't truncate zeros.
1256  val.toString(strValue, 0, 0, false);
1257  os << strValue;
1258  switch (llvm::APFloatBase::SemanticsToEnum(val.getSemantics())) {
1259  case llvm::APFloatBase::S_IEEEhalf:
1260  os << "f16";
1261  break;
1262  case llvm::APFloatBase::S_BFloat:
1263  os << "bf16";
1264  break;
1265  case llvm::APFloatBase::S_IEEEsingle:
1266  os << "f";
1267  break;
1268  case llvm::APFloatBase::S_IEEEdouble:
1269  break;
1270  default:
1271  llvm_unreachable("unsupported floating point type");
1272  };
1273  } else if (val.isNaN()) {
1274  os << "NAN";
1275  } else if (val.isInfinity()) {
1276  if (val.isNegative())
1277  os << "-";
1278  os << "INFINITY";
1279  }
1280  };
1281 
1282  // Print floating point attributes.
1283  if (auto fAttr = dyn_cast<FloatAttr>(attr)) {
1284  if (!isa<Float16Type, BFloat16Type, Float32Type, Float64Type>(
1285  fAttr.getType())) {
1286  return emitError(
1287  loc, "expected floating point attribute to be f16, bf16, f32 or f64");
1288  }
1289  printFloat(fAttr.getValue());
1290  return success();
1291  }
1292  if (auto dense = dyn_cast<DenseFPElementsAttr>(attr)) {
1293  if (!isa<Float16Type, BFloat16Type, Float32Type, Float64Type>(
1294  dense.getElementType())) {
1295  return emitError(
1296  loc, "expected floating point attribute to be f16, bf16, f32 or f64");
1297  }
1298  os << '{';
1299  interleaveComma(dense, os, [&](const APFloat &val) { printFloat(val); });
1300  os << '}';
1301  return success();
1302  }
1303 
1304  // Print integer attributes.
1305  if (auto iAttr = dyn_cast<IntegerAttr>(attr)) {
1306  if (auto iType = dyn_cast<IntegerType>(iAttr.getType())) {
1307  printInt(iAttr.getValue(), shouldMapToUnsigned(iType.getSignedness()));
1308  return success();
1309  }
1310  if (auto iType = dyn_cast<IndexType>(iAttr.getType())) {
1311  printInt(iAttr.getValue(), false);
1312  return success();
1313  }
1314  }
1315  if (auto dense = dyn_cast<DenseIntElementsAttr>(attr)) {
1316  if (auto iType = dyn_cast<IntegerType>(
1317  cast<TensorType>(dense.getType()).getElementType())) {
1318  os << '{';
1319  interleaveComma(dense, os, [&](const APInt &val) {
1320  printInt(val, shouldMapToUnsigned(iType.getSignedness()));
1321  });
1322  os << '}';
1323  return success();
1324  }
1325  if (auto iType = dyn_cast<IndexType>(
1326  cast<TensorType>(dense.getType()).getElementType())) {
1327  os << '{';
1328  interleaveComma(dense, os,
1329  [&](const APInt &val) { printInt(val, false); });
1330  os << '}';
1331  return success();
1332  }
1333  }
1334 
1335  // Print opaque attributes.
1336  if (auto oAttr = dyn_cast<emitc::OpaqueAttr>(attr)) {
1337  os << oAttr.getValue();
1338  return success();
1339  }
1340 
1341  // Print symbolic reference attributes.
1342  if (auto sAttr = dyn_cast<SymbolRefAttr>(attr)) {
1343  if (sAttr.getNestedReferences().size() > 1)
1344  return emitError(loc, "attribute has more than 1 nested reference");
1345  os << sAttr.getRootReference().getValue();
1346  return success();
1347  }
1348 
1349  // Print type attributes.
1350  if (auto type = dyn_cast<TypeAttr>(attr))
1351  return emitType(loc, type.getValue());
1352 
1353  return emitError(loc, "cannot emit attribute: ") << attr;
1354 }
1355 
1356 LogicalResult CppEmitter::emitExpression(ExpressionOp expressionOp) {
1357  assert(emittedExpressionPrecedence.empty() &&
1358  "Expected precedence stack to be empty");
1359  Operation *rootOp = expressionOp.getRootOp();
1360 
1361  emittedExpression = expressionOp;
1362  FailureOr<int> precedence = getOperatorPrecedence(rootOp);
1363  if (failed(precedence))
1364  return failure();
1365  pushExpressionPrecedence(precedence.value());
1366 
1367  if (failed(emitOperation(*rootOp, /*trailingSemicolon=*/false)))
1368  return failure();
1369 
1370  popExpressionPrecedence();
1371  assert(emittedExpressionPrecedence.empty() &&
1372  "Expected precedence stack to be empty");
1373  emittedExpression = nullptr;
1374 
1375  return success();
1376 }
1377 
1378 LogicalResult CppEmitter::emitOperand(Value value) {
1379  if (isPartOfCurrentExpression(value)) {
1380  Operation *def = value.getDefiningOp();
1381  assert(def && "Expected operand to be defined by an operation");
1382  FailureOr<int> precedence = getOperatorPrecedence(def);
1383  if (failed(precedence))
1384  return failure();
1385 
1386  // Sub-expressions with equal or lower precedence need to be parenthesized,
1387  // as they might be evaluated in the wrong order depending on the shape of
1388  // the expression tree.
1389  bool encloseInParenthesis = precedence.value() <= getExpressionPrecedence();
1390  if (encloseInParenthesis)
1391  os << "(";
1392  pushExpressionPrecedence(precedence.value());
1393 
1394  if (failed(emitOperation(*def, /*trailingSemicolon=*/false)))
1395  return failure();
1396 
1397  if (encloseInParenthesis)
1398  os << ")";
1399 
1400  popExpressionPrecedence();
1401  return success();
1402  }
1403 
1404  auto expressionOp = dyn_cast_if_present<ExpressionOp>(value.getDefiningOp());
1405  if (expressionOp && shouldBeInlined(expressionOp))
1406  return emitExpression(expressionOp);
1407 
1408  os << getOrCreateName(value);
1409  return success();
1410 }
1411 
1412 LogicalResult CppEmitter::emitOperands(Operation &op) {
1413  return interleaveCommaWithError(op.getOperands(), os, [&](Value operand) {
1414  // If an expression is being emitted, push lowest precedence as these
1415  // operands are either wrapped by parenthesis.
1416  if (getEmittedExpression())
1417  pushExpressionPrecedence(lowestPrecedence());
1418  if (failed(emitOperand(operand)))
1419  return failure();
1420  if (getEmittedExpression())
1421  popExpressionPrecedence();
1422  return success();
1423  });
1424 }
1425 
1426 LogicalResult
1427 CppEmitter::emitOperandsAndAttributes(Operation &op,
1428  ArrayRef<StringRef> exclude) {
1429  if (failed(emitOperands(op)))
1430  return failure();
1431  // Insert comma in between operands and non-filtered attributes if needed.
1432  if (op.getNumOperands() > 0) {
1433  for (NamedAttribute attr : op.getAttrs()) {
1434  if (!llvm::is_contained(exclude, attr.getName().strref())) {
1435  os << ", ";
1436  break;
1437  }
1438  }
1439  }
1440  // Emit attributes.
1441  auto emitNamedAttribute = [&](NamedAttribute attr) -> LogicalResult {
1442  if (llvm::is_contained(exclude, attr.getName().strref()))
1443  return success();
1444  os << "/* " << attr.getName().getValue() << " */";
1445  if (failed(emitAttribute(op.getLoc(), attr.getValue())))
1446  return failure();
1447  return success();
1448  };
1449  return interleaveCommaWithError(op.getAttrs(), os, emitNamedAttribute);
1450 }
1451 
1452 LogicalResult CppEmitter::emitVariableAssignment(OpResult result) {
1453  if (!hasValueInScope(result)) {
1454  return result.getDefiningOp()->emitOpError(
1455  "result variable for the operation has not been declared");
1456  }
1457  os << getOrCreateName(result) << " = ";
1458  return success();
1459 }
1460 
1461 LogicalResult CppEmitter::emitVariableDeclaration(OpResult result,
1462  bool trailingSemicolon) {
1463  if (hasDeferredEmission(result.getDefiningOp()))
1464  return success();
1465  if (hasValueInScope(result)) {
1466  return result.getDefiningOp()->emitError(
1467  "result variable for the operation already declared");
1468  }
1469  if (failed(emitVariableDeclaration(result.getOwner()->getLoc(),
1470  result.getType(),
1471  getOrCreateName(result))))
1472  return failure();
1473  if (trailingSemicolon)
1474  os << ";\n";
1475  return success();
1476 }
1477 
1478 LogicalResult CppEmitter::emitGlobalVariable(GlobalOp op) {
1479  if (op.getExternSpecifier())
1480  os << "extern ";
1481  else if (op.getStaticSpecifier())
1482  os << "static ";
1483  if (op.getConstSpecifier())
1484  os << "const ";
1485 
1486  if (failed(emitVariableDeclaration(op->getLoc(), op.getType(),
1487  op.getSymName()))) {
1488  return failure();
1489  }
1490 
1491  std::optional<Attribute> initialValue = op.getInitialValue();
1492  if (initialValue) {
1493  os << " = ";
1494  if (failed(emitAttribute(op->getLoc(), *initialValue)))
1495  return failure();
1496  }
1497 
1498  os << ";";
1499  return success();
1500 }
1501 
1502 LogicalResult CppEmitter::emitAssignPrefix(Operation &op) {
1503  // If op is being emitted as part of an expression, bail out.
1504  if (getEmittedExpression())
1505  return success();
1506 
1507  switch (op.getNumResults()) {
1508  case 0:
1509  break;
1510  case 1: {
1511  OpResult result = op.getResult(0);
1512  if (shouldDeclareVariablesAtTop()) {
1513  if (failed(emitVariableAssignment(result)))
1514  return failure();
1515  } else {
1516  if (failed(emitVariableDeclaration(result, /*trailingSemicolon=*/false)))
1517  return failure();
1518  os << " = ";
1519  }
1520  break;
1521  }
1522  default:
1523  if (!shouldDeclareVariablesAtTop()) {
1524  for (OpResult result : op.getResults()) {
1525  if (failed(emitVariableDeclaration(result, /*trailingSemicolon=*/true)))
1526  return failure();
1527  }
1528  }
1529  os << "std::tie(";
1530  interleaveComma(op.getResults(), os,
1531  [&](Value result) { os << getOrCreateName(result); });
1532  os << ") = ";
1533  }
1534  return success();
1535 }
1536 
1537 LogicalResult CppEmitter::emitLabel(Block &block) {
1538  if (!hasBlockLabel(block))
1539  return block.getParentOp()->emitError("label for block not found");
1540  // FIXME: Add feature in `raw_indented_ostream` to ignore indent for block
1541  // label instead of using `getOStream`.
1542  os.getOStream() << getOrCreateName(block) << ":\n";
1543  return success();
1544 }
1545 
1546 LogicalResult CppEmitter::emitOperation(Operation &op, bool trailingSemicolon) {
1547  LogicalResult status =
1549  // Builtin ops.
1550  .Case<ModuleOp>([&](auto op) { return printOperation(*this, op); })
1551  // CF ops.
1552  .Case<cf::BranchOp, cf::CondBranchOp>(
1553  [&](auto op) { return printOperation(*this, op); })
1554  // EmitC ops.
1555  .Case<emitc::AddOp, emitc::ApplyOp, emitc::AssignOp,
1556  emitc::BitwiseAndOp, emitc::BitwiseLeftShiftOp,
1557  emitc::BitwiseNotOp, emitc::BitwiseOrOp,
1558  emitc::BitwiseRightShiftOp, emitc::BitwiseXorOp, emitc::CallOp,
1559  emitc::CallOpaqueOp, emitc::CastOp, emitc::CmpOp,
1560  emitc::ConditionalOp, emitc::ConstantOp, emitc::DeclareFuncOp,
1561  emitc::DivOp, emitc::ExpressionOp, emitc::ForOp, emitc::FuncOp,
1562  emitc::GlobalOp, emitc::IfOp, emitc::IncludeOp, emitc::LoadOp,
1563  emitc::LogicalAndOp, emitc::LogicalNotOp, emitc::LogicalOrOp,
1564  emitc::MulOp, emitc::RemOp, emitc::ReturnOp, emitc::SubOp,
1565  emitc::SwitchOp, emitc::UnaryMinusOp, emitc::UnaryPlusOp,
1566  emitc::VariableOp, emitc::VerbatimOp>(
1567  [&](auto op) { return printOperation(*this, op); })
1568  // Func ops.
1569  .Case<func::CallOp, func::FuncOp, func::ReturnOp>(
1570  [&](auto op) { return printOperation(*this, op); })
1571  .Case<emitc::GetGlobalOp>([&](auto op) {
1572  cacheDeferredOpResult(op.getResult(), op.getName());
1573  return success();
1574  })
1575  .Case<emitc::LiteralOp>([&](auto op) {
1576  cacheDeferredOpResult(op.getResult(), op.getValue());
1577  return success();
1578  })
1579  .Case<emitc::MemberOp>([&](auto op) {
1580  cacheDeferredOpResult(op.getResult(), createMemberAccess(op));
1581  return success();
1582  })
1583  .Case<emitc::MemberOfPtrOp>([&](auto op) {
1584  cacheDeferredOpResult(op.getResult(), createMemberAccess(op));
1585  return success();
1586  })
1587  .Case<emitc::SubscriptOp>([&](auto op) {
1588  cacheDeferredOpResult(op.getResult(), getSubscriptName(op));
1589  return success();
1590  })
1591  .Default([&](Operation *) {
1592  return op.emitOpError("unable to find printer for op");
1593  });
1594 
1595  if (failed(status))
1596  return failure();
1597 
1598  if (hasDeferredEmission(&op))
1599  return success();
1600 
1601  if (getEmittedExpression() ||
1602  (isa<emitc::ExpressionOp>(op) &&
1603  shouldBeInlined(cast<emitc::ExpressionOp>(op))))
1604  return success();
1605 
1606  // Never emit a semicolon for some operations, especially if endening with
1607  // `}`.
1608  trailingSemicolon &=
1609  !isa<cf::CondBranchOp, emitc::DeclareFuncOp, emitc::ForOp, emitc::IfOp,
1610  emitc::IncludeOp, emitc::SwitchOp, emitc::VerbatimOp>(op);
1611 
1612  os << (trailingSemicolon ? ";\n" : "\n");
1613 
1614  return success();
1615 }
1616 
1617 LogicalResult CppEmitter::emitVariableDeclaration(Location loc, Type type,
1618  StringRef name) {
1619  if (auto arrType = dyn_cast<emitc::ArrayType>(type)) {
1620  if (failed(emitType(loc, arrType.getElementType())))
1621  return failure();
1622  os << " " << name;
1623  for (auto dim : arrType.getShape()) {
1624  os << "[" << dim << "]";
1625  }
1626  return success();
1627  }
1628  if (failed(emitType(loc, type)))
1629  return failure();
1630  os << " " << name;
1631  return success();
1632 }
1633 
1634 LogicalResult CppEmitter::emitType(Location loc, Type type) {
1635  if (auto iType = dyn_cast<IntegerType>(type)) {
1636  switch (iType.getWidth()) {
1637  case 1:
1638  return (os << "bool"), success();
1639  case 8:
1640  case 16:
1641  case 32:
1642  case 64:
1643  if (shouldMapToUnsigned(iType.getSignedness()))
1644  return (os << "uint" << iType.getWidth() << "_t"), success();
1645  else
1646  return (os << "int" << iType.getWidth() << "_t"), success();
1647  default:
1648  return emitError(loc, "cannot emit integer type ") << type;
1649  }
1650  }
1651  if (auto fType = dyn_cast<FloatType>(type)) {
1652  switch (fType.getWidth()) {
1653  case 16: {
1654  if (llvm::isa<Float16Type>(type))
1655  return (os << "_Float16"), success();
1656  else if (llvm::isa<BFloat16Type>(type))
1657  return (os << "__bf16"), success();
1658  else
1659  return emitError(loc, "cannot emit float type ") << type;
1660  }
1661  case 32:
1662  return (os << "float"), success();
1663  case 64:
1664  return (os << "double"), success();
1665  default:
1666  return emitError(loc, "cannot emit float type ") << type;
1667  }
1668  }
1669  if (auto iType = dyn_cast<IndexType>(type))
1670  return (os << "size_t"), success();
1671  if (auto sType = dyn_cast<emitc::SizeTType>(type))
1672  return (os << "size_t"), success();
1673  if (auto sType = dyn_cast<emitc::SignedSizeTType>(type))
1674  return (os << "ssize_t"), success();
1675  if (auto pType = dyn_cast<emitc::PtrDiffTType>(type))
1676  return (os << "ptrdiff_t"), success();
1677  if (auto tType = dyn_cast<TensorType>(type)) {
1678  if (!tType.hasRank())
1679  return emitError(loc, "cannot emit unranked tensor type");
1680  if (!tType.hasStaticShape())
1681  return emitError(loc, "cannot emit tensor type with non static shape");
1682  os << "Tensor<";
1683  if (isa<ArrayType>(tType.getElementType()))
1684  return emitError(loc, "cannot emit tensor of array type ") << type;
1685  if (failed(emitType(loc, tType.getElementType())))
1686  return failure();
1687  auto shape = tType.getShape();
1688  for (auto dimSize : shape) {
1689  os << ", ";
1690  os << dimSize;
1691  }
1692  os << ">";
1693  return success();
1694  }
1695  if (auto tType = dyn_cast<TupleType>(type))
1696  return emitTupleType(loc, tType.getTypes());
1697  if (auto oType = dyn_cast<emitc::OpaqueType>(type)) {
1698  os << oType.getValue();
1699  return success();
1700  }
1701  if (auto aType = dyn_cast<emitc::ArrayType>(type)) {
1702  if (failed(emitType(loc, aType.getElementType())))
1703  return failure();
1704  for (auto dim : aType.getShape())
1705  os << "[" << dim << "]";
1706  return success();
1707  }
1708  if (auto lType = dyn_cast<emitc::LValueType>(type))
1709  return emitType(loc, lType.getValueType());
1710  if (auto pType = dyn_cast<emitc::PointerType>(type)) {
1711  if (isa<ArrayType>(pType.getPointee()))
1712  return emitError(loc, "cannot emit pointer to array type ") << type;
1713  if (failed(emitType(loc, pType.getPointee())))
1714  return failure();
1715  os << "*";
1716  return success();
1717  }
1718  return emitError(loc, "cannot emit type ") << type;
1719 }
1720 
1721 LogicalResult CppEmitter::emitTypes(Location loc, ArrayRef<Type> types) {
1722  switch (types.size()) {
1723  case 0:
1724  os << "void";
1725  return success();
1726  case 1:
1727  return emitType(loc, types.front());
1728  default:
1729  return emitTupleType(loc, types);
1730  }
1731 }
1732 
1733 LogicalResult CppEmitter::emitTupleType(Location loc, ArrayRef<Type> types) {
1734  if (llvm::any_of(types, llvm::IsaPred<ArrayType>)) {
1735  return emitError(loc, "cannot emit tuple of array type");
1736  }
1737  os << "std::tuple<";
1738  if (failed(interleaveCommaWithError(
1739  types, os, [&](Type type) { return emitType(loc, type); })))
1740  return failure();
1741  os << ">";
1742  return success();
1743 }
1744 
1745 LogicalResult emitc::translateToCpp(Operation *op, raw_ostream &os,
1746  bool declareVariablesAtTop) {
1747  CppEmitter emitter(os, declareVariablesAtTop);
1748  return emitter.emitOperation(*op, /*trailingSemicolon=*/false);
1749 }
static LogicalResult printCallOperation(CppEmitter &emitter, Operation *callOp, StringRef callee)
static bool shouldBeInlined(ExpressionOp expressionOp)
Determine whether expression expressionOp should be emitted inline, i.e.
LogicalResult interleaveCommaWithError(const Container &c, raw_ostream &os, UnaryFunctor eachFn)
static FailureOr< int > getOperatorPrecedence(Operation *operation)
Return the precedence of a operator as an integer, higher values imply higher precedence.
static LogicalResult printFunctionArgs(CppEmitter &emitter, Operation *functionOp, ArrayRef< Type > arguments)
static LogicalResult printFunctionBody(CppEmitter &emitter, Operation *functionOp, Region::BlockListType &blocks)
static LogicalResult printConstantOp(CppEmitter &emitter, Operation *operation, Attribute value)
static LogicalResult emitSwitchCase(CppEmitter &emitter, raw_indented_ostream &os, Region &region)
LogicalResult interleaveWithError(ForwardIterator begin, ForwardIterator end, UnaryFunctor eachFn, NullaryFunctor betweenFn)
Convenience functions to produce interleaved output with functions returning a LogicalResult.
static LogicalResult printBinaryOperation(CppEmitter &emitter, Operation *operation, StringRef binaryOperator)
static LogicalResult printOperation(CppEmitter &emitter, emitc::ConstantOp constantOp)
static LogicalResult printUnaryOperation(CppEmitter &emitter, Operation *operation, StringRef unaryOperator)
static bool hasDeferredEmission(Operation *op)
Determine whether expression op should be emitted in a deferred way.
Attributes are known-constant values of operations.
Definition: Attributes.h:25
This class represents an argument of a Block.
Definition: Value.h:319
Block represents an ordered list of Operations.
Definition: Block.h:33
BlockArgListType getArguments()
Definition: Block.h:87
Block * getSuccessor(unsigned i)
Definition: Block.cpp:261
Operation * getParentOp()
Returns the closest surrounding operation that contains this block.
Definition: Block.cpp:33
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:66
NamedAttribute represents a combination of a name and an Attribute value.
Definition: Attributes.h:207
This is a value defined by a result of an operation.
Definition: Value.h:457
Operation * getOwner() const
Returns the operation that owns this result.
Definition: Value.h:466
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
Value getOperand(unsigned idx)
Definition: Operation.h:350
bool hasTrait()
Returns true if the operation was registered with a particular trait, e.g.
Definition: Operation.h:750
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Definition: Operation.h:407
std::enable_if_t< llvm::function_traits< std::decay_t< FnT > >::num_args==1, RetT > walk(FnT &&callback)
Walk the operation by calling the callback for each nested operation (including this one),...
Definition: Operation.h:798
Location getLoc()
The source location the operation was defined or derived from.
Definition: Operation.h:223
unsigned getNumOperands()
Definition: Operation.h:346
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
Definition: Operation.h:234
ArrayRef< NamedAttribute > getAttrs()
Return all of the attributes on this operation.
Definition: Operation.h:512
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
Definition: Operation.cpp:268
OperationName getName()
The name of an operation is the key identifier for it.
Definition: Operation.h:119
operand_range getOperands()
Returns an iterator on the underlying Value's.
Definition: Operation.h:378
result_range getResults()
Definition: Operation.h:415
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
Definition: Operation.cpp:671
unsigned getNumResults()
Return the number of results held by this operation.
Definition: Operation.h:404
This class provides iteration over the held operations of blocks directly within a region.
Definition: Region.h:134
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition: Region.h:26
iterator_range< OpIterator > getOps()
Definition: Region.h:172
llvm::iplist< Block > BlockListType
Definition: Region.h:44
OpIterator op_begin()
Return iterators that walk the operations nested directly within this region.
Definition: Region.h:170
bool empty()
Definition: Region.h:60
OpIterator op_end()
Definition: Region.h:171
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:129
user_range getUsers() const
Definition: Value.h:228
bool hasOneUse() const
Returns true if this value has exactly one use.
Definition: Value.h:215
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition: Value.cpp:20
A utility result that is used to signal how to proceed with an ongoing walk:
Definition: Visitors.h:33
static WalkResult skip()
Definition: Visitors.h:52
static WalkResult advance()
Definition: Visitors.h:51
bool wasInterrupted() const
Returns true if the walk was interrupted.
Definition: Visitors.h:55
raw_ostream subclass that simplifies indention a sequence of code.
raw_indented_ostream & unindent()
Decreases the indent and returning this raw_indented_ostream.
raw_indented_ostream & indent()
Increases the indent and returning this raw_indented_ostream.
LogicalResult translateToCpp(Operation *op, raw_ostream &os, bool declareVariablesAtTop=false)
Translates the given operation to C++ code.
Include the generated interface declarations.
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
This iterator enumerates the elements in "forward" order.
Definition: Visitors.h:65