MLIR  21.0.0git
TranslateToCpp.cpp
Go to the documentation of this file.
1 //===- TranslateToCpp.cpp - Translating to C++ calls ----------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
12 #include "mlir/IR/BuiltinOps.h"
13 #include "mlir/IR/BuiltinTypes.h"
14 #include "mlir/IR/Dialect.h"
15 #include "mlir/IR/Operation.h"
16 #include "mlir/IR/SymbolTable.h"
18 #include "mlir/Support/LLVM.h"
20 #include "llvm/ADT/DenseMap.h"
21 #include "llvm/ADT/ScopedHashTable.h"
22 #include "llvm/ADT/StringExtras.h"
23 #include "llvm/ADT/StringMap.h"
24 #include "llvm/ADT/TypeSwitch.h"
25 #include "llvm/Support/Debug.h"
26 #include "llvm/Support/FormatVariadic.h"
27 #include <stack>
28 #include <utility>
29 
30 #define DEBUG_TYPE "translate-to-cpp"
31 
32 using namespace mlir;
33 using namespace mlir::emitc;
34 using llvm::formatv;
35 
36 /// Convenience functions to produce interleaved output with functions returning
37 /// a LogicalResult. This is different than those in STLExtras as functions used
38 /// on each element doesn't return a string.
39 template <typename ForwardIterator, typename UnaryFunctor,
40  typename NullaryFunctor>
41 inline LogicalResult
43  UnaryFunctor eachFn, NullaryFunctor betweenFn) {
44  if (begin == end)
45  return success();
46  if (failed(eachFn(*begin)))
47  return failure();
48  ++begin;
49  for (; begin != end; ++begin) {
50  betweenFn();
51  if (failed(eachFn(*begin)))
52  return failure();
53  }
54  return success();
55 }
56 
57 template <typename Container, typename UnaryFunctor, typename NullaryFunctor>
58 inline LogicalResult interleaveWithError(const Container &c,
59  UnaryFunctor eachFn,
60  NullaryFunctor betweenFn) {
61  return interleaveWithError(c.begin(), c.end(), eachFn, betweenFn);
62 }
63 
64 template <typename Container, typename UnaryFunctor>
65 inline LogicalResult interleaveCommaWithError(const Container &c,
66  raw_ostream &os,
67  UnaryFunctor eachFn) {
68  return interleaveWithError(c.begin(), c.end(), eachFn, [&]() { os << ", "; });
69 }
70 
71 /// Return the precedence of a operator as an integer, higher values
72 /// imply higher precedence.
73 static FailureOr<int> getOperatorPrecedence(Operation *operation) {
75  .Case<emitc::AddOp>([&](auto op) { return 12; })
76  .Case<emitc::ApplyOp>([&](auto op) { return 15; })
77  .Case<emitc::BitwiseAndOp>([&](auto op) { return 7; })
78  .Case<emitc::BitwiseLeftShiftOp>([&](auto op) { return 11; })
79  .Case<emitc::BitwiseNotOp>([&](auto op) { return 15; })
80  .Case<emitc::BitwiseOrOp>([&](auto op) { return 5; })
81  .Case<emitc::BitwiseRightShiftOp>([&](auto op) { return 11; })
82  .Case<emitc::BitwiseXorOp>([&](auto op) { return 6; })
83  .Case<emitc::CallOp>([&](auto op) { return 16; })
84  .Case<emitc::CallOpaqueOp>([&](auto op) { return 16; })
85  .Case<emitc::CastOp>([&](auto op) { return 15; })
86  .Case<emitc::CmpOp>([&](auto op) -> FailureOr<int> {
87  switch (op.getPredicate()) {
88  case emitc::CmpPredicate::eq:
89  case emitc::CmpPredicate::ne:
90  return 8;
91  case emitc::CmpPredicate::lt:
92  case emitc::CmpPredicate::le:
93  case emitc::CmpPredicate::gt:
94  case emitc::CmpPredicate::ge:
95  return 9;
96  case emitc::CmpPredicate::three_way:
97  return 10;
98  }
99  return op->emitError("unsupported cmp predicate");
100  })
101  .Case<emitc::ConditionalOp>([&](auto op) { return 2; })
102  .Case<emitc::DivOp>([&](auto op) { return 13; })
103  .Case<emitc::LogicalAndOp>([&](auto op) { return 4; })
104  .Case<emitc::LogicalNotOp>([&](auto op) { return 15; })
105  .Case<emitc::LogicalOrOp>([&](auto op) { return 3; })
106  .Case<emitc::MulOp>([&](auto op) { return 13; })
107  .Case<emitc::RemOp>([&](auto op) { return 13; })
108  .Case<emitc::SubOp>([&](auto op) { return 12; })
109  .Case<emitc::UnaryMinusOp>([&](auto op) { return 15; })
110  .Case<emitc::UnaryPlusOp>([&](auto op) { return 15; })
111  .Default([](auto op) { return op->emitError("unsupported operation"); });
112 }
113 
114 namespace {
115 /// Emitter that uses dialect specific emitters to emit C++ code.
116 struct CppEmitter {
117  explicit CppEmitter(raw_ostream &os, bool declareVariablesAtTop,
118  StringRef fileId);
119 
120  /// Emits attribute or returns failure.
121  LogicalResult emitAttribute(Location loc, Attribute attr);
122 
123  /// Emits operation 'op' with/without training semicolon or returns failure.
124  ///
125  /// For operations that should never be followed by a semicolon, like ForOp,
126  /// the `trailingSemicolon` argument is ignored and a semicolon is not
127  /// emitted.
128  LogicalResult emitOperation(Operation &op, bool trailingSemicolon);
129 
130  /// Emits type 'type' or returns failure.
131  LogicalResult emitType(Location loc, Type type);
132 
133  /// Emits array of types as a std::tuple of the emitted types.
134  /// - emits void for an empty array;
135  /// - emits the type of the only element for arrays of size one;
136  /// - emits a std::tuple otherwise;
137  LogicalResult emitTypes(Location loc, ArrayRef<Type> types);
138 
139  /// Emits array of types as a std::tuple of the emitted types independently of
140  /// the array size.
141  LogicalResult emitTupleType(Location loc, ArrayRef<Type> types);
142 
143  /// Emits an assignment for a variable which has been declared previously.
144  LogicalResult emitVariableAssignment(OpResult result);
145 
146  /// Emits a variable declaration for a result of an operation.
147  LogicalResult emitVariableDeclaration(OpResult result,
148  bool trailingSemicolon);
149 
150  /// Emits a declaration of a variable with the given type and name.
151  LogicalResult emitVariableDeclaration(Location loc, Type type,
152  StringRef name);
153 
154  /// Emits the variable declaration and assignment prefix for 'op'.
155  /// - emits separate variable followed by std::tie for multi-valued operation;
156  /// - emits single type followed by variable for single result;
157  /// - emits nothing if no value produced by op;
158  /// Emits final '=' operator where a type is produced. Returns failure if
159  /// any result type could not be converted.
160  LogicalResult emitAssignPrefix(Operation &op);
161 
162  /// Emits a global variable declaration or definition.
163  LogicalResult emitGlobalVariable(GlobalOp op);
164 
165  /// Emits a label for the block.
166  LogicalResult emitLabel(Block &block);
167 
168  /// Emits the operands and atttributes of the operation. All operands are
169  /// emitted first and then all attributes in alphabetical order.
170  LogicalResult emitOperandsAndAttributes(Operation &op,
171  ArrayRef<StringRef> exclude = {});
172 
173  /// Emits the operands of the operation. All operands are emitted in order.
174  LogicalResult emitOperands(Operation &op);
175 
176  /// Emits value as an operands of an operation
177  LogicalResult emitOperand(Value value);
178 
179  /// Emit an expression as a C expression.
180  LogicalResult emitExpression(ExpressionOp expressionOp);
181 
182  /// Insert the expression representing the operation into the value cache.
183  void cacheDeferredOpResult(Value value, StringRef str);
184 
185  /// Return the existing or a new name for a Value.
186  StringRef getOrCreateName(Value val);
187 
188  // Returns the textual representation of a subscript operation.
189  std::string getSubscriptName(emitc::SubscriptOp op);
190 
191  // Returns the textual representation of a member (of object) operation.
192  std::string createMemberAccess(emitc::MemberOp op);
193 
194  // Returns the textual representation of a member of pointer operation.
195  std::string createMemberAccess(emitc::MemberOfPtrOp op);
196 
197  /// Return the existing or a new label of a Block.
198  StringRef getOrCreateName(Block &block);
199 
200  /// Whether to map an mlir integer to a unsigned integer in C++.
201  bool shouldMapToUnsigned(IntegerType::SignednessSemantics val);
202 
203  /// RAII helper function to manage entering/exiting C++ scopes.
204  struct Scope {
205  Scope(CppEmitter &emitter)
206  : valueMapperScope(emitter.valueMapper),
207  blockMapperScope(emitter.blockMapper), emitter(emitter) {
208  emitter.valueInScopeCount.push(emitter.valueInScopeCount.top());
209  emitter.labelInScopeCount.push(emitter.labelInScopeCount.top());
210  }
211  ~Scope() {
212  emitter.valueInScopeCount.pop();
213  emitter.labelInScopeCount.pop();
214  }
215 
216  private:
217  llvm::ScopedHashTableScope<Value, std::string> valueMapperScope;
218  llvm::ScopedHashTableScope<Block *, std::string> blockMapperScope;
219  CppEmitter &emitter;
220  };
221 
222  /// Returns wether the Value is assigned to a C++ variable in the scope.
223  bool hasValueInScope(Value val);
224 
225  // Returns whether a label is assigned to the block.
226  bool hasBlockLabel(Block &block);
227 
228  /// Returns the output stream.
229  raw_indented_ostream &ostream() { return os; };
230 
231  /// Returns if all variables for op results and basic block arguments need to
232  /// be declared at the beginning of a function.
233  bool shouldDeclareVariablesAtTop() { return declareVariablesAtTop; };
234 
235  /// Returns whether this file op should be emitted
236  bool shouldEmitFile(FileOp file) {
237  return !fileId.empty() && file.getId() == fileId;
238  }
239 
240  /// Get expression currently being emitted.
241  ExpressionOp getEmittedExpression() { return emittedExpression; }
242 
243  /// Determine whether given value is part of the expression potentially being
244  /// emitted.
245  bool isPartOfCurrentExpression(Value value) {
246  if (!emittedExpression)
247  return false;
248  Operation *def = value.getDefiningOp();
249  if (!def)
250  return false;
251  auto operandExpression = dyn_cast<ExpressionOp>(def->getParentOp());
252  return operandExpression == emittedExpression;
253  };
254 
255 private:
256  using ValueMapper = llvm::ScopedHashTable<Value, std::string>;
257  using BlockMapper = llvm::ScopedHashTable<Block *, std::string>;
258 
259  /// Output stream to emit to.
261 
262  /// Boolean to enforce that all variables for op results and block
263  /// arguments are declared at the beginning of the function. This also
264  /// includes results from ops located in nested regions.
265  bool declareVariablesAtTop;
266 
267  /// Only emit file ops whos id matches this value.
268  std::string fileId;
269 
270  /// Map from value to name of C++ variable that contain the name.
271  ValueMapper valueMapper;
272 
273  /// Map from block to name of C++ label.
274  BlockMapper blockMapper;
275 
276  /// The number of values in the current scope. This is used to declare the
277  /// names of values in a scope.
278  std::stack<int64_t> valueInScopeCount;
279  std::stack<int64_t> labelInScopeCount;
280 
281  /// State of the current expression being emitted.
282  ExpressionOp emittedExpression;
283  SmallVector<int> emittedExpressionPrecedence;
284 
285  void pushExpressionPrecedence(int precedence) {
286  emittedExpressionPrecedence.push_back(precedence);
287  }
288  void popExpressionPrecedence() { emittedExpressionPrecedence.pop_back(); }
289  static int lowestPrecedence() { return 0; }
290  int getExpressionPrecedence() {
291  if (emittedExpressionPrecedence.empty())
292  return lowestPrecedence();
293  return emittedExpressionPrecedence.back();
294  }
295 };
296 } // namespace
297 
298 /// Determine whether expression \p op should be emitted in a deferred way.
299 static bool hasDeferredEmission(Operation *op) {
300  return isa_and_nonnull<emitc::GetGlobalOp, emitc::LiteralOp, emitc::MemberOp,
301  emitc::MemberOfPtrOp, emitc::SubscriptOp>(op);
302 }
303 
304 /// Determine whether expression \p expressionOp should be emitted inline, i.e.
305 /// as part of its user. This function recommends inlining of any expressions
306 /// that can be inlined unless it is used by another expression, under the
307 /// assumption that any expression fusion/re-materialization was taken care of
308 /// by transformations run by the backend.
309 static bool shouldBeInlined(ExpressionOp expressionOp) {
310  // Do not inline if expression is marked as such.
311  if (expressionOp.getDoNotInline())
312  return false;
313 
314  // Do not inline expressions with side effects to prevent side-effect
315  // reordering.
316  if (expressionOp.hasSideEffects())
317  return false;
318 
319  // Do not inline expressions with multiple uses.
320  Value result = expressionOp.getResult();
321  if (!result.hasOneUse())
322  return false;
323 
324  Operation *user = *result.getUsers().begin();
325 
326  // Do not inline expressions used by operations with deferred emission, since
327  // their translation requires the materialization of variables.
328  if (hasDeferredEmission(user))
329  return false;
330 
331  // Do not inline expressions used by ops with the CExpression trait. If this
332  // was intended, the user could have been merged into the expression op.
333  return !user->hasTrait<OpTrait::emitc::CExpression>();
334 }
335 
336 static LogicalResult printConstantOp(CppEmitter &emitter, Operation *operation,
337  Attribute value) {
338  OpResult result = operation->getResult(0);
339 
340  // Only emit an assignment as the variable was already declared when printing
341  // the FuncOp.
342  if (emitter.shouldDeclareVariablesAtTop()) {
343  // Skip the assignment if the emitc.constant has no value.
344  if (auto oAttr = dyn_cast<emitc::OpaqueAttr>(value)) {
345  if (oAttr.getValue().empty())
346  return success();
347  }
348 
349  if (failed(emitter.emitVariableAssignment(result)))
350  return failure();
351  return emitter.emitAttribute(operation->getLoc(), value);
352  }
353 
354  // Emit a variable declaration for an emitc.constant op without value.
355  if (auto oAttr = dyn_cast<emitc::OpaqueAttr>(value)) {
356  if (oAttr.getValue().empty())
357  // The semicolon gets printed by the emitOperation function.
358  return emitter.emitVariableDeclaration(result,
359  /*trailingSemicolon=*/false);
360  }
361 
362  // Emit a variable declaration.
363  if (failed(emitter.emitAssignPrefix(*operation)))
364  return failure();
365  return emitter.emitAttribute(operation->getLoc(), value);
366 }
367 
368 static LogicalResult printOperation(CppEmitter &emitter,
369  emitc::ConstantOp constantOp) {
370  Operation *operation = constantOp.getOperation();
371  Attribute value = constantOp.getValue();
372 
373  return printConstantOp(emitter, operation, value);
374 }
375 
376 static LogicalResult printOperation(CppEmitter &emitter,
377  emitc::VariableOp variableOp) {
378  Operation *operation = variableOp.getOperation();
379  Attribute value = variableOp.getValue();
380 
381  return printConstantOp(emitter, operation, value);
382 }
383 
384 static LogicalResult printOperation(CppEmitter &emitter,
385  emitc::GlobalOp globalOp) {
386 
387  return emitter.emitGlobalVariable(globalOp);
388 }
389 
390 static LogicalResult printOperation(CppEmitter &emitter,
391  emitc::AssignOp assignOp) {
392  OpResult result = assignOp.getVar().getDefiningOp()->getResult(0);
393 
394  if (failed(emitter.emitVariableAssignment(result)))
395  return failure();
396 
397  return emitter.emitOperand(assignOp.getValue());
398 }
399 
400 static LogicalResult printOperation(CppEmitter &emitter, emitc::LoadOp loadOp) {
401  if (failed(emitter.emitAssignPrefix(*loadOp)))
402  return failure();
403 
404  return emitter.emitOperand(loadOp.getOperand());
405 }
406 
407 static LogicalResult printBinaryOperation(CppEmitter &emitter,
408  Operation *operation,
409  StringRef binaryOperator) {
410  raw_ostream &os = emitter.ostream();
411 
412  if (failed(emitter.emitAssignPrefix(*operation)))
413  return failure();
414 
415  if (failed(emitter.emitOperand(operation->getOperand(0))))
416  return failure();
417 
418  os << " " << binaryOperator << " ";
419 
420  if (failed(emitter.emitOperand(operation->getOperand(1))))
421  return failure();
422 
423  return success();
424 }
425 
426 static LogicalResult printUnaryOperation(CppEmitter &emitter,
427  Operation *operation,
428  StringRef unaryOperator) {
429  raw_ostream &os = emitter.ostream();
430 
431  if (failed(emitter.emitAssignPrefix(*operation)))
432  return failure();
433 
434  os << unaryOperator;
435 
436  if (failed(emitter.emitOperand(operation->getOperand(0))))
437  return failure();
438 
439  return success();
440 }
441 
442 static LogicalResult printOperation(CppEmitter &emitter, emitc::AddOp addOp) {
443  Operation *operation = addOp.getOperation();
444 
445  return printBinaryOperation(emitter, operation, "+");
446 }
447 
448 static LogicalResult printOperation(CppEmitter &emitter, emitc::DivOp divOp) {
449  Operation *operation = divOp.getOperation();
450 
451  return printBinaryOperation(emitter, operation, "/");
452 }
453 
454 static LogicalResult printOperation(CppEmitter &emitter, emitc::MulOp mulOp) {
455  Operation *operation = mulOp.getOperation();
456 
457  return printBinaryOperation(emitter, operation, "*");
458 }
459 
460 static LogicalResult printOperation(CppEmitter &emitter, emitc::RemOp remOp) {
461  Operation *operation = remOp.getOperation();
462 
463  return printBinaryOperation(emitter, operation, "%");
464 }
465 
466 static LogicalResult printOperation(CppEmitter &emitter, emitc::SubOp subOp) {
467  Operation *operation = subOp.getOperation();
468 
469  return printBinaryOperation(emitter, operation, "-");
470 }
471 
472 static LogicalResult emitSwitchCase(CppEmitter &emitter,
473  raw_indented_ostream &os, Region &region) {
474  for (Region::OpIterator iteratorOp = region.op_begin(), end = region.op_end();
475  std::next(iteratorOp) != end; ++iteratorOp) {
476  if (failed(emitter.emitOperation(*iteratorOp, /*trailingSemicolon=*/true)))
477  return failure();
478  }
479  os << "break;\n";
480  return success();
481 }
482 
483 static LogicalResult printOperation(CppEmitter &emitter,
484  emitc::SwitchOp switchOp) {
485  raw_indented_ostream &os = emitter.ostream();
486 
487  os << "switch (";
488  if (failed(emitter.emitOperand(switchOp.getArg())))
489  return failure();
490  os << ") {";
491 
492  for (auto pair : llvm::zip(switchOp.getCases(), switchOp.getCaseRegions())) {
493  os << "\ncase " << std::get<0>(pair) << ": {\n";
494  os.indent();
495 
496  if (failed(emitSwitchCase(emitter, os, std::get<1>(pair))))
497  return failure();
498 
499  os.unindent() << "}";
500  }
501 
502  os << "\ndefault: {\n";
503  os.indent();
504 
505  if (failed(emitSwitchCase(emitter, os, switchOp.getDefaultRegion())))
506  return failure();
507 
508  os.unindent() << "}\n}";
509  return success();
510 }
511 
512 static LogicalResult printOperation(CppEmitter &emitter, emitc::CmpOp cmpOp) {
513  Operation *operation = cmpOp.getOperation();
514 
515  StringRef binaryOperator;
516 
517  switch (cmpOp.getPredicate()) {
518  case emitc::CmpPredicate::eq:
519  binaryOperator = "==";
520  break;
521  case emitc::CmpPredicate::ne:
522  binaryOperator = "!=";
523  break;
524  case emitc::CmpPredicate::lt:
525  binaryOperator = "<";
526  break;
527  case emitc::CmpPredicate::le:
528  binaryOperator = "<=";
529  break;
530  case emitc::CmpPredicate::gt:
531  binaryOperator = ">";
532  break;
533  case emitc::CmpPredicate::ge:
534  binaryOperator = ">=";
535  break;
536  case emitc::CmpPredicate::three_way:
537  binaryOperator = "<=>";
538  break;
539  }
540 
541  return printBinaryOperation(emitter, operation, binaryOperator);
542 }
543 
544 static LogicalResult printOperation(CppEmitter &emitter,
545  emitc::ConditionalOp conditionalOp) {
546  raw_ostream &os = emitter.ostream();
547 
548  if (failed(emitter.emitAssignPrefix(*conditionalOp)))
549  return failure();
550 
551  if (failed(emitter.emitOperand(conditionalOp.getCondition())))
552  return failure();
553 
554  os << " ? ";
555 
556  if (failed(emitter.emitOperand(conditionalOp.getTrueValue())))
557  return failure();
558 
559  os << " : ";
560 
561  if (failed(emitter.emitOperand(conditionalOp.getFalseValue())))
562  return failure();
563 
564  return success();
565 }
566 
567 static LogicalResult printOperation(CppEmitter &emitter,
568  emitc::VerbatimOp verbatimOp) {
569  raw_ostream &os = emitter.ostream();
570 
571  FailureOr<SmallVector<ReplacementItem>> items =
572  verbatimOp.parseFormatString();
573  if (failed(items))
574  return failure();
575 
576  auto fmtArg = verbatimOp.getFmtArgs().begin();
577 
578  for (ReplacementItem &item : *items) {
579  if (auto *str = std::get_if<StringRef>(&item)) {
580  os << *str;
581  } else {
582  if (failed(emitter.emitOperand(*fmtArg++)))
583  return failure();
584  }
585  }
586 
587  return success();
588 }
589 
590 static LogicalResult printOperation(CppEmitter &emitter,
591  cf::BranchOp branchOp) {
592  raw_ostream &os = emitter.ostream();
593  Block &successor = *branchOp.getSuccessor();
594 
595  for (auto pair :
596  llvm::zip(branchOp.getOperands(), successor.getArguments())) {
597  Value &operand = std::get<0>(pair);
598  BlockArgument &argument = std::get<1>(pair);
599  os << emitter.getOrCreateName(argument) << " = "
600  << emitter.getOrCreateName(operand) << ";\n";
601  }
602 
603  os << "goto ";
604  if (!(emitter.hasBlockLabel(successor)))
605  return branchOp.emitOpError("unable to find label for successor block");
606  os << emitter.getOrCreateName(successor);
607  return success();
608 }
609 
610 static LogicalResult printOperation(CppEmitter &emitter,
611  cf::CondBranchOp condBranchOp) {
612  raw_indented_ostream &os = emitter.ostream();
613  Block &trueSuccessor = *condBranchOp.getTrueDest();
614  Block &falseSuccessor = *condBranchOp.getFalseDest();
615 
616  os << "if (";
617  if (failed(emitter.emitOperand(condBranchOp.getCondition())))
618  return failure();
619  os << ") {\n";
620 
621  os.indent();
622 
623  // If condition is true.
624  for (auto pair : llvm::zip(condBranchOp.getTrueOperands(),
625  trueSuccessor.getArguments())) {
626  Value &operand = std::get<0>(pair);
627  BlockArgument &argument = std::get<1>(pair);
628  os << emitter.getOrCreateName(argument) << " = "
629  << emitter.getOrCreateName(operand) << ";\n";
630  }
631 
632  os << "goto ";
633  if (!(emitter.hasBlockLabel(trueSuccessor))) {
634  return condBranchOp.emitOpError("unable to find label for successor block");
635  }
636  os << emitter.getOrCreateName(trueSuccessor) << ";\n";
637  os.unindent() << "} else {\n";
638  os.indent();
639  // If condition is false.
640  for (auto pair : llvm::zip(condBranchOp.getFalseOperands(),
641  falseSuccessor.getArguments())) {
642  Value &operand = std::get<0>(pair);
643  BlockArgument &argument = std::get<1>(pair);
644  os << emitter.getOrCreateName(argument) << " = "
645  << emitter.getOrCreateName(operand) << ";\n";
646  }
647 
648  os << "goto ";
649  if (!(emitter.hasBlockLabel(falseSuccessor))) {
650  return condBranchOp.emitOpError()
651  << "unable to find label for successor block";
652  }
653  os << emitter.getOrCreateName(falseSuccessor) << ";\n";
654  os.unindent() << "}";
655  return success();
656 }
657 
658 static LogicalResult printCallOperation(CppEmitter &emitter, Operation *callOp,
659  StringRef callee) {
660  if (failed(emitter.emitAssignPrefix(*callOp)))
661  return failure();
662 
663  raw_ostream &os = emitter.ostream();
664  os << callee << "(";
665  if (failed(emitter.emitOperands(*callOp)))
666  return failure();
667  os << ")";
668  return success();
669 }
670 
671 static LogicalResult printOperation(CppEmitter &emitter, func::CallOp callOp) {
672  Operation *operation = callOp.getOperation();
673  StringRef callee = callOp.getCallee();
674 
675  return printCallOperation(emitter, operation, callee);
676 }
677 
678 static LogicalResult printOperation(CppEmitter &emitter, emitc::CallOp callOp) {
679  Operation *operation = callOp.getOperation();
680  StringRef callee = callOp.getCallee();
681 
682  return printCallOperation(emitter, operation, callee);
683 }
684 
685 static LogicalResult printOperation(CppEmitter &emitter,
686  emitc::CallOpaqueOp callOpaqueOp) {
687  raw_ostream &os = emitter.ostream();
688  Operation &op = *callOpaqueOp.getOperation();
689 
690  if (failed(emitter.emitAssignPrefix(op)))
691  return failure();
692  os << callOpaqueOp.getCallee();
693 
694  auto emitArgs = [&](Attribute attr) -> LogicalResult {
695  if (auto t = dyn_cast<IntegerAttr>(attr)) {
696  // Index attributes are treated specially as operand index.
697  if (t.getType().isIndex()) {
698  int64_t idx = t.getInt();
699  Value operand = op.getOperand(idx);
700  if (!emitter.hasValueInScope(operand))
701  return op.emitOpError("operand ")
702  << idx << "'s value not defined in scope";
703  os << emitter.getOrCreateName(operand);
704  return success();
705  }
706  }
707  if (failed(emitter.emitAttribute(op.getLoc(), attr)))
708  return failure();
709 
710  return success();
711  };
712 
713  if (callOpaqueOp.getTemplateArgs()) {
714  os << "<";
715  if (failed(interleaveCommaWithError(*callOpaqueOp.getTemplateArgs(), os,
716  emitArgs)))
717  return failure();
718  os << ">";
719  }
720 
721  os << "(";
722 
723  LogicalResult emittedArgs =
724  callOpaqueOp.getArgs()
725  ? interleaveCommaWithError(*callOpaqueOp.getArgs(), os, emitArgs)
726  : emitter.emitOperands(op);
727  if (failed(emittedArgs))
728  return failure();
729  os << ")";
730  return success();
731 }
732 
733 static LogicalResult printOperation(CppEmitter &emitter,
734  emitc::ApplyOp applyOp) {
735  raw_ostream &os = emitter.ostream();
736  Operation &op = *applyOp.getOperation();
737 
738  if (failed(emitter.emitAssignPrefix(op)))
739  return failure();
740  os << applyOp.getApplicableOperator();
741  os << emitter.getOrCreateName(applyOp.getOperand());
742 
743  return success();
744 }
745 
746 static LogicalResult printOperation(CppEmitter &emitter,
747  emitc::BitwiseAndOp bitwiseAndOp) {
748  Operation *operation = bitwiseAndOp.getOperation();
749  return printBinaryOperation(emitter, operation, "&");
750 }
751 
752 static LogicalResult
753 printOperation(CppEmitter &emitter,
754  emitc::BitwiseLeftShiftOp bitwiseLeftShiftOp) {
755  Operation *operation = bitwiseLeftShiftOp.getOperation();
756  return printBinaryOperation(emitter, operation, "<<");
757 }
758 
759 static LogicalResult printOperation(CppEmitter &emitter,
760  emitc::BitwiseNotOp bitwiseNotOp) {
761  Operation *operation = bitwiseNotOp.getOperation();
762  return printUnaryOperation(emitter, operation, "~");
763 }
764 
765 static LogicalResult printOperation(CppEmitter &emitter,
766  emitc::BitwiseOrOp bitwiseOrOp) {
767  Operation *operation = bitwiseOrOp.getOperation();
768  return printBinaryOperation(emitter, operation, "|");
769 }
770 
771 static LogicalResult
772 printOperation(CppEmitter &emitter,
773  emitc::BitwiseRightShiftOp bitwiseRightShiftOp) {
774  Operation *operation = bitwiseRightShiftOp.getOperation();
775  return printBinaryOperation(emitter, operation, ">>");
776 }
777 
778 static LogicalResult printOperation(CppEmitter &emitter,
779  emitc::BitwiseXorOp bitwiseXorOp) {
780  Operation *operation = bitwiseXorOp.getOperation();
781  return printBinaryOperation(emitter, operation, "^");
782 }
783 
784 static LogicalResult printOperation(CppEmitter &emitter,
785  emitc::UnaryPlusOp unaryPlusOp) {
786  Operation *operation = unaryPlusOp.getOperation();
787  return printUnaryOperation(emitter, operation, "+");
788 }
789 
790 static LogicalResult printOperation(CppEmitter &emitter,
791  emitc::UnaryMinusOp unaryMinusOp) {
792  Operation *operation = unaryMinusOp.getOperation();
793  return printUnaryOperation(emitter, operation, "-");
794 }
795 
796 static LogicalResult printOperation(CppEmitter &emitter, emitc::CastOp castOp) {
797  raw_ostream &os = emitter.ostream();
798  Operation &op = *castOp.getOperation();
799 
800  if (failed(emitter.emitAssignPrefix(op)))
801  return failure();
802  os << "(";
803  if (failed(emitter.emitType(op.getLoc(), op.getResult(0).getType())))
804  return failure();
805  os << ") ";
806  return emitter.emitOperand(castOp.getOperand());
807 }
808 
809 static LogicalResult printOperation(CppEmitter &emitter,
810  emitc::ExpressionOp expressionOp) {
811  if (shouldBeInlined(expressionOp))
812  return success();
813 
814  Operation &op = *expressionOp.getOperation();
815 
816  if (failed(emitter.emitAssignPrefix(op)))
817  return failure();
818 
819  return emitter.emitExpression(expressionOp);
820 }
821 
822 static LogicalResult printOperation(CppEmitter &emitter,
823  emitc::IncludeOp includeOp) {
824  raw_ostream &os = emitter.ostream();
825 
826  os << "#include ";
827  if (includeOp.getIsStandardInclude())
828  os << "<" << includeOp.getInclude() << ">";
829  else
830  os << "\"" << includeOp.getInclude() << "\"";
831 
832  return success();
833 }
834 
835 static LogicalResult printOperation(CppEmitter &emitter,
836  emitc::LogicalAndOp logicalAndOp) {
837  Operation *operation = logicalAndOp.getOperation();
838  return printBinaryOperation(emitter, operation, "&&");
839 }
840 
841 static LogicalResult printOperation(CppEmitter &emitter,
842  emitc::LogicalNotOp logicalNotOp) {
843  Operation *operation = logicalNotOp.getOperation();
844  return printUnaryOperation(emitter, operation, "!");
845 }
846 
847 static LogicalResult printOperation(CppEmitter &emitter,
848  emitc::LogicalOrOp logicalOrOp) {
849  Operation *operation = logicalOrOp.getOperation();
850  return printBinaryOperation(emitter, operation, "||");
851 }
852 
853 static LogicalResult printOperation(CppEmitter &emitter, emitc::ForOp forOp) {
854 
855  raw_indented_ostream &os = emitter.ostream();
856 
857  // Utility function to determine whether a value is an expression that will be
858  // inlined, and as such should be wrapped in parentheses in order to guarantee
859  // its precedence and associativity.
860  auto requiresParentheses = [&](Value value) {
861  auto expressionOp =
862  dyn_cast_if_present<ExpressionOp>(value.getDefiningOp());
863  if (!expressionOp)
864  return false;
865  return shouldBeInlined(expressionOp);
866  };
867 
868  os << "for (";
869  if (failed(
870  emitter.emitType(forOp.getLoc(), forOp.getInductionVar().getType())))
871  return failure();
872  os << " ";
873  os << emitter.getOrCreateName(forOp.getInductionVar());
874  os << " = ";
875  if (failed(emitter.emitOperand(forOp.getLowerBound())))
876  return failure();
877  os << "; ";
878  os << emitter.getOrCreateName(forOp.getInductionVar());
879  os << " < ";
880  Value upperBound = forOp.getUpperBound();
881  bool upperBoundRequiresParentheses = requiresParentheses(upperBound);
882  if (upperBoundRequiresParentheses)
883  os << "(";
884  if (failed(emitter.emitOperand(upperBound)))
885  return failure();
886  if (upperBoundRequiresParentheses)
887  os << ")";
888  os << "; ";
889  os << emitter.getOrCreateName(forOp.getInductionVar());
890  os << " += ";
891  if (failed(emitter.emitOperand(forOp.getStep())))
892  return failure();
893  os << ") {\n";
894  os.indent();
895 
896  Region &forRegion = forOp.getRegion();
897  auto regionOps = forRegion.getOps();
898 
899  // We skip the trailing yield op.
900  for (auto it = regionOps.begin(); std::next(it) != regionOps.end(); ++it) {
901  if (failed(emitter.emitOperation(*it, /*trailingSemicolon=*/true)))
902  return failure();
903  }
904 
905  os.unindent() << "}";
906 
907  return success();
908 }
909 
910 static LogicalResult printOperation(CppEmitter &emitter, emitc::IfOp ifOp) {
911  raw_indented_ostream &os = emitter.ostream();
912 
913  // Helper function to emit all ops except the last one, expected to be
914  // emitc::yield.
915  auto emitAllExceptLast = [&emitter](Region &region) {
916  Region::OpIterator it = region.op_begin(), end = region.op_end();
917  for (; std::next(it) != end; ++it) {
918  if (failed(emitter.emitOperation(*it, /*trailingSemicolon=*/true)))
919  return failure();
920  }
921  assert(isa<emitc::YieldOp>(*it) &&
922  "Expected last operation in the region to be emitc::yield");
923  return success();
924  };
925 
926  os << "if (";
927  if (failed(emitter.emitOperand(ifOp.getCondition())))
928  return failure();
929  os << ") {\n";
930  os.indent();
931  if (failed(emitAllExceptLast(ifOp.getThenRegion())))
932  return failure();
933  os.unindent() << "}";
934 
935  Region &elseRegion = ifOp.getElseRegion();
936  if (!elseRegion.empty()) {
937  os << " else {\n";
938  os.indent();
939  if (failed(emitAllExceptLast(elseRegion)))
940  return failure();
941  os.unindent() << "}";
942  }
943 
944  return success();
945 }
946 
947 static LogicalResult printOperation(CppEmitter &emitter,
948  func::ReturnOp returnOp) {
949  raw_ostream &os = emitter.ostream();
950  os << "return";
951  switch (returnOp.getNumOperands()) {
952  case 0:
953  return success();
954  case 1:
955  os << " ";
956  if (failed(emitter.emitOperand(returnOp.getOperand(0))))
957  return failure();
958  return success();
959  default:
960  os << " std::make_tuple(";
961  if (failed(emitter.emitOperandsAndAttributes(*returnOp.getOperation())))
962  return failure();
963  os << ")";
964  return success();
965  }
966 }
967 
968 static LogicalResult printOperation(CppEmitter &emitter,
969  emitc::ReturnOp returnOp) {
970  raw_ostream &os = emitter.ostream();
971  os << "return";
972  if (returnOp.getNumOperands() == 0)
973  return success();
974 
975  os << " ";
976  if (failed(emitter.emitOperand(returnOp.getOperand())))
977  return failure();
978  return success();
979 }
980 
981 static LogicalResult printOperation(CppEmitter &emitter, ModuleOp moduleOp) {
982  CppEmitter::Scope scope(emitter);
983 
984  for (Operation &op : moduleOp) {
985  if (failed(emitter.emitOperation(op, /*trailingSemicolon=*/false)))
986  return failure();
987  }
988  return success();
989 }
990 
991 static LogicalResult printOperation(CppEmitter &emitter, FileOp file) {
992  if (!emitter.shouldEmitFile(file))
993  return success();
994 
995  CppEmitter::Scope scope(emitter);
996 
997  for (Operation &op : file) {
998  if (failed(emitter.emitOperation(op, /*trailingSemicolon=*/false)))
999  return failure();
1000  }
1001  return success();
1002 }
1003 
1004 static LogicalResult printFunctionArgs(CppEmitter &emitter,
1005  Operation *functionOp,
1006  ArrayRef<Type> arguments) {
1007  raw_indented_ostream &os = emitter.ostream();
1008 
1009  return (
1010  interleaveCommaWithError(arguments, os, [&](Type arg) -> LogicalResult {
1011  return emitter.emitType(functionOp->getLoc(), arg);
1012  }));
1013 }
1014 
1015 static LogicalResult printFunctionArgs(CppEmitter &emitter,
1016  Operation *functionOp,
1017  Region::BlockArgListType arguments) {
1018  raw_indented_ostream &os = emitter.ostream();
1019 
1020  return (interleaveCommaWithError(
1021  arguments, os, [&](BlockArgument arg) -> LogicalResult {
1022  return emitter.emitVariableDeclaration(
1023  functionOp->getLoc(), arg.getType(), emitter.getOrCreateName(arg));
1024  }));
1025 }
1026 
1027 static LogicalResult printFunctionBody(CppEmitter &emitter,
1028  Operation *functionOp,
1029  Region::BlockListType &blocks) {
1030  raw_indented_ostream &os = emitter.ostream();
1031  os.indent();
1032 
1033  if (emitter.shouldDeclareVariablesAtTop()) {
1034  // Declare all variables that hold op results including those from nested
1035  // regions.
1036  WalkResult result =
1037  functionOp->walk<WalkOrder::PreOrder>([&](Operation *op) -> WalkResult {
1038  if (isa<emitc::ExpressionOp>(op->getParentOp()) ||
1039  (isa<emitc::ExpressionOp>(op) &&
1040  shouldBeInlined(cast<emitc::ExpressionOp>(op))))
1041  return WalkResult::skip();
1042  for (OpResult result : op->getResults()) {
1043  if (failed(emitter.emitVariableDeclaration(
1044  result, /*trailingSemicolon=*/true))) {
1045  return WalkResult(
1046  op->emitError("unable to declare result variable for op"));
1047  }
1048  }
1049  return WalkResult::advance();
1050  });
1051  if (result.wasInterrupted())
1052  return failure();
1053  }
1054 
1055  // Create label names for basic blocks.
1056  for (Block &block : blocks) {
1057  emitter.getOrCreateName(block);
1058  }
1059 
1060  // Declare variables for basic block arguments.
1061  for (Block &block : llvm::drop_begin(blocks)) {
1062  for (BlockArgument &arg : block.getArguments()) {
1063  if (emitter.hasValueInScope(arg))
1064  return functionOp->emitOpError(" block argument #")
1065  << arg.getArgNumber() << " is out of scope";
1066  if (isa<ArrayType, LValueType>(arg.getType()))
1067  return functionOp->emitOpError("cannot emit block argument #")
1068  << arg.getArgNumber() << " with type " << arg.getType();
1069  if (failed(
1070  emitter.emitType(block.getParentOp()->getLoc(), arg.getType()))) {
1071  return failure();
1072  }
1073  os << " " << emitter.getOrCreateName(arg) << ";\n";
1074  }
1075  }
1076 
1077  for (Block &block : blocks) {
1078  // Only print a label if the block has predecessors.
1079  if (!block.hasNoPredecessors()) {
1080  if (failed(emitter.emitLabel(block)))
1081  return failure();
1082  }
1083  for (Operation &op : block.getOperations()) {
1084  if (failed(emitter.emitOperation(op, /*trailingSemicolon=*/true)))
1085  return failure();
1086  }
1087  }
1088 
1089  os.unindent();
1090 
1091  return success();
1092 }
1093 
1094 static LogicalResult printOperation(CppEmitter &emitter,
1095  func::FuncOp functionOp) {
1096  // We need to declare variables at top if the function has multiple blocks.
1097  if (!emitter.shouldDeclareVariablesAtTop() &&
1098  functionOp.getBlocks().size() > 1) {
1099  return functionOp.emitOpError(
1100  "with multiple blocks needs variables declared at top");
1101  }
1102 
1103  if (llvm::any_of(functionOp.getArgumentTypes(), llvm::IsaPred<LValueType>)) {
1104  return functionOp.emitOpError()
1105  << "cannot emit lvalue type as argument type";
1106  }
1107 
1108  if (llvm::any_of(functionOp.getResultTypes(), llvm::IsaPred<ArrayType>)) {
1109  return functionOp.emitOpError() << "cannot emit array type as result type";
1110  }
1111 
1112  CppEmitter::Scope scope(emitter);
1113  raw_indented_ostream &os = emitter.ostream();
1114  if (failed(emitter.emitTypes(functionOp.getLoc(),
1115  functionOp.getFunctionType().getResults())))
1116  return failure();
1117  os << " " << functionOp.getName();
1118 
1119  os << "(";
1120  Operation *operation = functionOp.getOperation();
1121  if (failed(printFunctionArgs(emitter, operation, functionOp.getArguments())))
1122  return failure();
1123  os << ") {\n";
1124  if (failed(printFunctionBody(emitter, operation, functionOp.getBlocks())))
1125  return failure();
1126  os << "}\n";
1127 
1128  return success();
1129 }
1130 
1131 static LogicalResult printOperation(CppEmitter &emitter,
1132  emitc::FuncOp functionOp) {
1133  // We need to declare variables at top if the function has multiple blocks.
1134  if (!emitter.shouldDeclareVariablesAtTop() &&
1135  functionOp.getBlocks().size() > 1) {
1136  return functionOp.emitOpError(
1137  "with multiple blocks needs variables declared at top");
1138  }
1139 
1140  CppEmitter::Scope scope(emitter);
1141  raw_indented_ostream &os = emitter.ostream();
1142  if (functionOp.getSpecifiers()) {
1143  for (Attribute specifier : functionOp.getSpecifiersAttr()) {
1144  os << cast<StringAttr>(specifier).str() << " ";
1145  }
1146  }
1147 
1148  if (failed(emitter.emitTypes(functionOp.getLoc(),
1149  functionOp.getFunctionType().getResults())))
1150  return failure();
1151  os << " " << functionOp.getName();
1152 
1153  os << "(";
1154  Operation *operation = functionOp.getOperation();
1155  if (functionOp.isExternal()) {
1156  if (failed(printFunctionArgs(emitter, operation,
1157  functionOp.getArgumentTypes())))
1158  return failure();
1159  os << ");";
1160  return success();
1161  }
1162  if (failed(printFunctionArgs(emitter, operation, functionOp.getArguments())))
1163  return failure();
1164  os << ") {\n";
1165  if (failed(printFunctionBody(emitter, operation, functionOp.getBlocks())))
1166  return failure();
1167  os << "}\n";
1168 
1169  return success();
1170 }
1171 
1172 static LogicalResult printOperation(CppEmitter &emitter,
1173  DeclareFuncOp declareFuncOp) {
1174  CppEmitter::Scope scope(emitter);
1175  raw_indented_ostream &os = emitter.ostream();
1176 
1177  auto functionOp = SymbolTable::lookupNearestSymbolFrom<emitc::FuncOp>(
1178  declareFuncOp, declareFuncOp.getSymNameAttr());
1179 
1180  if (!functionOp)
1181  return failure();
1182 
1183  if (functionOp.getSpecifiers()) {
1184  for (Attribute specifier : functionOp.getSpecifiersAttr()) {
1185  os << cast<StringAttr>(specifier).str() << " ";
1186  }
1187  }
1188 
1189  if (failed(emitter.emitTypes(functionOp.getLoc(),
1190  functionOp.getFunctionType().getResults())))
1191  return failure();
1192  os << " " << functionOp.getName();
1193 
1194  os << "(";
1195  Operation *operation = functionOp.getOperation();
1196  if (failed(printFunctionArgs(emitter, operation, functionOp.getArguments())))
1197  return failure();
1198  os << ");";
1199 
1200  return success();
1201 }
1202 
1203 CppEmitter::CppEmitter(raw_ostream &os, bool declareVariablesAtTop,
1204  StringRef fileId)
1205  : os(os), declareVariablesAtTop(declareVariablesAtTop),
1206  fileId(fileId.str()) {
1207  valueInScopeCount.push(0);
1208  labelInScopeCount.push(0);
1209 }
1210 
1211 std::string CppEmitter::getSubscriptName(emitc::SubscriptOp op) {
1212  std::string out;
1213  llvm::raw_string_ostream ss(out);
1214  ss << getOrCreateName(op.getValue());
1215  for (auto index : op.getIndices()) {
1216  ss << "[" << getOrCreateName(index) << "]";
1217  }
1218  return out;
1219 }
1220 
1221 std::string CppEmitter::createMemberAccess(emitc::MemberOp op) {
1222  std::string out;
1223  llvm::raw_string_ostream ss(out);
1224  ss << getOrCreateName(op.getOperand());
1225  ss << "." << op.getMember();
1226  return out;
1227 }
1228 
1229 std::string CppEmitter::createMemberAccess(emitc::MemberOfPtrOp op) {
1230  std::string out;
1231  llvm::raw_string_ostream ss(out);
1232  ss << getOrCreateName(op.getOperand());
1233  ss << "->" << op.getMember();
1234  return out;
1235 }
1236 
1237 void CppEmitter::cacheDeferredOpResult(Value value, StringRef str) {
1238  if (!valueMapper.count(value))
1239  valueMapper.insert(value, str.str());
1240 }
1241 
1242 /// Return the existing or a new name for a Value.
1243 StringRef CppEmitter::getOrCreateName(Value val) {
1244  if (!valueMapper.count(val)) {
1245  assert(!hasDeferredEmission(val.getDefiningOp()) &&
1246  "cacheDeferredOpResult should have been called on this value, "
1247  "update the emitOperation function.");
1248  valueMapper.insert(val, formatv("v{0}", ++valueInScopeCount.top()));
1249  }
1250  return *valueMapper.begin(val);
1251 }
1252 
1253 /// Return the existing or a new label for a Block.
1254 StringRef CppEmitter::getOrCreateName(Block &block) {
1255  if (!blockMapper.count(&block))
1256  blockMapper.insert(&block, formatv("label{0}", ++labelInScopeCount.top()));
1257  return *blockMapper.begin(&block);
1258 }
1259 
1260 bool CppEmitter::shouldMapToUnsigned(IntegerType::SignednessSemantics val) {
1261  switch (val) {
1262  case IntegerType::Signless:
1263  return false;
1264  case IntegerType::Signed:
1265  return false;
1266  case IntegerType::Unsigned:
1267  return true;
1268  }
1269  llvm_unreachable("Unexpected IntegerType::SignednessSemantics");
1270 }
1271 
1272 bool CppEmitter::hasValueInScope(Value val) { return valueMapper.count(val); }
1273 
1274 bool CppEmitter::hasBlockLabel(Block &block) {
1275  return blockMapper.count(&block);
1276 }
1277 
1278 LogicalResult CppEmitter::emitAttribute(Location loc, Attribute attr) {
1279  auto printInt = [&](const APInt &val, bool isUnsigned) {
1280  if (val.getBitWidth() == 1) {
1281  if (val.getBoolValue())
1282  os << "true";
1283  else
1284  os << "false";
1285  } else {
1286  SmallString<128> strValue;
1287  val.toString(strValue, 10, !isUnsigned, false);
1288  os << strValue;
1289  }
1290  };
1291 
1292  auto printFloat = [&](const APFloat &val) {
1293  if (val.isFinite()) {
1294  SmallString<128> strValue;
1295  // Use default values of toString except don't truncate zeros.
1296  val.toString(strValue, 0, 0, false);
1297  os << strValue;
1298  switch (llvm::APFloatBase::SemanticsToEnum(val.getSemantics())) {
1299  case llvm::APFloatBase::S_IEEEhalf:
1300  os << "f16";
1301  break;
1302  case llvm::APFloatBase::S_BFloat:
1303  os << "bf16";
1304  break;
1305  case llvm::APFloatBase::S_IEEEsingle:
1306  os << "f";
1307  break;
1308  case llvm::APFloatBase::S_IEEEdouble:
1309  break;
1310  default:
1311  llvm_unreachable("unsupported floating point type");
1312  };
1313  } else if (val.isNaN()) {
1314  os << "NAN";
1315  } else if (val.isInfinity()) {
1316  if (val.isNegative())
1317  os << "-";
1318  os << "INFINITY";
1319  }
1320  };
1321 
1322  // Print floating point attributes.
1323  if (auto fAttr = dyn_cast<FloatAttr>(attr)) {
1324  if (!isa<Float16Type, BFloat16Type, Float32Type, Float64Type>(
1325  fAttr.getType())) {
1326  return emitError(
1327  loc, "expected floating point attribute to be f16, bf16, f32 or f64");
1328  }
1329  printFloat(fAttr.getValue());
1330  return success();
1331  }
1332  if (auto dense = dyn_cast<DenseFPElementsAttr>(attr)) {
1333  if (!isa<Float16Type, BFloat16Type, Float32Type, Float64Type>(
1334  dense.getElementType())) {
1335  return emitError(
1336  loc, "expected floating point attribute to be f16, bf16, f32 or f64");
1337  }
1338  os << '{';
1339  interleaveComma(dense, os, [&](const APFloat &val) { printFloat(val); });
1340  os << '}';
1341  return success();
1342  }
1343 
1344  // Print integer attributes.
1345  if (auto iAttr = dyn_cast<IntegerAttr>(attr)) {
1346  if (auto iType = dyn_cast<IntegerType>(iAttr.getType())) {
1347  printInt(iAttr.getValue(), shouldMapToUnsigned(iType.getSignedness()));
1348  return success();
1349  }
1350  if (auto iType = dyn_cast<IndexType>(iAttr.getType())) {
1351  printInt(iAttr.getValue(), false);
1352  return success();
1353  }
1354  }
1355  if (auto dense = dyn_cast<DenseIntElementsAttr>(attr)) {
1356  if (auto iType = dyn_cast<IntegerType>(
1357  cast<TensorType>(dense.getType()).getElementType())) {
1358  os << '{';
1359  interleaveComma(dense, os, [&](const APInt &val) {
1360  printInt(val, shouldMapToUnsigned(iType.getSignedness()));
1361  });
1362  os << '}';
1363  return success();
1364  }
1365  if (auto iType = dyn_cast<IndexType>(
1366  cast<TensorType>(dense.getType()).getElementType())) {
1367  os << '{';
1368  interleaveComma(dense, os,
1369  [&](const APInt &val) { printInt(val, false); });
1370  os << '}';
1371  return success();
1372  }
1373  }
1374 
1375  // Print opaque attributes.
1376  if (auto oAttr = dyn_cast<emitc::OpaqueAttr>(attr)) {
1377  os << oAttr.getValue();
1378  return success();
1379  }
1380 
1381  // Print symbolic reference attributes.
1382  if (auto sAttr = dyn_cast<SymbolRefAttr>(attr)) {
1383  if (sAttr.getNestedReferences().size() > 1)
1384  return emitError(loc, "attribute has more than 1 nested reference");
1385  os << sAttr.getRootReference().getValue();
1386  return success();
1387  }
1388 
1389  // Print type attributes.
1390  if (auto type = dyn_cast<TypeAttr>(attr))
1391  return emitType(loc, type.getValue());
1392 
1393  return emitError(loc, "cannot emit attribute: ") << attr;
1394 }
1395 
1396 LogicalResult CppEmitter::emitExpression(ExpressionOp expressionOp) {
1397  assert(emittedExpressionPrecedence.empty() &&
1398  "Expected precedence stack to be empty");
1399  Operation *rootOp = expressionOp.getRootOp();
1400 
1401  emittedExpression = expressionOp;
1402  FailureOr<int> precedence = getOperatorPrecedence(rootOp);
1403  if (failed(precedence))
1404  return failure();
1405  pushExpressionPrecedence(precedence.value());
1406 
1407  if (failed(emitOperation(*rootOp, /*trailingSemicolon=*/false)))
1408  return failure();
1409 
1410  popExpressionPrecedence();
1411  assert(emittedExpressionPrecedence.empty() &&
1412  "Expected precedence stack to be empty");
1413  emittedExpression = nullptr;
1414 
1415  return success();
1416 }
1417 
1418 LogicalResult CppEmitter::emitOperand(Value value) {
1419  if (isPartOfCurrentExpression(value)) {
1420  Operation *def = value.getDefiningOp();
1421  assert(def && "Expected operand to be defined by an operation");
1422  FailureOr<int> precedence = getOperatorPrecedence(def);
1423  if (failed(precedence))
1424  return failure();
1425 
1426  // Sub-expressions with equal or lower precedence need to be parenthesized,
1427  // as they might be evaluated in the wrong order depending on the shape of
1428  // the expression tree.
1429  bool encloseInParenthesis = precedence.value() <= getExpressionPrecedence();
1430  if (encloseInParenthesis)
1431  os << "(";
1432  pushExpressionPrecedence(precedence.value());
1433 
1434  if (failed(emitOperation(*def, /*trailingSemicolon=*/false)))
1435  return failure();
1436 
1437  if (encloseInParenthesis)
1438  os << ")";
1439 
1440  popExpressionPrecedence();
1441  return success();
1442  }
1443 
1444  auto expressionOp = dyn_cast_if_present<ExpressionOp>(value.getDefiningOp());
1445  if (expressionOp && shouldBeInlined(expressionOp))
1446  return emitExpression(expressionOp);
1447 
1448  os << getOrCreateName(value);
1449  return success();
1450 }
1451 
1452 LogicalResult CppEmitter::emitOperands(Operation &op) {
1453  return interleaveCommaWithError(op.getOperands(), os, [&](Value operand) {
1454  // If an expression is being emitted, push lowest precedence as these
1455  // operands are either wrapped by parenthesis.
1456  if (getEmittedExpression())
1457  pushExpressionPrecedence(lowestPrecedence());
1458  if (failed(emitOperand(operand)))
1459  return failure();
1460  if (getEmittedExpression())
1461  popExpressionPrecedence();
1462  return success();
1463  });
1464 }
1465 
1466 LogicalResult
1467 CppEmitter::emitOperandsAndAttributes(Operation &op,
1468  ArrayRef<StringRef> exclude) {
1469  if (failed(emitOperands(op)))
1470  return failure();
1471  // Insert comma in between operands and non-filtered attributes if needed.
1472  if (op.getNumOperands() > 0) {
1473  for (NamedAttribute attr : op.getAttrs()) {
1474  if (!llvm::is_contained(exclude, attr.getName().strref())) {
1475  os << ", ";
1476  break;
1477  }
1478  }
1479  }
1480  // Emit attributes.
1481  auto emitNamedAttribute = [&](NamedAttribute attr) -> LogicalResult {
1482  if (llvm::is_contained(exclude, attr.getName().strref()))
1483  return success();
1484  os << "/* " << attr.getName().getValue() << " */";
1485  if (failed(emitAttribute(op.getLoc(), attr.getValue())))
1486  return failure();
1487  return success();
1488  };
1489  return interleaveCommaWithError(op.getAttrs(), os, emitNamedAttribute);
1490 }
1491 
1492 LogicalResult CppEmitter::emitVariableAssignment(OpResult result) {
1493  if (!hasValueInScope(result)) {
1494  return result.getDefiningOp()->emitOpError(
1495  "result variable for the operation has not been declared");
1496  }
1497  os << getOrCreateName(result) << " = ";
1498  return success();
1499 }
1500 
1501 LogicalResult CppEmitter::emitVariableDeclaration(OpResult result,
1502  bool trailingSemicolon) {
1503  if (hasDeferredEmission(result.getDefiningOp()))
1504  return success();
1505  if (hasValueInScope(result)) {
1506  return result.getDefiningOp()->emitError(
1507  "result variable for the operation already declared");
1508  }
1509  if (failed(emitVariableDeclaration(result.getOwner()->getLoc(),
1510  result.getType(),
1511  getOrCreateName(result))))
1512  return failure();
1513  if (trailingSemicolon)
1514  os << ";\n";
1515  return success();
1516 }
1517 
1518 LogicalResult CppEmitter::emitGlobalVariable(GlobalOp op) {
1519  if (op.getExternSpecifier())
1520  os << "extern ";
1521  else if (op.getStaticSpecifier())
1522  os << "static ";
1523  if (op.getConstSpecifier())
1524  os << "const ";
1525 
1526  if (failed(emitVariableDeclaration(op->getLoc(), op.getType(),
1527  op.getSymName()))) {
1528  return failure();
1529  }
1530 
1531  std::optional<Attribute> initialValue = op.getInitialValue();
1532  if (initialValue) {
1533  os << " = ";
1534  if (failed(emitAttribute(op->getLoc(), *initialValue)))
1535  return failure();
1536  }
1537 
1538  os << ";";
1539  return success();
1540 }
1541 
1542 LogicalResult CppEmitter::emitAssignPrefix(Operation &op) {
1543  // If op is being emitted as part of an expression, bail out.
1544  if (getEmittedExpression())
1545  return success();
1546 
1547  switch (op.getNumResults()) {
1548  case 0:
1549  break;
1550  case 1: {
1551  OpResult result = op.getResult(0);
1552  if (shouldDeclareVariablesAtTop()) {
1553  if (failed(emitVariableAssignment(result)))
1554  return failure();
1555  } else {
1556  if (failed(emitVariableDeclaration(result, /*trailingSemicolon=*/false)))
1557  return failure();
1558  os << " = ";
1559  }
1560  break;
1561  }
1562  default:
1563  if (!shouldDeclareVariablesAtTop()) {
1564  for (OpResult result : op.getResults()) {
1565  if (failed(emitVariableDeclaration(result, /*trailingSemicolon=*/true)))
1566  return failure();
1567  }
1568  }
1569  os << "std::tie(";
1570  interleaveComma(op.getResults(), os,
1571  [&](Value result) { os << getOrCreateName(result); });
1572  os << ") = ";
1573  }
1574  return success();
1575 }
1576 
1577 LogicalResult CppEmitter::emitLabel(Block &block) {
1578  if (!hasBlockLabel(block))
1579  return block.getParentOp()->emitError("label for block not found");
1580  // FIXME: Add feature in `raw_indented_ostream` to ignore indent for block
1581  // label instead of using `getOStream`.
1582  os.getOStream() << getOrCreateName(block) << ":\n";
1583  return success();
1584 }
1585 
1586 LogicalResult CppEmitter::emitOperation(Operation &op, bool trailingSemicolon) {
1587  LogicalResult status =
1589  // Builtin ops.
1590  .Case<ModuleOp>([&](auto op) { return printOperation(*this, op); })
1591  // CF ops.
1592  .Case<cf::BranchOp, cf::CondBranchOp>(
1593  [&](auto op) { return printOperation(*this, op); })
1594  // EmitC ops.
1595  .Case<emitc::AddOp, emitc::ApplyOp, emitc::AssignOp,
1596  emitc::BitwiseAndOp, emitc::BitwiseLeftShiftOp,
1597  emitc::BitwiseNotOp, emitc::BitwiseOrOp,
1598  emitc::BitwiseRightShiftOp, emitc::BitwiseXorOp, emitc::CallOp,
1599  emitc::CallOpaqueOp, emitc::CastOp, emitc::CmpOp,
1600  emitc::ConditionalOp, emitc::ConstantOp, emitc::DeclareFuncOp,
1601  emitc::DivOp, emitc::ExpressionOp, emitc::FileOp, emitc::ForOp,
1602  emitc::FuncOp, emitc::GlobalOp, emitc::IfOp, emitc::IncludeOp,
1603  emitc::LoadOp, emitc::LogicalAndOp, emitc::LogicalNotOp,
1604  emitc::LogicalOrOp, emitc::MulOp, emitc::RemOp, emitc::ReturnOp,
1605  emitc::SubOp, emitc::SwitchOp, emitc::UnaryMinusOp,
1606  emitc::UnaryPlusOp, emitc::VariableOp, emitc::VerbatimOp>(
1607 
1608  [&](auto op) { return printOperation(*this, op); })
1609  // Func ops.
1610  .Case<func::CallOp, func::FuncOp, func::ReturnOp>(
1611  [&](auto op) { return printOperation(*this, op); })
1612  .Case<emitc::GetGlobalOp>([&](auto op) {
1613  cacheDeferredOpResult(op.getResult(), op.getName());
1614  return success();
1615  })
1616  .Case<emitc::LiteralOp>([&](auto op) {
1617  cacheDeferredOpResult(op.getResult(), op.getValue());
1618  return success();
1619  })
1620  .Case<emitc::MemberOp>([&](auto op) {
1621  cacheDeferredOpResult(op.getResult(), createMemberAccess(op));
1622  return success();
1623  })
1624  .Case<emitc::MemberOfPtrOp>([&](auto op) {
1625  cacheDeferredOpResult(op.getResult(), createMemberAccess(op));
1626  return success();
1627  })
1628  .Case<emitc::SubscriptOp>([&](auto op) {
1629  cacheDeferredOpResult(op.getResult(), getSubscriptName(op));
1630  return success();
1631  })
1632  .Default([&](Operation *) {
1633  return op.emitOpError("unable to find printer for op");
1634  });
1635 
1636  if (failed(status))
1637  return failure();
1638 
1639  if (hasDeferredEmission(&op))
1640  return success();
1641 
1642  if (getEmittedExpression() ||
1643  (isa<emitc::ExpressionOp>(op) &&
1644  shouldBeInlined(cast<emitc::ExpressionOp>(op))))
1645  return success();
1646 
1647  // Never emit a semicolon for some operations, especially if endening with
1648  // `}`.
1649  trailingSemicolon &=
1650  !isa<cf::CondBranchOp, emitc::DeclareFuncOp, emitc::FileOp, emitc::ForOp,
1651  emitc::IfOp, emitc::IncludeOp, emitc::SwitchOp, emitc::VerbatimOp>(
1652  op);
1653 
1654  os << (trailingSemicolon ? ";\n" : "\n");
1655 
1656  return success();
1657 }
1658 
1659 LogicalResult CppEmitter::emitVariableDeclaration(Location loc, Type type,
1660  StringRef name) {
1661  if (auto arrType = dyn_cast<emitc::ArrayType>(type)) {
1662  if (failed(emitType(loc, arrType.getElementType())))
1663  return failure();
1664  os << " " << name;
1665  for (auto dim : arrType.getShape()) {
1666  os << "[" << dim << "]";
1667  }
1668  return success();
1669  }
1670  if (failed(emitType(loc, type)))
1671  return failure();
1672  os << " " << name;
1673  return success();
1674 }
1675 
1676 LogicalResult CppEmitter::emitType(Location loc, Type type) {
1677  if (auto iType = dyn_cast<IntegerType>(type)) {
1678  switch (iType.getWidth()) {
1679  case 1:
1680  return (os << "bool"), success();
1681  case 8:
1682  case 16:
1683  case 32:
1684  case 64:
1685  if (shouldMapToUnsigned(iType.getSignedness()))
1686  return (os << "uint" << iType.getWidth() << "_t"), success();
1687  else
1688  return (os << "int" << iType.getWidth() << "_t"), success();
1689  default:
1690  return emitError(loc, "cannot emit integer type ") << type;
1691  }
1692  }
1693  if (auto fType = dyn_cast<FloatType>(type)) {
1694  switch (fType.getWidth()) {
1695  case 16: {
1696  if (llvm::isa<Float16Type>(type))
1697  return (os << "_Float16"), success();
1698  else if (llvm::isa<BFloat16Type>(type))
1699  return (os << "__bf16"), success();
1700  else
1701  return emitError(loc, "cannot emit float type ") << type;
1702  }
1703  case 32:
1704  return (os << "float"), success();
1705  case 64:
1706  return (os << "double"), success();
1707  default:
1708  return emitError(loc, "cannot emit float type ") << type;
1709  }
1710  }
1711  if (auto iType = dyn_cast<IndexType>(type))
1712  return (os << "size_t"), success();
1713  if (auto sType = dyn_cast<emitc::SizeTType>(type))
1714  return (os << "size_t"), success();
1715  if (auto sType = dyn_cast<emitc::SignedSizeTType>(type))
1716  return (os << "ssize_t"), success();
1717  if (auto pType = dyn_cast<emitc::PtrDiffTType>(type))
1718  return (os << "ptrdiff_t"), success();
1719  if (auto tType = dyn_cast<TensorType>(type)) {
1720  if (!tType.hasRank())
1721  return emitError(loc, "cannot emit unranked tensor type");
1722  if (!tType.hasStaticShape())
1723  return emitError(loc, "cannot emit tensor type with non static shape");
1724  os << "Tensor<";
1725  if (isa<ArrayType>(tType.getElementType()))
1726  return emitError(loc, "cannot emit tensor of array type ") << type;
1727  if (failed(emitType(loc, tType.getElementType())))
1728  return failure();
1729  auto shape = tType.getShape();
1730  for (auto dimSize : shape) {
1731  os << ", ";
1732  os << dimSize;
1733  }
1734  os << ">";
1735  return success();
1736  }
1737  if (auto tType = dyn_cast<TupleType>(type))
1738  return emitTupleType(loc, tType.getTypes());
1739  if (auto oType = dyn_cast<emitc::OpaqueType>(type)) {
1740  os << oType.getValue();
1741  return success();
1742  }
1743  if (auto aType = dyn_cast<emitc::ArrayType>(type)) {
1744  if (failed(emitType(loc, aType.getElementType())))
1745  return failure();
1746  for (auto dim : aType.getShape())
1747  os << "[" << dim << "]";
1748  return success();
1749  }
1750  if (auto lType = dyn_cast<emitc::LValueType>(type))
1751  return emitType(loc, lType.getValueType());
1752  if (auto pType = dyn_cast<emitc::PointerType>(type)) {
1753  if (isa<ArrayType>(pType.getPointee()))
1754  return emitError(loc, "cannot emit pointer to array type ") << type;
1755  if (failed(emitType(loc, pType.getPointee())))
1756  return failure();
1757  os << "*";
1758  return success();
1759  }
1760  return emitError(loc, "cannot emit type ") << type;
1761 }
1762 
1763 LogicalResult CppEmitter::emitTypes(Location loc, ArrayRef<Type> types) {
1764  switch (types.size()) {
1765  case 0:
1766  os << "void";
1767  return success();
1768  case 1:
1769  return emitType(loc, types.front());
1770  default:
1771  return emitTupleType(loc, types);
1772  }
1773 }
1774 
1775 LogicalResult CppEmitter::emitTupleType(Location loc, ArrayRef<Type> types) {
1776  if (llvm::any_of(types, llvm::IsaPred<ArrayType>)) {
1777  return emitError(loc, "cannot emit tuple of array type");
1778  }
1779  os << "std::tuple<";
1780  if (failed(interleaveCommaWithError(
1781  types, os, [&](Type type) { return emitType(loc, type); })))
1782  return failure();
1783  os << ">";
1784  return success();
1785 }
1786 
1787 LogicalResult emitc::translateToCpp(Operation *op, raw_ostream &os,
1788  bool declareVariablesAtTop,
1789  StringRef fileId) {
1790  CppEmitter emitter(os, declareVariablesAtTop, fileId);
1791  return emitter.emitOperation(*op, /*trailingSemicolon=*/false);
1792 }
static LogicalResult printCallOperation(CppEmitter &emitter, Operation *callOp, StringRef callee)
static bool shouldBeInlined(ExpressionOp expressionOp)
Determine whether expression expressionOp should be emitted inline, i.e.
LogicalResult interleaveCommaWithError(const Container &c, raw_ostream &os, UnaryFunctor eachFn)
static FailureOr< int > getOperatorPrecedence(Operation *operation)
Return the precedence of a operator as an integer, higher values imply higher precedence.
static LogicalResult printFunctionArgs(CppEmitter &emitter, Operation *functionOp, ArrayRef< Type > arguments)
static LogicalResult printFunctionBody(CppEmitter &emitter, Operation *functionOp, Region::BlockListType &blocks)
static LogicalResult printConstantOp(CppEmitter &emitter, Operation *operation, Attribute value)
static LogicalResult emitSwitchCase(CppEmitter &emitter, raw_indented_ostream &os, Region &region)
LogicalResult interleaveWithError(ForwardIterator begin, ForwardIterator end, UnaryFunctor eachFn, NullaryFunctor betweenFn)
Convenience functions to produce interleaved output with functions returning a LogicalResult.
static LogicalResult printBinaryOperation(CppEmitter &emitter, Operation *operation, StringRef binaryOperator)
static LogicalResult printOperation(CppEmitter &emitter, emitc::ConstantOp constantOp)
static LogicalResult printUnaryOperation(CppEmitter &emitter, Operation *operation, StringRef unaryOperator)
static bool hasDeferredEmission(Operation *op)
Determine whether expression op should be emitted in a deferred way.
Attributes are known-constant values of operations.
Definition: Attributes.h:25
This class represents an argument of a Block.
Definition: Value.h:319
Block represents an ordered list of Operations.
Definition: Block.h:33
BlockArgListType getArguments()
Definition: Block.h:87
Block * getSuccessor(unsigned i)
Definition: Block.cpp:261
Operation * getParentOp()
Returns the closest surrounding operation that contains this block.
Definition: Block.cpp:33
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:66
NamedAttribute represents a combination of a name and an Attribute value.
Definition: Attributes.h:207
This is a value defined by a result of an operation.
Definition: Value.h:457
Operation * getOwner() const
Returns the operation that owns this result.
Definition: Value.h:466
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
Value getOperand(unsigned idx)
Definition: Operation.h:350
bool hasTrait()
Returns true if the operation was registered with a particular trait, e.g.
Definition: Operation.h:750
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Definition: Operation.h:407
std::enable_if_t< llvm::function_traits< std::decay_t< FnT > >::num_args==1, RetT > walk(FnT &&callback)
Walk the operation by calling the callback for each nested operation (including this one),...
Definition: Operation.h:798
Location getLoc()
The source location the operation was defined or derived from.
Definition: Operation.h:223
unsigned getNumOperands()
Definition: Operation.h:346
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
Definition: Operation.h:234
ArrayRef< NamedAttribute > getAttrs()
Return all of the attributes on this operation.
Definition: Operation.h:512
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
Definition: Operation.cpp:268
OperationName getName()
The name of an operation is the key identifier for it.
Definition: Operation.h:119
operand_range getOperands()
Returns an iterator on the underlying Value's.
Definition: Operation.h:378
result_range getResults()
Definition: Operation.h:415
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
Definition: Operation.cpp:671
unsigned getNumResults()
Return the number of results held by this operation.
Definition: Operation.h:404
This class provides iteration over the held operations of blocks directly within a region.
Definition: Region.h:134
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition: Region.h:26
iterator_range< OpIterator > getOps()
Definition: Region.h:172
llvm::iplist< Block > BlockListType
Definition: Region.h:44
OpIterator op_begin()
Return iterators that walk the operations nested directly within this region.
Definition: Region.h:170
bool empty()
Definition: Region.h:60
MutableArrayRef< BlockArgument > BlockArgListType
Definition: Region.h:80
OpIterator op_end()
Definition: Region.h:171
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:129
user_range getUsers() const
Definition: Value.h:228
bool hasOneUse() const
Returns true if this value has exactly one use.
Definition: Value.h:215
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition: Value.cpp:20
A utility result that is used to signal how to proceed with an ongoing walk:
Definition: Visitors.h:33
static WalkResult skip()
Definition: Visitors.h:52
static WalkResult advance()
Definition: Visitors.h:51
bool wasInterrupted() const
Returns true if the walk was interrupted.
Definition: Visitors.h:55
raw_ostream subclass that simplifies indention a sequence of code.
raw_indented_ostream & unindent()
Decreases the indent and returning this raw_indented_ostream.
raw_indented_ostream & indent()
Increases the indent and returning this raw_indented_ostream.
LogicalResult translateToCpp(Operation *op, raw_ostream &os, bool declareVariablesAtTop=false, StringRef fileId={})
Translates the given operation to C++ code.
std::variant< StringRef, Placeholder > ReplacementItem
Definition: EmitC.h:54
Include the generated interface declarations.
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
This iterator enumerates the elements in "forward" order.
Definition: Visitors.h:65