MLIR  21.0.0git
TranslateToCpp.cpp
Go to the documentation of this file.
1 //===- TranslateToCpp.cpp - Translating to C++ calls ----------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
12 #include "mlir/IR/BuiltinOps.h"
13 #include "mlir/IR/BuiltinTypes.h"
14 #include "mlir/IR/Dialect.h"
15 #include "mlir/IR/Operation.h"
16 #include "mlir/IR/SymbolTable.h"
18 #include "mlir/Support/LLVM.h"
20 #include "llvm/ADT/DenseMap.h"
21 #include "llvm/ADT/ScopedHashTable.h"
22 #include "llvm/ADT/StringExtras.h"
23 #include "llvm/ADT/StringMap.h"
24 #include "llvm/ADT/TypeSwitch.h"
25 #include "llvm/Support/Debug.h"
26 #include "llvm/Support/FormatVariadic.h"
27 #include <stack>
28 #include <utility>
29 
30 #define DEBUG_TYPE "translate-to-cpp"
31 
32 using namespace mlir;
33 using namespace mlir::emitc;
34 using llvm::formatv;
35 
36 /// Convenience functions to produce interleaved output with functions returning
37 /// a LogicalResult. This is different than those in STLExtras as functions used
38 /// on each element doesn't return a string.
39 template <typename ForwardIterator, typename UnaryFunctor,
40  typename NullaryFunctor>
41 inline LogicalResult
43  UnaryFunctor eachFn, NullaryFunctor betweenFn) {
44  if (begin == end)
45  return success();
46  if (failed(eachFn(*begin)))
47  return failure();
48  ++begin;
49  for (; begin != end; ++begin) {
50  betweenFn();
51  if (failed(eachFn(*begin)))
52  return failure();
53  }
54  return success();
55 }
56 
57 template <typename Container, typename UnaryFunctor, typename NullaryFunctor>
58 inline LogicalResult interleaveWithError(const Container &c,
59  UnaryFunctor eachFn,
60  NullaryFunctor betweenFn) {
61  return interleaveWithError(c.begin(), c.end(), eachFn, betweenFn);
62 }
63 
64 template <typename Container, typename UnaryFunctor>
65 inline LogicalResult interleaveCommaWithError(const Container &c,
66  raw_ostream &os,
67  UnaryFunctor eachFn) {
68  return interleaveWithError(c.begin(), c.end(), eachFn, [&]() { os << ", "; });
69 }
70 
71 /// Return the precedence of a operator as an integer, higher values
72 /// imply higher precedence.
73 static FailureOr<int> getOperatorPrecedence(Operation *operation) {
75  .Case<emitc::AddOp>([&](auto op) { return 12; })
76  .Case<emitc::ApplyOp>([&](auto op) { return 15; })
77  .Case<emitc::BitwiseAndOp>([&](auto op) { return 7; })
78  .Case<emitc::BitwiseLeftShiftOp>([&](auto op) { return 11; })
79  .Case<emitc::BitwiseNotOp>([&](auto op) { return 15; })
80  .Case<emitc::BitwiseOrOp>([&](auto op) { return 5; })
81  .Case<emitc::BitwiseRightShiftOp>([&](auto op) { return 11; })
82  .Case<emitc::BitwiseXorOp>([&](auto op) { return 6; })
83  .Case<emitc::CallOp>([&](auto op) { return 16; })
84  .Case<emitc::CallOpaqueOp>([&](auto op) { return 16; })
85  .Case<emitc::CastOp>([&](auto op) { return 15; })
86  .Case<emitc::CmpOp>([&](auto op) -> FailureOr<int> {
87  switch (op.getPredicate()) {
88  case emitc::CmpPredicate::eq:
89  case emitc::CmpPredicate::ne:
90  return 8;
91  case emitc::CmpPredicate::lt:
92  case emitc::CmpPredicate::le:
93  case emitc::CmpPredicate::gt:
94  case emitc::CmpPredicate::ge:
95  return 9;
96  case emitc::CmpPredicate::three_way:
97  return 10;
98  }
99  return op->emitError("unsupported cmp predicate");
100  })
101  .Case<emitc::ConditionalOp>([&](auto op) { return 2; })
102  .Case<emitc::DivOp>([&](auto op) { return 13; })
103  .Case<emitc::LoadOp>([&](auto op) { return 16; })
104  .Case<emitc::LogicalAndOp>([&](auto op) { return 4; })
105  .Case<emitc::LogicalNotOp>([&](auto op) { return 15; })
106  .Case<emitc::LogicalOrOp>([&](auto op) { return 3; })
107  .Case<emitc::MulOp>([&](auto op) { return 13; })
108  .Case<emitc::RemOp>([&](auto op) { return 13; })
109  .Case<emitc::SubOp>([&](auto op) { return 12; })
110  .Case<emitc::UnaryMinusOp>([&](auto op) { return 15; })
111  .Case<emitc::UnaryPlusOp>([&](auto op) { return 15; })
112  .Default([](auto op) { return op->emitError("unsupported operation"); });
113 }
114 
115 namespace {
116 /// Emitter that uses dialect specific emitters to emit C++ code.
117 struct CppEmitter {
118  explicit CppEmitter(raw_ostream &os, bool declareVariablesAtTop,
119  StringRef fileId);
120 
121  /// Emits attribute or returns failure.
122  LogicalResult emitAttribute(Location loc, Attribute attr);
123 
124  /// Emits operation 'op' with/without training semicolon or returns failure.
125  ///
126  /// For operations that should never be followed by a semicolon, like ForOp,
127  /// the `trailingSemicolon` argument is ignored and a semicolon is not
128  /// emitted.
129  LogicalResult emitOperation(Operation &op, bool trailingSemicolon);
130 
131  /// Emits type 'type' or returns failure.
132  LogicalResult emitType(Location loc, Type type);
133 
134  /// Emits array of types as a std::tuple of the emitted types.
135  /// - emits void for an empty array;
136  /// - emits the type of the only element for arrays of size one;
137  /// - emits a std::tuple otherwise;
138  LogicalResult emitTypes(Location loc, ArrayRef<Type> types);
139 
140  /// Emits array of types as a std::tuple of the emitted types independently of
141  /// the array size.
142  LogicalResult emitTupleType(Location loc, ArrayRef<Type> types);
143 
144  /// Emits an assignment for a variable which has been declared previously.
145  LogicalResult emitVariableAssignment(OpResult result);
146 
147  /// Emits a variable declaration for a result of an operation.
148  LogicalResult emitVariableDeclaration(OpResult result,
149  bool trailingSemicolon);
150 
151  /// Emits a declaration of a variable with the given type and name.
152  LogicalResult emitVariableDeclaration(Location loc, Type type,
153  StringRef name);
154 
155  /// Emits the variable declaration and assignment prefix for 'op'.
156  /// - emits separate variable followed by std::tie for multi-valued operation;
157  /// - emits single type followed by variable for single result;
158  /// - emits nothing if no value produced by op;
159  /// Emits final '=' operator where a type is produced. Returns failure if
160  /// any result type could not be converted.
161  LogicalResult emitAssignPrefix(Operation &op);
162 
163  /// Emits a global variable declaration or definition.
164  LogicalResult emitGlobalVariable(GlobalOp op);
165 
166  /// Emits a label for the block.
167  LogicalResult emitLabel(Block &block);
168 
169  /// Emits the operands and atttributes of the operation. All operands are
170  /// emitted first and then all attributes in alphabetical order.
171  LogicalResult emitOperandsAndAttributes(Operation &op,
172  ArrayRef<StringRef> exclude = {});
173 
174  /// Emits the operands of the operation. All operands are emitted in order.
175  LogicalResult emitOperands(Operation &op);
176 
177  /// Emits value as an operands of an operation
178  LogicalResult emitOperand(Value value);
179 
180  /// Emit an expression as a C expression.
181  LogicalResult emitExpression(ExpressionOp expressionOp);
182 
183  /// Insert the expression representing the operation into the value cache.
184  void cacheDeferredOpResult(Value value, StringRef str);
185 
186  /// Return the existing or a new name for a Value.
187  StringRef getOrCreateName(Value val);
188 
189  // Returns the textual representation of a subscript operation.
190  std::string getSubscriptName(emitc::SubscriptOp op);
191 
192  // Returns the textual representation of a member (of object) operation.
193  std::string createMemberAccess(emitc::MemberOp op);
194 
195  // Returns the textual representation of a member of pointer operation.
196  std::string createMemberAccess(emitc::MemberOfPtrOp op);
197 
198  /// Return the existing or a new label of a Block.
199  StringRef getOrCreateName(Block &block);
200 
201  /// Whether to map an mlir integer to a unsigned integer in C++.
202  bool shouldMapToUnsigned(IntegerType::SignednessSemantics val);
203 
204  /// RAII helper function to manage entering/exiting C++ scopes.
205  struct Scope {
206  Scope(CppEmitter &emitter)
207  : valueMapperScope(emitter.valueMapper),
208  blockMapperScope(emitter.blockMapper), emitter(emitter) {
209  emitter.valueInScopeCount.push(emitter.valueInScopeCount.top());
210  emitter.labelInScopeCount.push(emitter.labelInScopeCount.top());
211  }
212  ~Scope() {
213  emitter.valueInScopeCount.pop();
214  emitter.labelInScopeCount.pop();
215  }
216 
217  private:
218  llvm::ScopedHashTableScope<Value, std::string> valueMapperScope;
219  llvm::ScopedHashTableScope<Block *, std::string> blockMapperScope;
220  CppEmitter &emitter;
221  };
222 
223  /// Returns wether the Value is assigned to a C++ variable in the scope.
224  bool hasValueInScope(Value val);
225 
226  // Returns whether a label is assigned to the block.
227  bool hasBlockLabel(Block &block);
228 
229  /// Returns the output stream.
230  raw_indented_ostream &ostream() { return os; };
231 
232  /// Returns if all variables for op results and basic block arguments need to
233  /// be declared at the beginning of a function.
234  bool shouldDeclareVariablesAtTop() { return declareVariablesAtTop; };
235 
236  /// Returns whether this file op should be emitted
237  bool shouldEmitFile(FileOp file) {
238  return !fileId.empty() && file.getId() == fileId;
239  }
240 
241  /// Get expression currently being emitted.
242  ExpressionOp getEmittedExpression() { return emittedExpression; }
243 
244  /// Determine whether given value is part of the expression potentially being
245  /// emitted.
246  bool isPartOfCurrentExpression(Value value) {
247  if (!emittedExpression)
248  return false;
249  Operation *def = value.getDefiningOp();
250  if (!def)
251  return false;
252  auto operandExpression = dyn_cast<ExpressionOp>(def->getParentOp());
253  return operandExpression == emittedExpression;
254  };
255 
256 private:
257  using ValueMapper = llvm::ScopedHashTable<Value, std::string>;
258  using BlockMapper = llvm::ScopedHashTable<Block *, std::string>;
259 
260  /// Output stream to emit to.
262 
263  /// Boolean to enforce that all variables for op results and block
264  /// arguments are declared at the beginning of the function. This also
265  /// includes results from ops located in nested regions.
266  bool declareVariablesAtTop;
267 
268  /// Only emit file ops whos id matches this value.
269  std::string fileId;
270 
271  /// Map from value to name of C++ variable that contain the name.
272  ValueMapper valueMapper;
273 
274  /// Map from block to name of C++ label.
275  BlockMapper blockMapper;
276 
277  /// The number of values in the current scope. This is used to declare the
278  /// names of values in a scope.
279  std::stack<int64_t> valueInScopeCount;
280  std::stack<int64_t> labelInScopeCount;
281 
282  /// State of the current expression being emitted.
283  ExpressionOp emittedExpression;
284  SmallVector<int> emittedExpressionPrecedence;
285 
286  void pushExpressionPrecedence(int precedence) {
287  emittedExpressionPrecedence.push_back(precedence);
288  }
289  void popExpressionPrecedence() { emittedExpressionPrecedence.pop_back(); }
290  static int lowestPrecedence() { return 0; }
291  int getExpressionPrecedence() {
292  if (emittedExpressionPrecedence.empty())
293  return lowestPrecedence();
294  return emittedExpressionPrecedence.back();
295  }
296 };
297 } // namespace
298 
299 /// Determine whether expression \p op should be emitted in a deferred way.
300 static bool hasDeferredEmission(Operation *op) {
301  return isa_and_nonnull<emitc::GetGlobalOp, emitc::LiteralOp, emitc::MemberOp,
302  emitc::MemberOfPtrOp, emitc::SubscriptOp>(op);
303 }
304 
305 /// Determine whether expression \p expressionOp should be emitted inline, i.e.
306 /// as part of its user. This function recommends inlining of any expressions
307 /// that can be inlined unless it is used by another expression, under the
308 /// assumption that any expression fusion/re-materialization was taken care of
309 /// by transformations run by the backend.
310 static bool shouldBeInlined(ExpressionOp expressionOp) {
311  // Do not inline if expression is marked as such.
312  if (expressionOp.getDoNotInline())
313  return false;
314 
315  // Do not inline expressions with side effects to prevent side-effect
316  // reordering.
317  if (expressionOp.hasSideEffects())
318  return false;
319 
320  // Do not inline expressions with multiple uses.
321  Value result = expressionOp.getResult();
322  if (!result.hasOneUse())
323  return false;
324 
325  Operation *user = *result.getUsers().begin();
326 
327  // Do not inline expressions used by operations with deferred emission, since
328  // their translation requires the materialization of variables.
329  if (hasDeferredEmission(user))
330  return false;
331 
332  // Do not inline expressions used by ops with the CExpression trait. If this
333  // was intended, the user could have been merged into the expression op.
334  return !user->hasTrait<OpTrait::emitc::CExpression>();
335 }
336 
337 static LogicalResult printConstantOp(CppEmitter &emitter, Operation *operation,
338  Attribute value) {
339  OpResult result = operation->getResult(0);
340 
341  // Only emit an assignment as the variable was already declared when printing
342  // the FuncOp.
343  if (emitter.shouldDeclareVariablesAtTop()) {
344  // Skip the assignment if the emitc.constant has no value.
345  if (auto oAttr = dyn_cast<emitc::OpaqueAttr>(value)) {
346  if (oAttr.getValue().empty())
347  return success();
348  }
349 
350  if (failed(emitter.emitVariableAssignment(result)))
351  return failure();
352  return emitter.emitAttribute(operation->getLoc(), value);
353  }
354 
355  // Emit a variable declaration for an emitc.constant op without value.
356  if (auto oAttr = dyn_cast<emitc::OpaqueAttr>(value)) {
357  if (oAttr.getValue().empty())
358  // The semicolon gets printed by the emitOperation function.
359  return emitter.emitVariableDeclaration(result,
360  /*trailingSemicolon=*/false);
361  }
362 
363  // Emit a variable declaration.
364  if (failed(emitter.emitAssignPrefix(*operation)))
365  return failure();
366  return emitter.emitAttribute(operation->getLoc(), value);
367 }
368 
369 static LogicalResult printOperation(CppEmitter &emitter,
370  emitc::ConstantOp constantOp) {
371  Operation *operation = constantOp.getOperation();
372  Attribute value = constantOp.getValue();
373 
374  return printConstantOp(emitter, operation, value);
375 }
376 
377 static LogicalResult printOperation(CppEmitter &emitter,
378  emitc::VariableOp variableOp) {
379  Operation *operation = variableOp.getOperation();
380  Attribute value = variableOp.getValue();
381 
382  return printConstantOp(emitter, operation, value);
383 }
384 
385 static LogicalResult printOperation(CppEmitter &emitter,
386  emitc::GlobalOp globalOp) {
387 
388  return emitter.emitGlobalVariable(globalOp);
389 }
390 
391 static LogicalResult printOperation(CppEmitter &emitter,
392  emitc::AssignOp assignOp) {
393  OpResult result = assignOp.getVar().getDefiningOp()->getResult(0);
394 
395  if (failed(emitter.emitVariableAssignment(result)))
396  return failure();
397 
398  return emitter.emitOperand(assignOp.getValue());
399 }
400 
401 static LogicalResult printOperation(CppEmitter &emitter, emitc::LoadOp loadOp) {
402  if (failed(emitter.emitAssignPrefix(*loadOp)))
403  return failure();
404 
405  return emitter.emitOperand(loadOp.getOperand());
406 }
407 
408 static LogicalResult printBinaryOperation(CppEmitter &emitter,
409  Operation *operation,
410  StringRef binaryOperator) {
411  raw_ostream &os = emitter.ostream();
412 
413  if (failed(emitter.emitAssignPrefix(*operation)))
414  return failure();
415 
416  if (failed(emitter.emitOperand(operation->getOperand(0))))
417  return failure();
418 
419  os << " " << binaryOperator << " ";
420 
421  if (failed(emitter.emitOperand(operation->getOperand(1))))
422  return failure();
423 
424  return success();
425 }
426 
427 static LogicalResult printUnaryOperation(CppEmitter &emitter,
428  Operation *operation,
429  StringRef unaryOperator) {
430  raw_ostream &os = emitter.ostream();
431 
432  if (failed(emitter.emitAssignPrefix(*operation)))
433  return failure();
434 
435  os << unaryOperator;
436 
437  if (failed(emitter.emitOperand(operation->getOperand(0))))
438  return failure();
439 
440  return success();
441 }
442 
443 static LogicalResult printOperation(CppEmitter &emitter, emitc::AddOp addOp) {
444  Operation *operation = addOp.getOperation();
445 
446  return printBinaryOperation(emitter, operation, "+");
447 }
448 
449 static LogicalResult printOperation(CppEmitter &emitter, emitc::DivOp divOp) {
450  Operation *operation = divOp.getOperation();
451 
452  return printBinaryOperation(emitter, operation, "/");
453 }
454 
455 static LogicalResult printOperation(CppEmitter &emitter, emitc::MulOp mulOp) {
456  Operation *operation = mulOp.getOperation();
457 
458  return printBinaryOperation(emitter, operation, "*");
459 }
460 
461 static LogicalResult printOperation(CppEmitter &emitter, emitc::RemOp remOp) {
462  Operation *operation = remOp.getOperation();
463 
464  return printBinaryOperation(emitter, operation, "%");
465 }
466 
467 static LogicalResult printOperation(CppEmitter &emitter, emitc::SubOp subOp) {
468  Operation *operation = subOp.getOperation();
469 
470  return printBinaryOperation(emitter, operation, "-");
471 }
472 
473 static LogicalResult emitSwitchCase(CppEmitter &emitter,
474  raw_indented_ostream &os, Region &region) {
475  for (Region::OpIterator iteratorOp = region.op_begin(), end = region.op_end();
476  std::next(iteratorOp) != end; ++iteratorOp) {
477  if (failed(emitter.emitOperation(*iteratorOp, /*trailingSemicolon=*/true)))
478  return failure();
479  }
480  os << "break;\n";
481  return success();
482 }
483 
484 static LogicalResult printOperation(CppEmitter &emitter,
485  emitc::SwitchOp switchOp) {
486  raw_indented_ostream &os = emitter.ostream();
487 
488  os << "switch (";
489  if (failed(emitter.emitOperand(switchOp.getArg())))
490  return failure();
491  os << ") {";
492 
493  for (auto pair : llvm::zip(switchOp.getCases(), switchOp.getCaseRegions())) {
494  os << "\ncase " << std::get<0>(pair) << ": {\n";
495  os.indent();
496 
497  if (failed(emitSwitchCase(emitter, os, std::get<1>(pair))))
498  return failure();
499 
500  os.unindent() << "}";
501  }
502 
503  os << "\ndefault: {\n";
504  os.indent();
505 
506  if (failed(emitSwitchCase(emitter, os, switchOp.getDefaultRegion())))
507  return failure();
508 
509  os.unindent() << "}\n}";
510  return success();
511 }
512 
513 static LogicalResult printOperation(CppEmitter &emitter, emitc::CmpOp cmpOp) {
514  Operation *operation = cmpOp.getOperation();
515 
516  StringRef binaryOperator;
517 
518  switch (cmpOp.getPredicate()) {
519  case emitc::CmpPredicate::eq:
520  binaryOperator = "==";
521  break;
522  case emitc::CmpPredicate::ne:
523  binaryOperator = "!=";
524  break;
525  case emitc::CmpPredicate::lt:
526  binaryOperator = "<";
527  break;
528  case emitc::CmpPredicate::le:
529  binaryOperator = "<=";
530  break;
531  case emitc::CmpPredicate::gt:
532  binaryOperator = ">";
533  break;
534  case emitc::CmpPredicate::ge:
535  binaryOperator = ">=";
536  break;
537  case emitc::CmpPredicate::three_way:
538  binaryOperator = "<=>";
539  break;
540  }
541 
542  return printBinaryOperation(emitter, operation, binaryOperator);
543 }
544 
545 static LogicalResult printOperation(CppEmitter &emitter,
546  emitc::ConditionalOp conditionalOp) {
547  raw_ostream &os = emitter.ostream();
548 
549  if (failed(emitter.emitAssignPrefix(*conditionalOp)))
550  return failure();
551 
552  if (failed(emitter.emitOperand(conditionalOp.getCondition())))
553  return failure();
554 
555  os << " ? ";
556 
557  if (failed(emitter.emitOperand(conditionalOp.getTrueValue())))
558  return failure();
559 
560  os << " : ";
561 
562  if (failed(emitter.emitOperand(conditionalOp.getFalseValue())))
563  return failure();
564 
565  return success();
566 }
567 
568 static LogicalResult printOperation(CppEmitter &emitter,
569  emitc::VerbatimOp verbatimOp) {
570  raw_ostream &os = emitter.ostream();
571 
572  FailureOr<SmallVector<ReplacementItem>> items =
573  verbatimOp.parseFormatString();
574  if (failed(items))
575  return failure();
576 
577  auto fmtArg = verbatimOp.getFmtArgs().begin();
578 
579  for (ReplacementItem &item : *items) {
580  if (auto *str = std::get_if<StringRef>(&item)) {
581  os << *str;
582  } else {
583  if (failed(emitter.emitOperand(*fmtArg++)))
584  return failure();
585  }
586  }
587 
588  return success();
589 }
590 
591 static LogicalResult printOperation(CppEmitter &emitter,
592  cf::BranchOp branchOp) {
593  raw_ostream &os = emitter.ostream();
594  Block &successor = *branchOp.getSuccessor();
595 
596  for (auto pair :
597  llvm::zip(branchOp.getOperands(), successor.getArguments())) {
598  Value &operand = std::get<0>(pair);
599  BlockArgument &argument = std::get<1>(pair);
600  os << emitter.getOrCreateName(argument) << " = "
601  << emitter.getOrCreateName(operand) << ";\n";
602  }
603 
604  os << "goto ";
605  if (!(emitter.hasBlockLabel(successor)))
606  return branchOp.emitOpError("unable to find label for successor block");
607  os << emitter.getOrCreateName(successor);
608  return success();
609 }
610 
611 static LogicalResult printOperation(CppEmitter &emitter,
612  cf::CondBranchOp condBranchOp) {
613  raw_indented_ostream &os = emitter.ostream();
614  Block &trueSuccessor = *condBranchOp.getTrueDest();
615  Block &falseSuccessor = *condBranchOp.getFalseDest();
616 
617  os << "if (";
618  if (failed(emitter.emitOperand(condBranchOp.getCondition())))
619  return failure();
620  os << ") {\n";
621 
622  os.indent();
623 
624  // If condition is true.
625  for (auto pair : llvm::zip(condBranchOp.getTrueOperands(),
626  trueSuccessor.getArguments())) {
627  Value &operand = std::get<0>(pair);
628  BlockArgument &argument = std::get<1>(pair);
629  os << emitter.getOrCreateName(argument) << " = "
630  << emitter.getOrCreateName(operand) << ";\n";
631  }
632 
633  os << "goto ";
634  if (!(emitter.hasBlockLabel(trueSuccessor))) {
635  return condBranchOp.emitOpError("unable to find label for successor block");
636  }
637  os << emitter.getOrCreateName(trueSuccessor) << ";\n";
638  os.unindent() << "} else {\n";
639  os.indent();
640  // If condition is false.
641  for (auto pair : llvm::zip(condBranchOp.getFalseOperands(),
642  falseSuccessor.getArguments())) {
643  Value &operand = std::get<0>(pair);
644  BlockArgument &argument = std::get<1>(pair);
645  os << emitter.getOrCreateName(argument) << " = "
646  << emitter.getOrCreateName(operand) << ";\n";
647  }
648 
649  os << "goto ";
650  if (!(emitter.hasBlockLabel(falseSuccessor))) {
651  return condBranchOp.emitOpError()
652  << "unable to find label for successor block";
653  }
654  os << emitter.getOrCreateName(falseSuccessor) << ";\n";
655  os.unindent() << "}";
656  return success();
657 }
658 
659 static LogicalResult printCallOperation(CppEmitter &emitter, Operation *callOp,
660  StringRef callee) {
661  if (failed(emitter.emitAssignPrefix(*callOp)))
662  return failure();
663 
664  raw_ostream &os = emitter.ostream();
665  os << callee << "(";
666  if (failed(emitter.emitOperands(*callOp)))
667  return failure();
668  os << ")";
669  return success();
670 }
671 
672 static LogicalResult printOperation(CppEmitter &emitter, func::CallOp callOp) {
673  Operation *operation = callOp.getOperation();
674  StringRef callee = callOp.getCallee();
675 
676  return printCallOperation(emitter, operation, callee);
677 }
678 
679 static LogicalResult printOperation(CppEmitter &emitter, emitc::CallOp callOp) {
680  Operation *operation = callOp.getOperation();
681  StringRef callee = callOp.getCallee();
682 
683  return printCallOperation(emitter, operation, callee);
684 }
685 
686 static LogicalResult printOperation(CppEmitter &emitter,
687  emitc::CallOpaqueOp callOpaqueOp) {
688  raw_ostream &os = emitter.ostream();
689  Operation &op = *callOpaqueOp.getOperation();
690 
691  if (failed(emitter.emitAssignPrefix(op)))
692  return failure();
693  os << callOpaqueOp.getCallee();
694 
695  auto emitArgs = [&](Attribute attr) -> LogicalResult {
696  if (auto t = dyn_cast<IntegerAttr>(attr)) {
697  // Index attributes are treated specially as operand index.
698  if (t.getType().isIndex()) {
699  int64_t idx = t.getInt();
700  Value operand = op.getOperand(idx);
701  if (!emitter.hasValueInScope(operand))
702  return op.emitOpError("operand ")
703  << idx << "'s value not defined in scope";
704  os << emitter.getOrCreateName(operand);
705  return success();
706  }
707  }
708  if (failed(emitter.emitAttribute(op.getLoc(), attr)))
709  return failure();
710 
711  return success();
712  };
713 
714  if (callOpaqueOp.getTemplateArgs()) {
715  os << "<";
716  if (failed(interleaveCommaWithError(*callOpaqueOp.getTemplateArgs(), os,
717  emitArgs)))
718  return failure();
719  os << ">";
720  }
721 
722  os << "(";
723 
724  LogicalResult emittedArgs =
725  callOpaqueOp.getArgs()
726  ? interleaveCommaWithError(*callOpaqueOp.getArgs(), os, emitArgs)
727  : emitter.emitOperands(op);
728  if (failed(emittedArgs))
729  return failure();
730  os << ")";
731  return success();
732 }
733 
734 static LogicalResult printOperation(CppEmitter &emitter,
735  emitc::ApplyOp applyOp) {
736  raw_ostream &os = emitter.ostream();
737  Operation &op = *applyOp.getOperation();
738 
739  if (failed(emitter.emitAssignPrefix(op)))
740  return failure();
741  os << applyOp.getApplicableOperator();
742  os << emitter.getOrCreateName(applyOp.getOperand());
743 
744  return success();
745 }
746 
747 static LogicalResult printOperation(CppEmitter &emitter,
748  emitc::BitwiseAndOp bitwiseAndOp) {
749  Operation *operation = bitwiseAndOp.getOperation();
750  return printBinaryOperation(emitter, operation, "&");
751 }
752 
753 static LogicalResult
754 printOperation(CppEmitter &emitter,
755  emitc::BitwiseLeftShiftOp bitwiseLeftShiftOp) {
756  Operation *operation = bitwiseLeftShiftOp.getOperation();
757  return printBinaryOperation(emitter, operation, "<<");
758 }
759 
760 static LogicalResult printOperation(CppEmitter &emitter,
761  emitc::BitwiseNotOp bitwiseNotOp) {
762  Operation *operation = bitwiseNotOp.getOperation();
763  return printUnaryOperation(emitter, operation, "~");
764 }
765 
766 static LogicalResult printOperation(CppEmitter &emitter,
767  emitc::BitwiseOrOp bitwiseOrOp) {
768  Operation *operation = bitwiseOrOp.getOperation();
769  return printBinaryOperation(emitter, operation, "|");
770 }
771 
772 static LogicalResult
773 printOperation(CppEmitter &emitter,
774  emitc::BitwiseRightShiftOp bitwiseRightShiftOp) {
775  Operation *operation = bitwiseRightShiftOp.getOperation();
776  return printBinaryOperation(emitter, operation, ">>");
777 }
778 
779 static LogicalResult printOperation(CppEmitter &emitter,
780  emitc::BitwiseXorOp bitwiseXorOp) {
781  Operation *operation = bitwiseXorOp.getOperation();
782  return printBinaryOperation(emitter, operation, "^");
783 }
784 
785 static LogicalResult printOperation(CppEmitter &emitter,
786  emitc::UnaryPlusOp unaryPlusOp) {
787  Operation *operation = unaryPlusOp.getOperation();
788  return printUnaryOperation(emitter, operation, "+");
789 }
790 
791 static LogicalResult printOperation(CppEmitter &emitter,
792  emitc::UnaryMinusOp unaryMinusOp) {
793  Operation *operation = unaryMinusOp.getOperation();
794  return printUnaryOperation(emitter, operation, "-");
795 }
796 
797 static LogicalResult printOperation(CppEmitter &emitter, emitc::CastOp castOp) {
798  raw_ostream &os = emitter.ostream();
799  Operation &op = *castOp.getOperation();
800 
801  if (failed(emitter.emitAssignPrefix(op)))
802  return failure();
803  os << "(";
804  if (failed(emitter.emitType(op.getLoc(), op.getResult(0).getType())))
805  return failure();
806  os << ") ";
807  return emitter.emitOperand(castOp.getOperand());
808 }
809 
810 static LogicalResult printOperation(CppEmitter &emitter,
811  emitc::ExpressionOp expressionOp) {
812  if (shouldBeInlined(expressionOp))
813  return success();
814 
815  Operation &op = *expressionOp.getOperation();
816 
817  if (failed(emitter.emitAssignPrefix(op)))
818  return failure();
819 
820  return emitter.emitExpression(expressionOp);
821 }
822 
823 static LogicalResult printOperation(CppEmitter &emitter,
824  emitc::IncludeOp includeOp) {
825  raw_ostream &os = emitter.ostream();
826 
827  os << "#include ";
828  if (includeOp.getIsStandardInclude())
829  os << "<" << includeOp.getInclude() << ">";
830  else
831  os << "\"" << includeOp.getInclude() << "\"";
832 
833  return success();
834 }
835 
836 static LogicalResult printOperation(CppEmitter &emitter,
837  emitc::LogicalAndOp logicalAndOp) {
838  Operation *operation = logicalAndOp.getOperation();
839  return printBinaryOperation(emitter, operation, "&&");
840 }
841 
842 static LogicalResult printOperation(CppEmitter &emitter,
843  emitc::LogicalNotOp logicalNotOp) {
844  Operation *operation = logicalNotOp.getOperation();
845  return printUnaryOperation(emitter, operation, "!");
846 }
847 
848 static LogicalResult printOperation(CppEmitter &emitter,
849  emitc::LogicalOrOp logicalOrOp) {
850  Operation *operation = logicalOrOp.getOperation();
851  return printBinaryOperation(emitter, operation, "||");
852 }
853 
854 static LogicalResult printOperation(CppEmitter &emitter, emitc::ForOp forOp) {
855 
856  raw_indented_ostream &os = emitter.ostream();
857 
858  // Utility function to determine whether a value is an expression that will be
859  // inlined, and as such should be wrapped in parentheses in order to guarantee
860  // its precedence and associativity.
861  auto requiresParentheses = [&](Value value) {
862  auto expressionOp =
863  dyn_cast_if_present<ExpressionOp>(value.getDefiningOp());
864  if (!expressionOp)
865  return false;
866  return shouldBeInlined(expressionOp);
867  };
868 
869  os << "for (";
870  if (failed(
871  emitter.emitType(forOp.getLoc(), forOp.getInductionVar().getType())))
872  return failure();
873  os << " ";
874  os << emitter.getOrCreateName(forOp.getInductionVar());
875  os << " = ";
876  if (failed(emitter.emitOperand(forOp.getLowerBound())))
877  return failure();
878  os << "; ";
879  os << emitter.getOrCreateName(forOp.getInductionVar());
880  os << " < ";
881  Value upperBound = forOp.getUpperBound();
882  bool upperBoundRequiresParentheses = requiresParentheses(upperBound);
883  if (upperBoundRequiresParentheses)
884  os << "(";
885  if (failed(emitter.emitOperand(upperBound)))
886  return failure();
887  if (upperBoundRequiresParentheses)
888  os << ")";
889  os << "; ";
890  os << emitter.getOrCreateName(forOp.getInductionVar());
891  os << " += ";
892  if (failed(emitter.emitOperand(forOp.getStep())))
893  return failure();
894  os << ") {\n";
895  os.indent();
896 
897  Region &forRegion = forOp.getRegion();
898  auto regionOps = forRegion.getOps();
899 
900  // We skip the trailing yield op.
901  for (auto it = regionOps.begin(); std::next(it) != regionOps.end(); ++it) {
902  if (failed(emitter.emitOperation(*it, /*trailingSemicolon=*/true)))
903  return failure();
904  }
905 
906  os.unindent() << "}";
907 
908  return success();
909 }
910 
911 static LogicalResult printOperation(CppEmitter &emitter, emitc::IfOp ifOp) {
912  raw_indented_ostream &os = emitter.ostream();
913 
914  // Helper function to emit all ops except the last one, expected to be
915  // emitc::yield.
916  auto emitAllExceptLast = [&emitter](Region &region) {
917  Region::OpIterator it = region.op_begin(), end = region.op_end();
918  for (; std::next(it) != end; ++it) {
919  if (failed(emitter.emitOperation(*it, /*trailingSemicolon=*/true)))
920  return failure();
921  }
922  assert(isa<emitc::YieldOp>(*it) &&
923  "Expected last operation in the region to be emitc::yield");
924  return success();
925  };
926 
927  os << "if (";
928  if (failed(emitter.emitOperand(ifOp.getCondition())))
929  return failure();
930  os << ") {\n";
931  os.indent();
932  if (failed(emitAllExceptLast(ifOp.getThenRegion())))
933  return failure();
934  os.unindent() << "}";
935 
936  Region &elseRegion = ifOp.getElseRegion();
937  if (!elseRegion.empty()) {
938  os << " else {\n";
939  os.indent();
940  if (failed(emitAllExceptLast(elseRegion)))
941  return failure();
942  os.unindent() << "}";
943  }
944 
945  return success();
946 }
947 
948 static LogicalResult printOperation(CppEmitter &emitter,
949  func::ReturnOp returnOp) {
950  raw_ostream &os = emitter.ostream();
951  os << "return";
952  switch (returnOp.getNumOperands()) {
953  case 0:
954  return success();
955  case 1:
956  os << " ";
957  if (failed(emitter.emitOperand(returnOp.getOperand(0))))
958  return failure();
959  return success();
960  default:
961  os << " std::make_tuple(";
962  if (failed(emitter.emitOperandsAndAttributes(*returnOp.getOperation())))
963  return failure();
964  os << ")";
965  return success();
966  }
967 }
968 
969 static LogicalResult printOperation(CppEmitter &emitter,
970  emitc::ReturnOp returnOp) {
971  raw_ostream &os = emitter.ostream();
972  os << "return";
973  if (returnOp.getNumOperands() == 0)
974  return success();
975 
976  os << " ";
977  if (failed(emitter.emitOperand(returnOp.getOperand())))
978  return failure();
979  return success();
980 }
981 
982 static LogicalResult printOperation(CppEmitter &emitter, ModuleOp moduleOp) {
983  CppEmitter::Scope scope(emitter);
984 
985  for (Operation &op : moduleOp) {
986  if (failed(emitter.emitOperation(op, /*trailingSemicolon=*/false)))
987  return failure();
988  }
989  return success();
990 }
991 
992 static LogicalResult printOperation(CppEmitter &emitter, FileOp file) {
993  if (!emitter.shouldEmitFile(file))
994  return success();
995 
996  CppEmitter::Scope scope(emitter);
997 
998  for (Operation &op : file) {
999  if (failed(emitter.emitOperation(op, /*trailingSemicolon=*/false)))
1000  return failure();
1001  }
1002  return success();
1003 }
1004 
1005 static LogicalResult printFunctionArgs(CppEmitter &emitter,
1006  Operation *functionOp,
1007  ArrayRef<Type> arguments) {
1008  raw_indented_ostream &os = emitter.ostream();
1009 
1010  return (
1011  interleaveCommaWithError(arguments, os, [&](Type arg) -> LogicalResult {
1012  return emitter.emitType(functionOp->getLoc(), arg);
1013  }));
1014 }
1015 
1016 static LogicalResult printFunctionArgs(CppEmitter &emitter,
1017  Operation *functionOp,
1018  Region::BlockArgListType arguments) {
1019  raw_indented_ostream &os = emitter.ostream();
1020 
1021  return (interleaveCommaWithError(
1022  arguments, os, [&](BlockArgument arg) -> LogicalResult {
1023  return emitter.emitVariableDeclaration(
1024  functionOp->getLoc(), arg.getType(), emitter.getOrCreateName(arg));
1025  }));
1026 }
1027 
1028 static LogicalResult printFunctionBody(CppEmitter &emitter,
1029  Operation *functionOp,
1030  Region::BlockListType &blocks) {
1031  raw_indented_ostream &os = emitter.ostream();
1032  os.indent();
1033 
1034  if (emitter.shouldDeclareVariablesAtTop()) {
1035  // Declare all variables that hold op results including those from nested
1036  // regions.
1037  WalkResult result =
1038  functionOp->walk<WalkOrder::PreOrder>([&](Operation *op) -> WalkResult {
1039  if (isa<emitc::ExpressionOp>(op->getParentOp()) ||
1040  (isa<emitc::ExpressionOp>(op) &&
1041  shouldBeInlined(cast<emitc::ExpressionOp>(op))))
1042  return WalkResult::skip();
1043  for (OpResult result : op->getResults()) {
1044  if (failed(emitter.emitVariableDeclaration(
1045  result, /*trailingSemicolon=*/true))) {
1046  return WalkResult(
1047  op->emitError("unable to declare result variable for op"));
1048  }
1049  }
1050  return WalkResult::advance();
1051  });
1052  if (result.wasInterrupted())
1053  return failure();
1054  }
1055 
1056  // Create label names for basic blocks.
1057  for (Block &block : blocks) {
1058  emitter.getOrCreateName(block);
1059  }
1060 
1061  // Declare variables for basic block arguments.
1062  for (Block &block : llvm::drop_begin(blocks)) {
1063  for (BlockArgument &arg : block.getArguments()) {
1064  if (emitter.hasValueInScope(arg))
1065  return functionOp->emitOpError(" block argument #")
1066  << arg.getArgNumber() << " is out of scope";
1067  if (isa<ArrayType, LValueType>(arg.getType()))
1068  return functionOp->emitOpError("cannot emit block argument #")
1069  << arg.getArgNumber() << " with type " << arg.getType();
1070  if (failed(
1071  emitter.emitType(block.getParentOp()->getLoc(), arg.getType()))) {
1072  return failure();
1073  }
1074  os << " " << emitter.getOrCreateName(arg) << ";\n";
1075  }
1076  }
1077 
1078  for (Block &block : blocks) {
1079  // Only print a label if the block has predecessors.
1080  if (!block.hasNoPredecessors()) {
1081  if (failed(emitter.emitLabel(block)))
1082  return failure();
1083  }
1084  for (Operation &op : block.getOperations()) {
1085  if (failed(emitter.emitOperation(op, /*trailingSemicolon=*/true)))
1086  return failure();
1087  }
1088  }
1089 
1090  os.unindent();
1091 
1092  return success();
1093 }
1094 
1095 static LogicalResult printOperation(CppEmitter &emitter,
1096  func::FuncOp functionOp) {
1097  // We need to declare variables at top if the function has multiple blocks.
1098  if (!emitter.shouldDeclareVariablesAtTop() &&
1099  functionOp.getBlocks().size() > 1) {
1100  return functionOp.emitOpError(
1101  "with multiple blocks needs variables declared at top");
1102  }
1103 
1104  if (llvm::any_of(functionOp.getArgumentTypes(), llvm::IsaPred<LValueType>)) {
1105  return functionOp.emitOpError()
1106  << "cannot emit lvalue type as argument type";
1107  }
1108 
1109  if (llvm::any_of(functionOp.getResultTypes(), llvm::IsaPred<ArrayType>)) {
1110  return functionOp.emitOpError() << "cannot emit array type as result type";
1111  }
1112 
1113  CppEmitter::Scope scope(emitter);
1114  raw_indented_ostream &os = emitter.ostream();
1115  if (failed(emitter.emitTypes(functionOp.getLoc(),
1116  functionOp.getFunctionType().getResults())))
1117  return failure();
1118  os << " " << functionOp.getName();
1119 
1120  os << "(";
1121  Operation *operation = functionOp.getOperation();
1122  if (failed(printFunctionArgs(emitter, operation, functionOp.getArguments())))
1123  return failure();
1124  os << ") {\n";
1125  if (failed(printFunctionBody(emitter, operation, functionOp.getBlocks())))
1126  return failure();
1127  os << "}\n";
1128 
1129  return success();
1130 }
1131 
1132 static LogicalResult printOperation(CppEmitter &emitter,
1133  emitc::FuncOp functionOp) {
1134  // We need to declare variables at top if the function has multiple blocks.
1135  if (!emitter.shouldDeclareVariablesAtTop() &&
1136  functionOp.getBlocks().size() > 1) {
1137  return functionOp.emitOpError(
1138  "with multiple blocks needs variables declared at top");
1139  }
1140 
1141  CppEmitter::Scope scope(emitter);
1142  raw_indented_ostream &os = emitter.ostream();
1143  if (functionOp.getSpecifiers()) {
1144  for (Attribute specifier : functionOp.getSpecifiersAttr()) {
1145  os << cast<StringAttr>(specifier).str() << " ";
1146  }
1147  }
1148 
1149  if (failed(emitter.emitTypes(functionOp.getLoc(),
1150  functionOp.getFunctionType().getResults())))
1151  return failure();
1152  os << " " << functionOp.getName();
1153 
1154  os << "(";
1155  Operation *operation = functionOp.getOperation();
1156  if (functionOp.isExternal()) {
1157  if (failed(printFunctionArgs(emitter, operation,
1158  functionOp.getArgumentTypes())))
1159  return failure();
1160  os << ");";
1161  return success();
1162  }
1163  if (failed(printFunctionArgs(emitter, operation, functionOp.getArguments())))
1164  return failure();
1165  os << ") {\n";
1166  if (failed(printFunctionBody(emitter, operation, functionOp.getBlocks())))
1167  return failure();
1168  os << "}\n";
1169 
1170  return success();
1171 }
1172 
1173 static LogicalResult printOperation(CppEmitter &emitter,
1174  DeclareFuncOp declareFuncOp) {
1175  CppEmitter::Scope scope(emitter);
1176  raw_indented_ostream &os = emitter.ostream();
1177 
1178  auto functionOp = SymbolTable::lookupNearestSymbolFrom<emitc::FuncOp>(
1179  declareFuncOp, declareFuncOp.getSymNameAttr());
1180 
1181  if (!functionOp)
1182  return failure();
1183 
1184  if (functionOp.getSpecifiers()) {
1185  for (Attribute specifier : functionOp.getSpecifiersAttr()) {
1186  os << cast<StringAttr>(specifier).str() << " ";
1187  }
1188  }
1189 
1190  if (failed(emitter.emitTypes(functionOp.getLoc(),
1191  functionOp.getFunctionType().getResults())))
1192  return failure();
1193  os << " " << functionOp.getName();
1194 
1195  os << "(";
1196  Operation *operation = functionOp.getOperation();
1197  if (failed(printFunctionArgs(emitter, operation, functionOp.getArguments())))
1198  return failure();
1199  os << ");";
1200 
1201  return success();
1202 }
1203 
1204 CppEmitter::CppEmitter(raw_ostream &os, bool declareVariablesAtTop,
1205  StringRef fileId)
1206  : os(os), declareVariablesAtTop(declareVariablesAtTop),
1207  fileId(fileId.str()) {
1208  valueInScopeCount.push(0);
1209  labelInScopeCount.push(0);
1210 }
1211 
1212 std::string CppEmitter::getSubscriptName(emitc::SubscriptOp op) {
1213  std::string out;
1214  llvm::raw_string_ostream ss(out);
1215  ss << getOrCreateName(op.getValue());
1216  for (auto index : op.getIndices()) {
1217  ss << "[" << getOrCreateName(index) << "]";
1218  }
1219  return out;
1220 }
1221 
1222 std::string CppEmitter::createMemberAccess(emitc::MemberOp op) {
1223  std::string out;
1224  llvm::raw_string_ostream ss(out);
1225  ss << getOrCreateName(op.getOperand());
1226  ss << "." << op.getMember();
1227  return out;
1228 }
1229 
1230 std::string CppEmitter::createMemberAccess(emitc::MemberOfPtrOp op) {
1231  std::string out;
1232  llvm::raw_string_ostream ss(out);
1233  ss << getOrCreateName(op.getOperand());
1234  ss << "->" << op.getMember();
1235  return out;
1236 }
1237 
1238 void CppEmitter::cacheDeferredOpResult(Value value, StringRef str) {
1239  if (!valueMapper.count(value))
1240  valueMapper.insert(value, str.str());
1241 }
1242 
1243 /// Return the existing or a new name for a Value.
1244 StringRef CppEmitter::getOrCreateName(Value val) {
1245  if (!valueMapper.count(val)) {
1246  assert(!hasDeferredEmission(val.getDefiningOp()) &&
1247  "cacheDeferredOpResult should have been called on this value, "
1248  "update the emitOperation function.");
1249  valueMapper.insert(val, formatv("v{0}", ++valueInScopeCount.top()));
1250  }
1251  return *valueMapper.begin(val);
1252 }
1253 
1254 /// Return the existing or a new label for a Block.
1255 StringRef CppEmitter::getOrCreateName(Block &block) {
1256  if (!blockMapper.count(&block))
1257  blockMapper.insert(&block, formatv("label{0}", ++labelInScopeCount.top()));
1258  return *blockMapper.begin(&block);
1259 }
1260 
1261 bool CppEmitter::shouldMapToUnsigned(IntegerType::SignednessSemantics val) {
1262  switch (val) {
1263  case IntegerType::Signless:
1264  return false;
1265  case IntegerType::Signed:
1266  return false;
1267  case IntegerType::Unsigned:
1268  return true;
1269  }
1270  llvm_unreachable("Unexpected IntegerType::SignednessSemantics");
1271 }
1272 
1273 bool CppEmitter::hasValueInScope(Value val) { return valueMapper.count(val); }
1274 
1275 bool CppEmitter::hasBlockLabel(Block &block) {
1276  return blockMapper.count(&block);
1277 }
1278 
1279 LogicalResult CppEmitter::emitAttribute(Location loc, Attribute attr) {
1280  auto printInt = [&](const APInt &val, bool isUnsigned) {
1281  if (val.getBitWidth() == 1) {
1282  if (val.getBoolValue())
1283  os << "true";
1284  else
1285  os << "false";
1286  } else {
1287  SmallString<128> strValue;
1288  val.toString(strValue, 10, !isUnsigned, false);
1289  os << strValue;
1290  }
1291  };
1292 
1293  auto printFloat = [&](const APFloat &val) {
1294  if (val.isFinite()) {
1295  SmallString<128> strValue;
1296  // Use default values of toString except don't truncate zeros.
1297  val.toString(strValue, 0, 0, false);
1298  os << strValue;
1299  switch (llvm::APFloatBase::SemanticsToEnum(val.getSemantics())) {
1300  case llvm::APFloatBase::S_IEEEhalf:
1301  os << "f16";
1302  break;
1303  case llvm::APFloatBase::S_BFloat:
1304  os << "bf16";
1305  break;
1306  case llvm::APFloatBase::S_IEEEsingle:
1307  os << "f";
1308  break;
1309  case llvm::APFloatBase::S_IEEEdouble:
1310  break;
1311  default:
1312  llvm_unreachable("unsupported floating point type");
1313  };
1314  } else if (val.isNaN()) {
1315  os << "NAN";
1316  } else if (val.isInfinity()) {
1317  if (val.isNegative())
1318  os << "-";
1319  os << "INFINITY";
1320  }
1321  };
1322 
1323  // Print floating point attributes.
1324  if (auto fAttr = dyn_cast<FloatAttr>(attr)) {
1325  if (!isa<Float16Type, BFloat16Type, Float32Type, Float64Type>(
1326  fAttr.getType())) {
1327  return emitError(
1328  loc, "expected floating point attribute to be f16, bf16, f32 or f64");
1329  }
1330  printFloat(fAttr.getValue());
1331  return success();
1332  }
1333  if (auto dense = dyn_cast<DenseFPElementsAttr>(attr)) {
1334  if (!isa<Float16Type, BFloat16Type, Float32Type, Float64Type>(
1335  dense.getElementType())) {
1336  return emitError(
1337  loc, "expected floating point attribute to be f16, bf16, f32 or f64");
1338  }
1339  os << '{';
1340  interleaveComma(dense, os, [&](const APFloat &val) { printFloat(val); });
1341  os << '}';
1342  return success();
1343  }
1344 
1345  // Print integer attributes.
1346  if (auto iAttr = dyn_cast<IntegerAttr>(attr)) {
1347  if (auto iType = dyn_cast<IntegerType>(iAttr.getType())) {
1348  printInt(iAttr.getValue(), shouldMapToUnsigned(iType.getSignedness()));
1349  return success();
1350  }
1351  if (auto iType = dyn_cast<IndexType>(iAttr.getType())) {
1352  printInt(iAttr.getValue(), false);
1353  return success();
1354  }
1355  }
1356  if (auto dense = dyn_cast<DenseIntElementsAttr>(attr)) {
1357  if (auto iType = dyn_cast<IntegerType>(
1358  cast<TensorType>(dense.getType()).getElementType())) {
1359  os << '{';
1360  interleaveComma(dense, os, [&](const APInt &val) {
1361  printInt(val, shouldMapToUnsigned(iType.getSignedness()));
1362  });
1363  os << '}';
1364  return success();
1365  }
1366  if (auto iType = dyn_cast<IndexType>(
1367  cast<TensorType>(dense.getType()).getElementType())) {
1368  os << '{';
1369  interleaveComma(dense, os,
1370  [&](const APInt &val) { printInt(val, false); });
1371  os << '}';
1372  return success();
1373  }
1374  }
1375 
1376  // Print opaque attributes.
1377  if (auto oAttr = dyn_cast<emitc::OpaqueAttr>(attr)) {
1378  os << oAttr.getValue();
1379  return success();
1380  }
1381 
1382  // Print symbolic reference attributes.
1383  if (auto sAttr = dyn_cast<SymbolRefAttr>(attr)) {
1384  if (sAttr.getNestedReferences().size() > 1)
1385  return emitError(loc, "attribute has more than 1 nested reference");
1386  os << sAttr.getRootReference().getValue();
1387  return success();
1388  }
1389 
1390  // Print type attributes.
1391  if (auto type = dyn_cast<TypeAttr>(attr))
1392  return emitType(loc, type.getValue());
1393 
1394  return emitError(loc, "cannot emit attribute: ") << attr;
1395 }
1396 
1397 LogicalResult CppEmitter::emitExpression(ExpressionOp expressionOp) {
1398  assert(emittedExpressionPrecedence.empty() &&
1399  "Expected precedence stack to be empty");
1400  Operation *rootOp = expressionOp.getRootOp();
1401 
1402  emittedExpression = expressionOp;
1403  FailureOr<int> precedence = getOperatorPrecedence(rootOp);
1404  if (failed(precedence))
1405  return failure();
1406  pushExpressionPrecedence(precedence.value());
1407 
1408  if (failed(emitOperation(*rootOp, /*trailingSemicolon=*/false)))
1409  return failure();
1410 
1411  popExpressionPrecedence();
1412  assert(emittedExpressionPrecedence.empty() &&
1413  "Expected precedence stack to be empty");
1414  emittedExpression = nullptr;
1415 
1416  return success();
1417 }
1418 
1419 LogicalResult CppEmitter::emitOperand(Value value) {
1420  if (isPartOfCurrentExpression(value)) {
1421  Operation *def = value.getDefiningOp();
1422  assert(def && "Expected operand to be defined by an operation");
1423  FailureOr<int> precedence = getOperatorPrecedence(def);
1424  if (failed(precedence))
1425  return failure();
1426 
1427  // Sub-expressions with equal or lower precedence need to be parenthesized,
1428  // as they might be evaluated in the wrong order depending on the shape of
1429  // the expression tree.
1430  bool encloseInParenthesis = precedence.value() <= getExpressionPrecedence();
1431  if (encloseInParenthesis)
1432  os << "(";
1433  pushExpressionPrecedence(precedence.value());
1434 
1435  if (failed(emitOperation(*def, /*trailingSemicolon=*/false)))
1436  return failure();
1437 
1438  if (encloseInParenthesis)
1439  os << ")";
1440 
1441  popExpressionPrecedence();
1442  return success();
1443  }
1444 
1445  auto expressionOp = dyn_cast_if_present<ExpressionOp>(value.getDefiningOp());
1446  if (expressionOp && shouldBeInlined(expressionOp))
1447  return emitExpression(expressionOp);
1448 
1449  os << getOrCreateName(value);
1450  return success();
1451 }
1452 
1453 LogicalResult CppEmitter::emitOperands(Operation &op) {
1454  return interleaveCommaWithError(op.getOperands(), os, [&](Value operand) {
1455  // If an expression is being emitted, push lowest precedence as these
1456  // operands are either wrapped by parenthesis.
1457  if (getEmittedExpression())
1458  pushExpressionPrecedence(lowestPrecedence());
1459  if (failed(emitOperand(operand)))
1460  return failure();
1461  if (getEmittedExpression())
1462  popExpressionPrecedence();
1463  return success();
1464  });
1465 }
1466 
1467 LogicalResult
1468 CppEmitter::emitOperandsAndAttributes(Operation &op,
1469  ArrayRef<StringRef> exclude) {
1470  if (failed(emitOperands(op)))
1471  return failure();
1472  // Insert comma in between operands and non-filtered attributes if needed.
1473  if (op.getNumOperands() > 0) {
1474  for (NamedAttribute attr : op.getAttrs()) {
1475  if (!llvm::is_contained(exclude, attr.getName().strref())) {
1476  os << ", ";
1477  break;
1478  }
1479  }
1480  }
1481  // Emit attributes.
1482  auto emitNamedAttribute = [&](NamedAttribute attr) -> LogicalResult {
1483  if (llvm::is_contained(exclude, attr.getName().strref()))
1484  return success();
1485  os << "/* " << attr.getName().getValue() << " */";
1486  if (failed(emitAttribute(op.getLoc(), attr.getValue())))
1487  return failure();
1488  return success();
1489  };
1490  return interleaveCommaWithError(op.getAttrs(), os, emitNamedAttribute);
1491 }
1492 
1493 LogicalResult CppEmitter::emitVariableAssignment(OpResult result) {
1494  if (!hasValueInScope(result)) {
1495  return result.getDefiningOp()->emitOpError(
1496  "result variable for the operation has not been declared");
1497  }
1498  os << getOrCreateName(result) << " = ";
1499  return success();
1500 }
1501 
1502 LogicalResult CppEmitter::emitVariableDeclaration(OpResult result,
1503  bool trailingSemicolon) {
1504  if (hasDeferredEmission(result.getDefiningOp()))
1505  return success();
1506  if (hasValueInScope(result)) {
1507  return result.getDefiningOp()->emitError(
1508  "result variable for the operation already declared");
1509  }
1510  if (failed(emitVariableDeclaration(result.getOwner()->getLoc(),
1511  result.getType(),
1512  getOrCreateName(result))))
1513  return failure();
1514  if (trailingSemicolon)
1515  os << ";\n";
1516  return success();
1517 }
1518 
1519 LogicalResult CppEmitter::emitGlobalVariable(GlobalOp op) {
1520  if (op.getExternSpecifier())
1521  os << "extern ";
1522  else if (op.getStaticSpecifier())
1523  os << "static ";
1524  if (op.getConstSpecifier())
1525  os << "const ";
1526 
1527  if (failed(emitVariableDeclaration(op->getLoc(), op.getType(),
1528  op.getSymName()))) {
1529  return failure();
1530  }
1531 
1532  std::optional<Attribute> initialValue = op.getInitialValue();
1533  if (initialValue) {
1534  os << " = ";
1535  if (failed(emitAttribute(op->getLoc(), *initialValue)))
1536  return failure();
1537  }
1538 
1539  os << ";";
1540  return success();
1541 }
1542 
1543 LogicalResult CppEmitter::emitAssignPrefix(Operation &op) {
1544  // If op is being emitted as part of an expression, bail out.
1545  if (getEmittedExpression())
1546  return success();
1547 
1548  switch (op.getNumResults()) {
1549  case 0:
1550  break;
1551  case 1: {
1552  OpResult result = op.getResult(0);
1553  if (shouldDeclareVariablesAtTop()) {
1554  if (failed(emitVariableAssignment(result)))
1555  return failure();
1556  } else {
1557  if (failed(emitVariableDeclaration(result, /*trailingSemicolon=*/false)))
1558  return failure();
1559  os << " = ";
1560  }
1561  break;
1562  }
1563  default:
1564  if (!shouldDeclareVariablesAtTop()) {
1565  for (OpResult result : op.getResults()) {
1566  if (failed(emitVariableDeclaration(result, /*trailingSemicolon=*/true)))
1567  return failure();
1568  }
1569  }
1570  os << "std::tie(";
1571  interleaveComma(op.getResults(), os,
1572  [&](Value result) { os << getOrCreateName(result); });
1573  os << ") = ";
1574  }
1575  return success();
1576 }
1577 
1578 LogicalResult CppEmitter::emitLabel(Block &block) {
1579  if (!hasBlockLabel(block))
1580  return block.getParentOp()->emitError("label for block not found");
1581  // FIXME: Add feature in `raw_indented_ostream` to ignore indent for block
1582  // label instead of using `getOStream`.
1583  os.getOStream() << getOrCreateName(block) << ":\n";
1584  return success();
1585 }
1586 
1587 LogicalResult CppEmitter::emitOperation(Operation &op, bool trailingSemicolon) {
1588  LogicalResult status =
1590  // Builtin ops.
1591  .Case<ModuleOp>([&](auto op) { return printOperation(*this, op); })
1592  // CF ops.
1593  .Case<cf::BranchOp, cf::CondBranchOp>(
1594  [&](auto op) { return printOperation(*this, op); })
1595  // EmitC ops.
1596  .Case<emitc::AddOp, emitc::ApplyOp, emitc::AssignOp,
1597  emitc::BitwiseAndOp, emitc::BitwiseLeftShiftOp,
1598  emitc::BitwiseNotOp, emitc::BitwiseOrOp,
1599  emitc::BitwiseRightShiftOp, emitc::BitwiseXorOp, emitc::CallOp,
1600  emitc::CallOpaqueOp, emitc::CastOp, emitc::CmpOp,
1601  emitc::ConditionalOp, emitc::ConstantOp, emitc::DeclareFuncOp,
1602  emitc::DivOp, emitc::ExpressionOp, emitc::FileOp, emitc::ForOp,
1603  emitc::FuncOp, emitc::GlobalOp, emitc::IfOp, emitc::IncludeOp,
1604  emitc::LoadOp, emitc::LogicalAndOp, emitc::LogicalNotOp,
1605  emitc::LogicalOrOp, emitc::MulOp, emitc::RemOp, emitc::ReturnOp,
1606  emitc::SubOp, emitc::SwitchOp, emitc::UnaryMinusOp,
1607  emitc::UnaryPlusOp, emitc::VariableOp, emitc::VerbatimOp>(
1608 
1609  [&](auto op) { return printOperation(*this, op); })
1610  // Func ops.
1611  .Case<func::CallOp, func::FuncOp, func::ReturnOp>(
1612  [&](auto op) { return printOperation(*this, op); })
1613  .Case<emitc::GetGlobalOp>([&](auto op) {
1614  cacheDeferredOpResult(op.getResult(), op.getName());
1615  return success();
1616  })
1617  .Case<emitc::LiteralOp>([&](auto op) {
1618  cacheDeferredOpResult(op.getResult(), op.getValue());
1619  return success();
1620  })
1621  .Case<emitc::MemberOp>([&](auto op) {
1622  cacheDeferredOpResult(op.getResult(), createMemberAccess(op));
1623  return success();
1624  })
1625  .Case<emitc::MemberOfPtrOp>([&](auto op) {
1626  cacheDeferredOpResult(op.getResult(), createMemberAccess(op));
1627  return success();
1628  })
1629  .Case<emitc::SubscriptOp>([&](auto op) {
1630  cacheDeferredOpResult(op.getResult(), getSubscriptName(op));
1631  return success();
1632  })
1633  .Default([&](Operation *) {
1634  return op.emitOpError("unable to find printer for op");
1635  });
1636 
1637  if (failed(status))
1638  return failure();
1639 
1640  if (hasDeferredEmission(&op))
1641  return success();
1642 
1643  if (getEmittedExpression() ||
1644  (isa<emitc::ExpressionOp>(op) &&
1645  shouldBeInlined(cast<emitc::ExpressionOp>(op))))
1646  return success();
1647 
1648  // Never emit a semicolon for some operations, especially if endening with
1649  // `}`.
1650  trailingSemicolon &=
1651  !isa<cf::CondBranchOp, emitc::DeclareFuncOp, emitc::FileOp, emitc::ForOp,
1652  emitc::IfOp, emitc::IncludeOp, emitc::SwitchOp, emitc::VerbatimOp>(
1653  op);
1654 
1655  os << (trailingSemicolon ? ";\n" : "\n");
1656 
1657  return success();
1658 }
1659 
1660 LogicalResult CppEmitter::emitVariableDeclaration(Location loc, Type type,
1661  StringRef name) {
1662  if (auto arrType = dyn_cast<emitc::ArrayType>(type)) {
1663  if (failed(emitType(loc, arrType.getElementType())))
1664  return failure();
1665  os << " " << name;
1666  for (auto dim : arrType.getShape()) {
1667  os << "[" << dim << "]";
1668  }
1669  return success();
1670  }
1671  if (failed(emitType(loc, type)))
1672  return failure();
1673  os << " " << name;
1674  return success();
1675 }
1676 
1677 LogicalResult CppEmitter::emitType(Location loc, Type type) {
1678  if (auto iType = dyn_cast<IntegerType>(type)) {
1679  switch (iType.getWidth()) {
1680  case 1:
1681  return (os << "bool"), success();
1682  case 8:
1683  case 16:
1684  case 32:
1685  case 64:
1686  if (shouldMapToUnsigned(iType.getSignedness()))
1687  return (os << "uint" << iType.getWidth() << "_t"), success();
1688  else
1689  return (os << "int" << iType.getWidth() << "_t"), success();
1690  default:
1691  return emitError(loc, "cannot emit integer type ") << type;
1692  }
1693  }
1694  if (auto fType = dyn_cast<FloatType>(type)) {
1695  switch (fType.getWidth()) {
1696  case 16: {
1697  if (llvm::isa<Float16Type>(type))
1698  return (os << "_Float16"), success();
1699  else if (llvm::isa<BFloat16Type>(type))
1700  return (os << "__bf16"), success();
1701  else
1702  return emitError(loc, "cannot emit float type ") << type;
1703  }
1704  case 32:
1705  return (os << "float"), success();
1706  case 64:
1707  return (os << "double"), success();
1708  default:
1709  return emitError(loc, "cannot emit float type ") << type;
1710  }
1711  }
1712  if (auto iType = dyn_cast<IndexType>(type))
1713  return (os << "size_t"), success();
1714  if (auto sType = dyn_cast<emitc::SizeTType>(type))
1715  return (os << "size_t"), success();
1716  if (auto sType = dyn_cast<emitc::SignedSizeTType>(type))
1717  return (os << "ssize_t"), success();
1718  if (auto pType = dyn_cast<emitc::PtrDiffTType>(type))
1719  return (os << "ptrdiff_t"), success();
1720  if (auto tType = dyn_cast<TensorType>(type)) {
1721  if (!tType.hasRank())
1722  return emitError(loc, "cannot emit unranked tensor type");
1723  if (!tType.hasStaticShape())
1724  return emitError(loc, "cannot emit tensor type with non static shape");
1725  os << "Tensor<";
1726  if (isa<ArrayType>(tType.getElementType()))
1727  return emitError(loc, "cannot emit tensor of array type ") << type;
1728  if (failed(emitType(loc, tType.getElementType())))
1729  return failure();
1730  auto shape = tType.getShape();
1731  for (auto dimSize : shape) {
1732  os << ", ";
1733  os << dimSize;
1734  }
1735  os << ">";
1736  return success();
1737  }
1738  if (auto tType = dyn_cast<TupleType>(type))
1739  return emitTupleType(loc, tType.getTypes());
1740  if (auto oType = dyn_cast<emitc::OpaqueType>(type)) {
1741  os << oType.getValue();
1742  return success();
1743  }
1744  if (auto aType = dyn_cast<emitc::ArrayType>(type)) {
1745  if (failed(emitType(loc, aType.getElementType())))
1746  return failure();
1747  for (auto dim : aType.getShape())
1748  os << "[" << dim << "]";
1749  return success();
1750  }
1751  if (auto lType = dyn_cast<emitc::LValueType>(type))
1752  return emitType(loc, lType.getValueType());
1753  if (auto pType = dyn_cast<emitc::PointerType>(type)) {
1754  if (isa<ArrayType>(pType.getPointee()))
1755  return emitError(loc, "cannot emit pointer to array type ") << type;
1756  if (failed(emitType(loc, pType.getPointee())))
1757  return failure();
1758  os << "*";
1759  return success();
1760  }
1761  return emitError(loc, "cannot emit type ") << type;
1762 }
1763 
1764 LogicalResult CppEmitter::emitTypes(Location loc, ArrayRef<Type> types) {
1765  switch (types.size()) {
1766  case 0:
1767  os << "void";
1768  return success();
1769  case 1:
1770  return emitType(loc, types.front());
1771  default:
1772  return emitTupleType(loc, types);
1773  }
1774 }
1775 
1776 LogicalResult CppEmitter::emitTupleType(Location loc, ArrayRef<Type> types) {
1777  if (llvm::any_of(types, llvm::IsaPred<ArrayType>)) {
1778  return emitError(loc, "cannot emit tuple of array type");
1779  }
1780  os << "std::tuple<";
1781  if (failed(interleaveCommaWithError(
1782  types, os, [&](Type type) { return emitType(loc, type); })))
1783  return failure();
1784  os << ">";
1785  return success();
1786 }
1787 
1788 LogicalResult emitc::translateToCpp(Operation *op, raw_ostream &os,
1789  bool declareVariablesAtTop,
1790  StringRef fileId) {
1791  CppEmitter emitter(os, declareVariablesAtTop, fileId);
1792  return emitter.emitOperation(*op, /*trailingSemicolon=*/false);
1793 }
static LogicalResult printCallOperation(CppEmitter &emitter, Operation *callOp, StringRef callee)
static bool shouldBeInlined(ExpressionOp expressionOp)
Determine whether expression expressionOp should be emitted inline, i.e.
LogicalResult interleaveCommaWithError(const Container &c, raw_ostream &os, UnaryFunctor eachFn)
static FailureOr< int > getOperatorPrecedence(Operation *operation)
Return the precedence of a operator as an integer, higher values imply higher precedence.
static LogicalResult printFunctionArgs(CppEmitter &emitter, Operation *functionOp, ArrayRef< Type > arguments)
static LogicalResult printFunctionBody(CppEmitter &emitter, Operation *functionOp, Region::BlockListType &blocks)
static LogicalResult printConstantOp(CppEmitter &emitter, Operation *operation, Attribute value)
static LogicalResult emitSwitchCase(CppEmitter &emitter, raw_indented_ostream &os, Region &region)
LogicalResult interleaveWithError(ForwardIterator begin, ForwardIterator end, UnaryFunctor eachFn, NullaryFunctor betweenFn)
Convenience functions to produce interleaved output with functions returning a LogicalResult.
static LogicalResult printBinaryOperation(CppEmitter &emitter, Operation *operation, StringRef binaryOperator)
static LogicalResult printOperation(CppEmitter &emitter, emitc::ConstantOp constantOp)
static LogicalResult printUnaryOperation(CppEmitter &emitter, Operation *operation, StringRef unaryOperator)
static bool hasDeferredEmission(Operation *op)
Determine whether expression op should be emitted in a deferred way.
Attributes are known-constant values of operations.
Definition: Attributes.h:25
This class represents an argument of a Block.
Definition: Value.h:295
Block represents an ordered list of Operations.
Definition: Block.h:33
BlockArgListType getArguments()
Definition: Block.h:87
Block * getSuccessor(unsigned i)
Definition: Block.cpp:261
Operation * getParentOp()
Returns the closest surrounding operation that contains this block.
Definition: Block.cpp:33
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:66
NamedAttribute represents a combination of a name and an Attribute value.
Definition: Attributes.h:164
This is a value defined by a result of an operation.
Definition: Value.h:433
Operation * getOwner() const
Returns the operation that owns this result.
Definition: Value.h:442
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
Value getOperand(unsigned idx)
Definition: Operation.h:350
bool hasTrait()
Returns true if the operation was registered with a particular trait, e.g.
Definition: Operation.h:750
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Definition: Operation.h:407
std::enable_if_t< llvm::function_traits< std::decay_t< FnT > >::num_args==1, RetT > walk(FnT &&callback)
Walk the operation by calling the callback for each nested operation (including this one),...
Definition: Operation.h:798
Location getLoc()
The source location the operation was defined or derived from.
Definition: Operation.h:223
unsigned getNumOperands()
Definition: Operation.h:346
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
Definition: Operation.h:234
ArrayRef< NamedAttribute > getAttrs()
Return all of the attributes on this operation.
Definition: Operation.h:512
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
Definition: Operation.cpp:268
OperationName getName()
The name of an operation is the key identifier for it.
Definition: Operation.h:119
operand_range getOperands()
Returns an iterator on the underlying Value's.
Definition: Operation.h:378
result_range getResults()
Definition: Operation.h:415
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
Definition: Operation.cpp:671
unsigned getNumResults()
Return the number of results held by this operation.
Definition: Operation.h:404
This class provides iteration over the held operations of blocks directly within a region.
Definition: Region.h:134
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition: Region.h:26
iterator_range< OpIterator > getOps()
Definition: Region.h:172
llvm::iplist< Block > BlockListType
Definition: Region.h:44
OpIterator op_begin()
Return iterators that walk the operations nested directly within this region.
Definition: Region.h:170
bool empty()
Definition: Region.h:60
MutableArrayRef< BlockArgument > BlockArgListType
Definition: Region.h:80
OpIterator op_end()
Definition: Region.h:171
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:105
user_range getUsers() const
Definition: Value.h:204
bool hasOneUse() const
Returns true if this value has exactly one use.
Definition: Value.h:191
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition: Value.cpp:20
A utility result that is used to signal how to proceed with an ongoing walk:
Definition: Visitors.h:33
static WalkResult skip()
Definition: Visitors.h:52
static WalkResult advance()
Definition: Visitors.h:51
bool wasInterrupted() const
Returns true if the walk was interrupted.
Definition: Visitors.h:55
raw_ostream subclass that simplifies indention a sequence of code.
raw_indented_ostream & unindent()
Decreases the indent and returning this raw_indented_ostream.
raw_indented_ostream & indent()
Increases the indent and returning this raw_indented_ostream.
LogicalResult translateToCpp(Operation *op, raw_ostream &os, bool declareVariablesAtTop=false, StringRef fileId={})
Translates the given operation to C++ code.
std::variant< StringRef, Placeholder > ReplacementItem
Definition: EmitC.h:54
Include the generated interface declarations.
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
This iterator enumerates the elements in "forward" order.
Definition: Visitors.h:65