MLIR  20.0.0git
TranslateToCpp.cpp
Go to the documentation of this file.
1 //===- TranslateToCpp.cpp - Translating to C++ calls ----------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
12 #include "mlir/IR/BuiltinOps.h"
13 #include "mlir/IR/BuiltinTypes.h"
14 #include "mlir/IR/Dialect.h"
15 #include "mlir/IR/Operation.h"
16 #include "mlir/IR/SymbolTable.h"
18 #include "mlir/Support/LLVM.h"
20 #include "llvm/ADT/DenseMap.h"
21 #include "llvm/ADT/ScopedHashTable.h"
22 #include "llvm/ADT/StringExtras.h"
23 #include "llvm/ADT/StringMap.h"
24 #include "llvm/ADT/TypeSwitch.h"
25 #include "llvm/Support/Debug.h"
26 #include "llvm/Support/FormatVariadic.h"
27 #include <stack>
28 #include <utility>
29 
30 #define DEBUG_TYPE "translate-to-cpp"
31 
32 using namespace mlir;
33 using namespace mlir::emitc;
34 using llvm::formatv;
35 
36 /// Convenience functions to produce interleaved output with functions returning
37 /// a LogicalResult. This is different than those in STLExtras as functions used
38 /// on each element doesn't return a string.
39 template <typename ForwardIterator, typename UnaryFunctor,
40  typename NullaryFunctor>
41 inline LogicalResult
43  UnaryFunctor eachFn, NullaryFunctor betweenFn) {
44  if (begin == end)
45  return success();
46  if (failed(eachFn(*begin)))
47  return failure();
48  ++begin;
49  for (; begin != end; ++begin) {
50  betweenFn();
51  if (failed(eachFn(*begin)))
52  return failure();
53  }
54  return success();
55 }
56 
57 template <typename Container, typename UnaryFunctor, typename NullaryFunctor>
58 inline LogicalResult interleaveWithError(const Container &c,
59  UnaryFunctor eachFn,
60  NullaryFunctor betweenFn) {
61  return interleaveWithError(c.begin(), c.end(), eachFn, betweenFn);
62 }
63 
64 template <typename Container, typename UnaryFunctor>
65 inline LogicalResult interleaveCommaWithError(const Container &c,
66  raw_ostream &os,
67  UnaryFunctor eachFn) {
68  return interleaveWithError(c.begin(), c.end(), eachFn, [&]() { os << ", "; });
69 }
70 
71 /// Return the precedence of a operator as an integer, higher values
72 /// imply higher precedence.
73 static FailureOr<int> getOperatorPrecedence(Operation *operation) {
75  .Case<emitc::AddOp>([&](auto op) { return 12; })
76  .Case<emitc::ApplyOp>([&](auto op) { return 15; })
77  .Case<emitc::BitwiseAndOp>([&](auto op) { return 7; })
78  .Case<emitc::BitwiseLeftShiftOp>([&](auto op) { return 11; })
79  .Case<emitc::BitwiseNotOp>([&](auto op) { return 15; })
80  .Case<emitc::BitwiseOrOp>([&](auto op) { return 5; })
81  .Case<emitc::BitwiseRightShiftOp>([&](auto op) { return 11; })
82  .Case<emitc::BitwiseXorOp>([&](auto op) { return 6; })
83  .Case<emitc::CallOp>([&](auto op) { return 16; })
84  .Case<emitc::CallOpaqueOp>([&](auto op) { return 16; })
85  .Case<emitc::CastOp>([&](auto op) { return 15; })
86  .Case<emitc::CmpOp>([&](auto op) -> FailureOr<int> {
87  switch (op.getPredicate()) {
88  case emitc::CmpPredicate::eq:
89  case emitc::CmpPredicate::ne:
90  return 8;
91  case emitc::CmpPredicate::lt:
92  case emitc::CmpPredicate::le:
93  case emitc::CmpPredicate::gt:
94  case emitc::CmpPredicate::ge:
95  return 9;
96  case emitc::CmpPredicate::three_way:
97  return 10;
98  }
99  return op->emitError("unsupported cmp predicate");
100  })
101  .Case<emitc::ConditionalOp>([&](auto op) { return 2; })
102  .Case<emitc::DivOp>([&](auto op) { return 13; })
103  .Case<emitc::LogicalAndOp>([&](auto op) { return 4; })
104  .Case<emitc::LogicalNotOp>([&](auto op) { return 15; })
105  .Case<emitc::LogicalOrOp>([&](auto op) { return 3; })
106  .Case<emitc::MulOp>([&](auto op) { return 13; })
107  .Case<emitc::RemOp>([&](auto op) { return 13; })
108  .Case<emitc::SubOp>([&](auto op) { return 12; })
109  .Case<emitc::UnaryMinusOp>([&](auto op) { return 15; })
110  .Case<emitc::UnaryPlusOp>([&](auto op) { return 15; })
111  .Default([](auto op) { return op->emitError("unsupported operation"); });
112 }
113 
114 namespace {
115 /// Emitter that uses dialect specific emitters to emit C++ code.
116 struct CppEmitter {
117  explicit CppEmitter(raw_ostream &os, bool declareVariablesAtTop);
118 
119  /// Emits attribute or returns failure.
120  LogicalResult emitAttribute(Location loc, Attribute attr);
121 
122  /// Emits operation 'op' with/without training semicolon or returns failure.
123  LogicalResult emitOperation(Operation &op, bool trailingSemicolon);
124 
125  /// Emits type 'type' or returns failure.
126  LogicalResult emitType(Location loc, Type type);
127 
128  /// Emits array of types as a std::tuple of the emitted types.
129  /// - emits void for an empty array;
130  /// - emits the type of the only element for arrays of size one;
131  /// - emits a std::tuple otherwise;
132  LogicalResult emitTypes(Location loc, ArrayRef<Type> types);
133 
134  /// Emits array of types as a std::tuple of the emitted types independently of
135  /// the array size.
136  LogicalResult emitTupleType(Location loc, ArrayRef<Type> types);
137 
138  /// Emits an assignment for a variable which has been declared previously.
139  LogicalResult emitVariableAssignment(OpResult result);
140 
141  /// Emits a variable declaration for a result of an operation.
142  LogicalResult emitVariableDeclaration(OpResult result,
143  bool trailingSemicolon);
144 
145  /// Emits a declaration of a variable with the given type and name.
146  LogicalResult emitVariableDeclaration(Location loc, Type type,
147  StringRef name);
148 
149  /// Emits the variable declaration and assignment prefix for 'op'.
150  /// - emits separate variable followed by std::tie for multi-valued operation;
151  /// - emits single type followed by variable for single result;
152  /// - emits nothing if no value produced by op;
153  /// Emits final '=' operator where a type is produced. Returns failure if
154  /// any result type could not be converted.
155  LogicalResult emitAssignPrefix(Operation &op);
156 
157  /// Emits a global variable declaration or definition.
158  LogicalResult emitGlobalVariable(GlobalOp op);
159 
160  /// Emits a label for the block.
161  LogicalResult emitLabel(Block &block);
162 
163  /// Emits the operands and atttributes of the operation. All operands are
164  /// emitted first and then all attributes in alphabetical order.
165  LogicalResult emitOperandsAndAttributes(Operation &op,
166  ArrayRef<StringRef> exclude = {});
167 
168  /// Emits the operands of the operation. All operands are emitted in order.
169  LogicalResult emitOperands(Operation &op);
170 
171  /// Emits value as an operands of an operation
172  LogicalResult emitOperand(Value value);
173 
174  /// Emit an expression as a C expression.
175  LogicalResult emitExpression(ExpressionOp expressionOp);
176 
177  /// Insert the expression representing the operation into the value cache.
178  void cacheDeferredOpResult(Value value, StringRef str);
179 
180  /// Return the existing or a new name for a Value.
181  StringRef getOrCreateName(Value val);
182 
183  // Returns the textual representation of a subscript operation.
184  std::string getSubscriptName(emitc::SubscriptOp op);
185 
186  // Returns the textual representation of a member (of object) operation.
187  std::string createMemberAccess(emitc::MemberOp op);
188 
189  // Returns the textual representation of a member of pointer operation.
190  std::string createMemberAccess(emitc::MemberOfPtrOp op);
191 
192  /// Return the existing or a new label of a Block.
193  StringRef getOrCreateName(Block &block);
194 
195  /// Whether to map an mlir integer to a unsigned integer in C++.
196  bool shouldMapToUnsigned(IntegerType::SignednessSemantics val);
197 
198  /// RAII helper function to manage entering/exiting C++ scopes.
199  struct Scope {
200  Scope(CppEmitter &emitter)
201  : valueMapperScope(emitter.valueMapper),
202  blockMapperScope(emitter.blockMapper), emitter(emitter) {
203  emitter.valueInScopeCount.push(emitter.valueInScopeCount.top());
204  emitter.labelInScopeCount.push(emitter.labelInScopeCount.top());
205  }
206  ~Scope() {
207  emitter.valueInScopeCount.pop();
208  emitter.labelInScopeCount.pop();
209  }
210 
211  private:
212  llvm::ScopedHashTableScope<Value, std::string> valueMapperScope;
213  llvm::ScopedHashTableScope<Block *, std::string> blockMapperScope;
214  CppEmitter &emitter;
215  };
216 
217  /// Returns wether the Value is assigned to a C++ variable in the scope.
218  bool hasValueInScope(Value val);
219 
220  // Returns whether a label is assigned to the block.
221  bool hasBlockLabel(Block &block);
222 
223  /// Returns the output stream.
224  raw_indented_ostream &ostream() { return os; };
225 
226  /// Returns if all variables for op results and basic block arguments need to
227  /// be declared at the beginning of a function.
228  bool shouldDeclareVariablesAtTop() { return declareVariablesAtTop; };
229 
230  /// Get expression currently being emitted.
231  ExpressionOp getEmittedExpression() { return emittedExpression; }
232 
233  /// Determine whether given value is part of the expression potentially being
234  /// emitted.
235  bool isPartOfCurrentExpression(Value value) {
236  if (!emittedExpression)
237  return false;
238  Operation *def = value.getDefiningOp();
239  if (!def)
240  return false;
241  auto operandExpression = dyn_cast<ExpressionOp>(def->getParentOp());
242  return operandExpression == emittedExpression;
243  };
244 
245 private:
246  using ValueMapper = llvm::ScopedHashTable<Value, std::string>;
247  using BlockMapper = llvm::ScopedHashTable<Block *, std::string>;
248 
249  /// Output stream to emit to.
251 
252  /// Boolean to enforce that all variables for op results and block
253  /// arguments are declared at the beginning of the function. This also
254  /// includes results from ops located in nested regions.
255  bool declareVariablesAtTop;
256 
257  /// Map from value to name of C++ variable that contain the name.
258  ValueMapper valueMapper;
259 
260  /// Map from block to name of C++ label.
261  BlockMapper blockMapper;
262 
263  /// The number of values in the current scope. This is used to declare the
264  /// names of values in a scope.
265  std::stack<int64_t> valueInScopeCount;
266  std::stack<int64_t> labelInScopeCount;
267 
268  /// State of the current expression being emitted.
269  ExpressionOp emittedExpression;
270  SmallVector<int> emittedExpressionPrecedence;
271 
272  void pushExpressionPrecedence(int precedence) {
273  emittedExpressionPrecedence.push_back(precedence);
274  }
275  void popExpressionPrecedence() { emittedExpressionPrecedence.pop_back(); }
276  static int lowestPrecedence() { return 0; }
277  int getExpressionPrecedence() {
278  if (emittedExpressionPrecedence.empty())
279  return lowestPrecedence();
280  return emittedExpressionPrecedence.back();
281  }
282 };
283 } // namespace
284 
285 /// Determine whether expression \p op should be emitted in a deferred way.
286 static bool hasDeferredEmission(Operation *op) {
287  return isa_and_nonnull<emitc::GetGlobalOp, emitc::LiteralOp, emitc::MemberOp,
288  emitc::MemberOfPtrOp, emitc::SubscriptOp>(op);
289 }
290 
291 /// Determine whether expression \p expressionOp should be emitted inline, i.e.
292 /// as part of its user. This function recommends inlining of any expressions
293 /// that can be inlined unless it is used by another expression, under the
294 /// assumption that any expression fusion/re-materialization was taken care of
295 /// by transformations run by the backend.
296 static bool shouldBeInlined(ExpressionOp expressionOp) {
297  // Do not inline if expression is marked as such.
298  if (expressionOp.getDoNotInline())
299  return false;
300 
301  // Do not inline expressions with side effects to prevent side-effect
302  // reordering.
303  if (expressionOp.hasSideEffects())
304  return false;
305 
306  // Do not inline expressions with multiple uses.
307  Value result = expressionOp.getResult();
308  if (!result.hasOneUse())
309  return false;
310 
311  Operation *user = *result.getUsers().begin();
312 
313  // Do not inline expressions used by operations with deferred emission, since
314  // their translation requires the materialization of variables.
315  if (hasDeferredEmission(user))
316  return false;
317 
318  // Do not inline expressions used by ops with the CExpression trait. If this
319  // was intended, the user could have been merged into the expression op.
320  return !user->hasTrait<OpTrait::emitc::CExpression>();
321 }
322 
323 static LogicalResult printConstantOp(CppEmitter &emitter, Operation *operation,
324  Attribute value) {
325  OpResult result = operation->getResult(0);
326 
327  // Only emit an assignment as the variable was already declared when printing
328  // the FuncOp.
329  if (emitter.shouldDeclareVariablesAtTop()) {
330  // Skip the assignment if the emitc.constant has no value.
331  if (auto oAttr = dyn_cast<emitc::OpaqueAttr>(value)) {
332  if (oAttr.getValue().empty())
333  return success();
334  }
335 
336  if (failed(emitter.emitVariableAssignment(result)))
337  return failure();
338  return emitter.emitAttribute(operation->getLoc(), value);
339  }
340 
341  // Emit a variable declaration for an emitc.constant op without value.
342  if (auto oAttr = dyn_cast<emitc::OpaqueAttr>(value)) {
343  if (oAttr.getValue().empty())
344  // The semicolon gets printed by the emitOperation function.
345  return emitter.emitVariableDeclaration(result,
346  /*trailingSemicolon=*/false);
347  }
348 
349  // Emit a variable declaration.
350  if (failed(emitter.emitAssignPrefix(*operation)))
351  return failure();
352  return emitter.emitAttribute(operation->getLoc(), value);
353 }
354 
355 static LogicalResult printOperation(CppEmitter &emitter,
356  emitc::ConstantOp constantOp) {
357  Operation *operation = constantOp.getOperation();
358  Attribute value = constantOp.getValue();
359 
360  return printConstantOp(emitter, operation, value);
361 }
362 
363 static LogicalResult printOperation(CppEmitter &emitter,
364  emitc::VariableOp variableOp) {
365  Operation *operation = variableOp.getOperation();
366  Attribute value = variableOp.getValue();
367 
368  return printConstantOp(emitter, operation, value);
369 }
370 
371 static LogicalResult printOperation(CppEmitter &emitter,
372  emitc::GlobalOp globalOp) {
373 
374  return emitter.emitGlobalVariable(globalOp);
375 }
376 
377 static LogicalResult printOperation(CppEmitter &emitter,
378  emitc::AssignOp assignOp) {
379  OpResult result = assignOp.getVar().getDefiningOp()->getResult(0);
380 
381  if (failed(emitter.emitVariableAssignment(result)))
382  return failure();
383 
384  return emitter.emitOperand(assignOp.getValue());
385 }
386 
387 static LogicalResult printOperation(CppEmitter &emitter, emitc::LoadOp loadOp) {
388  if (failed(emitter.emitAssignPrefix(*loadOp)))
389  return failure();
390 
391  return emitter.emitOperand(loadOp.getOperand());
392 }
393 
394 static LogicalResult printBinaryOperation(CppEmitter &emitter,
395  Operation *operation,
396  StringRef binaryOperator) {
397  raw_ostream &os = emitter.ostream();
398 
399  if (failed(emitter.emitAssignPrefix(*operation)))
400  return failure();
401 
402  if (failed(emitter.emitOperand(operation->getOperand(0))))
403  return failure();
404 
405  os << " " << binaryOperator << " ";
406 
407  if (failed(emitter.emitOperand(operation->getOperand(1))))
408  return failure();
409 
410  return success();
411 }
412 
413 static LogicalResult printUnaryOperation(CppEmitter &emitter,
414  Operation *operation,
415  StringRef unaryOperator) {
416  raw_ostream &os = emitter.ostream();
417 
418  if (failed(emitter.emitAssignPrefix(*operation)))
419  return failure();
420 
421  os << unaryOperator;
422 
423  if (failed(emitter.emitOperand(operation->getOperand(0))))
424  return failure();
425 
426  return success();
427 }
428 
429 static LogicalResult printOperation(CppEmitter &emitter, emitc::AddOp addOp) {
430  Operation *operation = addOp.getOperation();
431 
432  return printBinaryOperation(emitter, operation, "+");
433 }
434 
435 static LogicalResult printOperation(CppEmitter &emitter, emitc::DivOp divOp) {
436  Operation *operation = divOp.getOperation();
437 
438  return printBinaryOperation(emitter, operation, "/");
439 }
440 
441 static LogicalResult printOperation(CppEmitter &emitter, emitc::MulOp mulOp) {
442  Operation *operation = mulOp.getOperation();
443 
444  return printBinaryOperation(emitter, operation, "*");
445 }
446 
447 static LogicalResult printOperation(CppEmitter &emitter, emitc::RemOp remOp) {
448  Operation *operation = remOp.getOperation();
449 
450  return printBinaryOperation(emitter, operation, "%");
451 }
452 
453 static LogicalResult printOperation(CppEmitter &emitter, emitc::SubOp subOp) {
454  Operation *operation = subOp.getOperation();
455 
456  return printBinaryOperation(emitter, operation, "-");
457 }
458 
459 static LogicalResult emitSwitchCase(CppEmitter &emitter,
460  raw_indented_ostream &os, Region &region) {
461  for (Region::OpIterator iteratorOp = region.op_begin(), end = region.op_end();
462  std::next(iteratorOp) != end; ++iteratorOp) {
463  if (failed(emitter.emitOperation(*iteratorOp, /*trailingSemicolon=*/true)))
464  return failure();
465  }
466  os << "break;\n";
467  return success();
468 }
469 
470 static LogicalResult printOperation(CppEmitter &emitter,
471  emitc::SwitchOp switchOp) {
472  raw_indented_ostream &os = emitter.ostream();
473 
474  os << "\nswitch (" << emitter.getOrCreateName(switchOp.getArg()) << ") {";
475 
476  for (auto pair : llvm::zip(switchOp.getCases(), switchOp.getCaseRegions())) {
477  os << "\ncase " << std::get<0>(pair) << ": {\n";
478  os.indent();
479 
480  if (failed(emitSwitchCase(emitter, os, std::get<1>(pair))))
481  return failure();
482 
483  os.unindent() << "}";
484  }
485 
486  os << "\ndefault: {\n";
487  os.indent();
488 
489  if (failed(emitSwitchCase(emitter, os, switchOp.getDefaultRegion())))
490  return failure();
491 
492  os.unindent() << "}\n}";
493  return success();
494 }
495 
496 static LogicalResult printOperation(CppEmitter &emitter, emitc::CmpOp cmpOp) {
497  Operation *operation = cmpOp.getOperation();
498 
499  StringRef binaryOperator;
500 
501  switch (cmpOp.getPredicate()) {
502  case emitc::CmpPredicate::eq:
503  binaryOperator = "==";
504  break;
505  case emitc::CmpPredicate::ne:
506  binaryOperator = "!=";
507  break;
508  case emitc::CmpPredicate::lt:
509  binaryOperator = "<";
510  break;
511  case emitc::CmpPredicate::le:
512  binaryOperator = "<=";
513  break;
514  case emitc::CmpPredicate::gt:
515  binaryOperator = ">";
516  break;
517  case emitc::CmpPredicate::ge:
518  binaryOperator = ">=";
519  break;
520  case emitc::CmpPredicate::three_way:
521  binaryOperator = "<=>";
522  break;
523  }
524 
525  return printBinaryOperation(emitter, operation, binaryOperator);
526 }
527 
528 static LogicalResult printOperation(CppEmitter &emitter,
529  emitc::ConditionalOp conditionalOp) {
530  raw_ostream &os = emitter.ostream();
531 
532  if (failed(emitter.emitAssignPrefix(*conditionalOp)))
533  return failure();
534 
535  if (failed(emitter.emitOperand(conditionalOp.getCondition())))
536  return failure();
537 
538  os << " ? ";
539 
540  if (failed(emitter.emitOperand(conditionalOp.getTrueValue())))
541  return failure();
542 
543  os << " : ";
544 
545  if (failed(emitter.emitOperand(conditionalOp.getFalseValue())))
546  return failure();
547 
548  return success();
549 }
550 
551 static LogicalResult printOperation(CppEmitter &emitter,
552  emitc::VerbatimOp verbatimOp) {
553  raw_ostream &os = emitter.ostream();
554 
555  os << verbatimOp.getValue();
556 
557  return success();
558 }
559 
560 static LogicalResult printOperation(CppEmitter &emitter,
561  cf::BranchOp branchOp) {
562  raw_ostream &os = emitter.ostream();
563  Block &successor = *branchOp.getSuccessor();
564 
565  for (auto pair :
566  llvm::zip(branchOp.getOperands(), successor.getArguments())) {
567  Value &operand = std::get<0>(pair);
568  BlockArgument &argument = std::get<1>(pair);
569  os << emitter.getOrCreateName(argument) << " = "
570  << emitter.getOrCreateName(operand) << ";\n";
571  }
572 
573  os << "goto ";
574  if (!(emitter.hasBlockLabel(successor)))
575  return branchOp.emitOpError("unable to find label for successor block");
576  os << emitter.getOrCreateName(successor);
577  return success();
578 }
579 
580 static LogicalResult printOperation(CppEmitter &emitter,
581  cf::CondBranchOp condBranchOp) {
582  raw_indented_ostream &os = emitter.ostream();
583  Block &trueSuccessor = *condBranchOp.getTrueDest();
584  Block &falseSuccessor = *condBranchOp.getFalseDest();
585 
586  os << "if (" << emitter.getOrCreateName(condBranchOp.getCondition())
587  << ") {\n";
588 
589  os.indent();
590 
591  // If condition is true.
592  for (auto pair : llvm::zip(condBranchOp.getTrueOperands(),
593  trueSuccessor.getArguments())) {
594  Value &operand = std::get<0>(pair);
595  BlockArgument &argument = std::get<1>(pair);
596  os << emitter.getOrCreateName(argument) << " = "
597  << emitter.getOrCreateName(operand) << ";\n";
598  }
599 
600  os << "goto ";
601  if (!(emitter.hasBlockLabel(trueSuccessor))) {
602  return condBranchOp.emitOpError("unable to find label for successor block");
603  }
604  os << emitter.getOrCreateName(trueSuccessor) << ";\n";
605  os.unindent() << "} else {\n";
606  os.indent();
607  // If condition is false.
608  for (auto pair : llvm::zip(condBranchOp.getFalseOperands(),
609  falseSuccessor.getArguments())) {
610  Value &operand = std::get<0>(pair);
611  BlockArgument &argument = std::get<1>(pair);
612  os << emitter.getOrCreateName(argument) << " = "
613  << emitter.getOrCreateName(operand) << ";\n";
614  }
615 
616  os << "goto ";
617  if (!(emitter.hasBlockLabel(falseSuccessor))) {
618  return condBranchOp.emitOpError()
619  << "unable to find label for successor block";
620  }
621  os << emitter.getOrCreateName(falseSuccessor) << ";\n";
622  os.unindent() << "}";
623  return success();
624 }
625 
626 static LogicalResult printCallOperation(CppEmitter &emitter, Operation *callOp,
627  StringRef callee) {
628  if (failed(emitter.emitAssignPrefix(*callOp)))
629  return failure();
630 
631  raw_ostream &os = emitter.ostream();
632  os << callee << "(";
633  if (failed(emitter.emitOperands(*callOp)))
634  return failure();
635  os << ")";
636  return success();
637 }
638 
639 static LogicalResult printOperation(CppEmitter &emitter, func::CallOp callOp) {
640  Operation *operation = callOp.getOperation();
641  StringRef callee = callOp.getCallee();
642 
643  return printCallOperation(emitter, operation, callee);
644 }
645 
646 static LogicalResult printOperation(CppEmitter &emitter, emitc::CallOp callOp) {
647  Operation *operation = callOp.getOperation();
648  StringRef callee = callOp.getCallee();
649 
650  return printCallOperation(emitter, operation, callee);
651 }
652 
653 static LogicalResult printOperation(CppEmitter &emitter,
654  emitc::CallOpaqueOp callOpaqueOp) {
655  raw_ostream &os = emitter.ostream();
656  Operation &op = *callOpaqueOp.getOperation();
657 
658  if (failed(emitter.emitAssignPrefix(op)))
659  return failure();
660  os << callOpaqueOp.getCallee();
661 
662  auto emitArgs = [&](Attribute attr) -> LogicalResult {
663  if (auto t = dyn_cast<IntegerAttr>(attr)) {
664  // Index attributes are treated specially as operand index.
665  if (t.getType().isIndex()) {
666  int64_t idx = t.getInt();
667  Value operand = op.getOperand(idx);
668  if (!emitter.hasValueInScope(operand))
669  return op.emitOpError("operand ")
670  << idx << "'s value not defined in scope";
671  os << emitter.getOrCreateName(operand);
672  return success();
673  }
674  }
675  if (failed(emitter.emitAttribute(op.getLoc(), attr)))
676  return failure();
677 
678  return success();
679  };
680 
681  if (callOpaqueOp.getTemplateArgs()) {
682  os << "<";
683  if (failed(interleaveCommaWithError(*callOpaqueOp.getTemplateArgs(), os,
684  emitArgs)))
685  return failure();
686  os << ">";
687  }
688 
689  os << "(";
690 
691  LogicalResult emittedArgs =
692  callOpaqueOp.getArgs()
693  ? interleaveCommaWithError(*callOpaqueOp.getArgs(), os, emitArgs)
694  : emitter.emitOperands(op);
695  if (failed(emittedArgs))
696  return failure();
697  os << ")";
698  return success();
699 }
700 
701 static LogicalResult printOperation(CppEmitter &emitter,
702  emitc::ApplyOp applyOp) {
703  raw_ostream &os = emitter.ostream();
704  Operation &op = *applyOp.getOperation();
705 
706  if (failed(emitter.emitAssignPrefix(op)))
707  return failure();
708  os << applyOp.getApplicableOperator();
709  os << emitter.getOrCreateName(applyOp.getOperand());
710 
711  return success();
712 }
713 
714 static LogicalResult printOperation(CppEmitter &emitter,
715  emitc::BitwiseAndOp bitwiseAndOp) {
716  Operation *operation = bitwiseAndOp.getOperation();
717  return printBinaryOperation(emitter, operation, "&");
718 }
719 
720 static LogicalResult
721 printOperation(CppEmitter &emitter,
722  emitc::BitwiseLeftShiftOp bitwiseLeftShiftOp) {
723  Operation *operation = bitwiseLeftShiftOp.getOperation();
724  return printBinaryOperation(emitter, operation, "<<");
725 }
726 
727 static LogicalResult printOperation(CppEmitter &emitter,
728  emitc::BitwiseNotOp bitwiseNotOp) {
729  Operation *operation = bitwiseNotOp.getOperation();
730  return printUnaryOperation(emitter, operation, "~");
731 }
732 
733 static LogicalResult printOperation(CppEmitter &emitter,
734  emitc::BitwiseOrOp bitwiseOrOp) {
735  Operation *operation = bitwiseOrOp.getOperation();
736  return printBinaryOperation(emitter, operation, "|");
737 }
738 
739 static LogicalResult
740 printOperation(CppEmitter &emitter,
741  emitc::BitwiseRightShiftOp bitwiseRightShiftOp) {
742  Operation *operation = bitwiseRightShiftOp.getOperation();
743  return printBinaryOperation(emitter, operation, ">>");
744 }
745 
746 static LogicalResult printOperation(CppEmitter &emitter,
747  emitc::BitwiseXorOp bitwiseXorOp) {
748  Operation *operation = bitwiseXorOp.getOperation();
749  return printBinaryOperation(emitter, operation, "^");
750 }
751 
752 static LogicalResult printOperation(CppEmitter &emitter,
753  emitc::UnaryPlusOp unaryPlusOp) {
754  Operation *operation = unaryPlusOp.getOperation();
755  return printUnaryOperation(emitter, operation, "+");
756 }
757 
758 static LogicalResult printOperation(CppEmitter &emitter,
759  emitc::UnaryMinusOp unaryMinusOp) {
760  Operation *operation = unaryMinusOp.getOperation();
761  return printUnaryOperation(emitter, operation, "-");
762 }
763 
764 static LogicalResult printOperation(CppEmitter &emitter, emitc::CastOp castOp) {
765  raw_ostream &os = emitter.ostream();
766  Operation &op = *castOp.getOperation();
767 
768  if (failed(emitter.emitAssignPrefix(op)))
769  return failure();
770  os << "(";
771  if (failed(emitter.emitType(op.getLoc(), op.getResult(0).getType())))
772  return failure();
773  os << ") ";
774  return emitter.emitOperand(castOp.getOperand());
775 }
776 
777 static LogicalResult printOperation(CppEmitter &emitter,
778  emitc::ExpressionOp expressionOp) {
779  if (shouldBeInlined(expressionOp))
780  return success();
781 
782  Operation &op = *expressionOp.getOperation();
783 
784  if (failed(emitter.emitAssignPrefix(op)))
785  return failure();
786 
787  return emitter.emitExpression(expressionOp);
788 }
789 
790 static LogicalResult printOperation(CppEmitter &emitter,
791  emitc::IncludeOp includeOp) {
792  raw_ostream &os = emitter.ostream();
793 
794  os << "#include ";
795  if (includeOp.getIsStandardInclude())
796  os << "<" << includeOp.getInclude() << ">";
797  else
798  os << "\"" << includeOp.getInclude() << "\"";
799 
800  return success();
801 }
802 
803 static LogicalResult printOperation(CppEmitter &emitter,
804  emitc::LogicalAndOp logicalAndOp) {
805  Operation *operation = logicalAndOp.getOperation();
806  return printBinaryOperation(emitter, operation, "&&");
807 }
808 
809 static LogicalResult printOperation(CppEmitter &emitter,
810  emitc::LogicalNotOp logicalNotOp) {
811  Operation *operation = logicalNotOp.getOperation();
812  return printUnaryOperation(emitter, operation, "!");
813 }
814 
815 static LogicalResult printOperation(CppEmitter &emitter,
816  emitc::LogicalOrOp logicalOrOp) {
817  Operation *operation = logicalOrOp.getOperation();
818  return printBinaryOperation(emitter, operation, "||");
819 }
820 
821 static LogicalResult printOperation(CppEmitter &emitter, emitc::ForOp forOp) {
822 
823  raw_indented_ostream &os = emitter.ostream();
824 
825  // Utility function to determine whether a value is an expression that will be
826  // inlined, and as such should be wrapped in parentheses in order to guarantee
827  // its precedence and associativity.
828  auto requiresParentheses = [&](Value value) {
829  auto expressionOp =
830  dyn_cast_if_present<ExpressionOp>(value.getDefiningOp());
831  if (!expressionOp)
832  return false;
833  return shouldBeInlined(expressionOp);
834  };
835 
836  os << "for (";
837  if (failed(
838  emitter.emitType(forOp.getLoc(), forOp.getInductionVar().getType())))
839  return failure();
840  os << " ";
841  os << emitter.getOrCreateName(forOp.getInductionVar());
842  os << " = ";
843  if (failed(emitter.emitOperand(forOp.getLowerBound())))
844  return failure();
845  os << "; ";
846  os << emitter.getOrCreateName(forOp.getInductionVar());
847  os << " < ";
848  Value upperBound = forOp.getUpperBound();
849  bool upperBoundRequiresParentheses = requiresParentheses(upperBound);
850  if (upperBoundRequiresParentheses)
851  os << "(";
852  if (failed(emitter.emitOperand(upperBound)))
853  return failure();
854  if (upperBoundRequiresParentheses)
855  os << ")";
856  os << "; ";
857  os << emitter.getOrCreateName(forOp.getInductionVar());
858  os << " += ";
859  if (failed(emitter.emitOperand(forOp.getStep())))
860  return failure();
861  os << ") {\n";
862  os.indent();
863 
864  Region &forRegion = forOp.getRegion();
865  auto regionOps = forRegion.getOps();
866 
867  // We skip the trailing yield op.
868  for (auto it = regionOps.begin(); std::next(it) != regionOps.end(); ++it) {
869  if (failed(emitter.emitOperation(*it, /*trailingSemicolon=*/true)))
870  return failure();
871  }
872 
873  os.unindent() << "}";
874 
875  return success();
876 }
877 
878 static LogicalResult printOperation(CppEmitter &emitter, emitc::IfOp ifOp) {
879  raw_indented_ostream &os = emitter.ostream();
880 
881  // Helper function to emit all ops except the last one, expected to be
882  // emitc::yield.
883  auto emitAllExceptLast = [&emitter](Region &region) {
884  Region::OpIterator it = region.op_begin(), end = region.op_end();
885  for (; std::next(it) != end; ++it) {
886  if (failed(emitter.emitOperation(*it, /*trailingSemicolon=*/true)))
887  return failure();
888  }
889  assert(isa<emitc::YieldOp>(*it) &&
890  "Expected last operation in the region to be emitc::yield");
891  return success();
892  };
893 
894  os << "if (";
895  if (failed(emitter.emitOperand(ifOp.getCondition())))
896  return failure();
897  os << ") {\n";
898  os.indent();
899  if (failed(emitAllExceptLast(ifOp.getThenRegion())))
900  return failure();
901  os.unindent() << "}";
902 
903  Region &elseRegion = ifOp.getElseRegion();
904  if (!elseRegion.empty()) {
905  os << " else {\n";
906  os.indent();
907  if (failed(emitAllExceptLast(elseRegion)))
908  return failure();
909  os.unindent() << "}";
910  }
911 
912  return success();
913 }
914 
915 static LogicalResult printOperation(CppEmitter &emitter,
916  func::ReturnOp returnOp) {
917  raw_ostream &os = emitter.ostream();
918  os << "return";
919  switch (returnOp.getNumOperands()) {
920  case 0:
921  return success();
922  case 1:
923  os << " ";
924  if (failed(emitter.emitOperand(returnOp.getOperand(0))))
925  return failure();
926  return success();
927  default:
928  os << " std::make_tuple(";
929  if (failed(emitter.emitOperandsAndAttributes(*returnOp.getOperation())))
930  return failure();
931  os << ")";
932  return success();
933  }
934 }
935 
936 static LogicalResult printOperation(CppEmitter &emitter,
937  emitc::ReturnOp returnOp) {
938  raw_ostream &os = emitter.ostream();
939  os << "return";
940  if (returnOp.getNumOperands() == 0)
941  return success();
942 
943  os << " ";
944  if (failed(emitter.emitOperand(returnOp.getOperand())))
945  return failure();
946  return success();
947 }
948 
949 static LogicalResult printOperation(CppEmitter &emitter, ModuleOp moduleOp) {
950  CppEmitter::Scope scope(emitter);
951 
952  for (Operation &op : moduleOp) {
953  if (failed(emitter.emitOperation(op, /*trailingSemicolon=*/false)))
954  return failure();
955  }
956  return success();
957 }
958 
959 static LogicalResult printFunctionArgs(CppEmitter &emitter,
960  Operation *functionOp,
961  ArrayRef<Type> arguments) {
962  raw_indented_ostream &os = emitter.ostream();
963 
964  return (
965  interleaveCommaWithError(arguments, os, [&](Type arg) -> LogicalResult {
966  return emitter.emitType(functionOp->getLoc(), arg);
967  }));
968 }
969 
970 static LogicalResult printFunctionArgs(CppEmitter &emitter,
971  Operation *functionOp,
972  Region::BlockArgListType arguments) {
973  raw_indented_ostream &os = emitter.ostream();
974 
975  return (interleaveCommaWithError(
976  arguments, os, [&](BlockArgument arg) -> LogicalResult {
977  return emitter.emitVariableDeclaration(
978  functionOp->getLoc(), arg.getType(), emitter.getOrCreateName(arg));
979  }));
980 }
981 
982 static LogicalResult printFunctionBody(CppEmitter &emitter,
983  Operation *functionOp,
984  Region::BlockListType &blocks) {
985  raw_indented_ostream &os = emitter.ostream();
986  os.indent();
987 
988  if (emitter.shouldDeclareVariablesAtTop()) {
989  // Declare all variables that hold op results including those from nested
990  // regions.
991  WalkResult result =
992  functionOp->walk<WalkOrder::PreOrder>([&](Operation *op) -> WalkResult {
993  if (isa<emitc::ExpressionOp>(op->getParentOp()) ||
994  (isa<emitc::ExpressionOp>(op) &&
995  shouldBeInlined(cast<emitc::ExpressionOp>(op))))
996  return WalkResult::skip();
997  for (OpResult result : op->getResults()) {
998  if (failed(emitter.emitVariableDeclaration(
999  result, /*trailingSemicolon=*/true))) {
1000  return WalkResult(
1001  op->emitError("unable to declare result variable for op"));
1002  }
1003  }
1004  return WalkResult::advance();
1005  });
1006  if (result.wasInterrupted())
1007  return failure();
1008  }
1009 
1010  // Create label names for basic blocks.
1011  for (Block &block : blocks) {
1012  emitter.getOrCreateName(block);
1013  }
1014 
1015  // Declare variables for basic block arguments.
1016  for (Block &block : llvm::drop_begin(blocks)) {
1017  for (BlockArgument &arg : block.getArguments()) {
1018  if (emitter.hasValueInScope(arg))
1019  return functionOp->emitOpError(" block argument #")
1020  << arg.getArgNumber() << " is out of scope";
1021  if (isa<ArrayType, LValueType>(arg.getType()))
1022  return functionOp->emitOpError("cannot emit block argument #")
1023  << arg.getArgNumber() << " with type " << arg.getType();
1024  if (failed(
1025  emitter.emitType(block.getParentOp()->getLoc(), arg.getType()))) {
1026  return failure();
1027  }
1028  os << " " << emitter.getOrCreateName(arg) << ";\n";
1029  }
1030  }
1031 
1032  for (Block &block : blocks) {
1033  // Only print a label if the block has predecessors.
1034  if (!block.hasNoPredecessors()) {
1035  if (failed(emitter.emitLabel(block)))
1036  return failure();
1037  }
1038  for (Operation &op : block.getOperations()) {
1039  // When generating code for an emitc.if or cf.cond_br op no semicolon
1040  // needs to be printed after the closing brace.
1041  // When generating code for an emitc.for and emitc.verbatim op, printing a
1042  // trailing semicolon is handled within the printOperation function.
1043  bool trailingSemicolon =
1044  !isa<cf::CondBranchOp, emitc::DeclareFuncOp, emitc::ForOp,
1045  emitc::IfOp, emitc::SwitchOp, emitc::VerbatimOp>(op);
1046 
1047  if (failed(emitter.emitOperation(
1048  op, /*trailingSemicolon=*/trailingSemicolon)))
1049  return failure();
1050  }
1051  }
1052 
1053  os.unindent();
1054 
1055  return success();
1056 }
1057 
1058 static LogicalResult printOperation(CppEmitter &emitter,
1059  func::FuncOp functionOp) {
1060  // We need to declare variables at top if the function has multiple blocks.
1061  if (!emitter.shouldDeclareVariablesAtTop() &&
1062  functionOp.getBlocks().size() > 1) {
1063  return functionOp.emitOpError(
1064  "with multiple blocks needs variables declared at top");
1065  }
1066 
1067  if (llvm::any_of(functionOp.getArgumentTypes(), llvm::IsaPred<LValueType>)) {
1068  return functionOp.emitOpError()
1069  << "cannot emit lvalue type as argument type";
1070  }
1071 
1072  if (llvm::any_of(functionOp.getResultTypes(), llvm::IsaPred<ArrayType>)) {
1073  return functionOp.emitOpError() << "cannot emit array type as result type";
1074  }
1075 
1076  CppEmitter::Scope scope(emitter);
1077  raw_indented_ostream &os = emitter.ostream();
1078  if (failed(emitter.emitTypes(functionOp.getLoc(),
1079  functionOp.getFunctionType().getResults())))
1080  return failure();
1081  os << " " << functionOp.getName();
1082 
1083  os << "(";
1084  Operation *operation = functionOp.getOperation();
1085  if (failed(printFunctionArgs(emitter, operation, functionOp.getArguments())))
1086  return failure();
1087  os << ") {\n";
1088  if (failed(printFunctionBody(emitter, operation, functionOp.getBlocks())))
1089  return failure();
1090  os << "}\n";
1091 
1092  return success();
1093 }
1094 
1095 static LogicalResult printOperation(CppEmitter &emitter,
1096  emitc::FuncOp functionOp) {
1097  // We need to declare variables at top if the function has multiple blocks.
1098  if (!emitter.shouldDeclareVariablesAtTop() &&
1099  functionOp.getBlocks().size() > 1) {
1100  return functionOp.emitOpError(
1101  "with multiple blocks needs variables declared at top");
1102  }
1103 
1104  CppEmitter::Scope scope(emitter);
1105  raw_indented_ostream &os = emitter.ostream();
1106  if (functionOp.getSpecifiers()) {
1107  for (Attribute specifier : functionOp.getSpecifiersAttr()) {
1108  os << cast<StringAttr>(specifier).str() << " ";
1109  }
1110  }
1111 
1112  if (failed(emitter.emitTypes(functionOp.getLoc(),
1113  functionOp.getFunctionType().getResults())))
1114  return failure();
1115  os << " " << functionOp.getName();
1116 
1117  os << "(";
1118  Operation *operation = functionOp.getOperation();
1119  if (functionOp.isExternal()) {
1120  if (failed(printFunctionArgs(emitter, operation,
1121  functionOp.getArgumentTypes())))
1122  return failure();
1123  os << ");";
1124  return success();
1125  }
1126  if (failed(printFunctionArgs(emitter, operation, functionOp.getArguments())))
1127  return failure();
1128  os << ") {\n";
1129  if (failed(printFunctionBody(emitter, operation, functionOp.getBlocks())))
1130  return failure();
1131  os << "}\n";
1132 
1133  return success();
1134 }
1135 
1136 static LogicalResult printOperation(CppEmitter &emitter,
1137  DeclareFuncOp declareFuncOp) {
1138  CppEmitter::Scope scope(emitter);
1139  raw_indented_ostream &os = emitter.ostream();
1140 
1141  auto functionOp = SymbolTable::lookupNearestSymbolFrom<emitc::FuncOp>(
1142  declareFuncOp, declareFuncOp.getSymNameAttr());
1143 
1144  if (!functionOp)
1145  return failure();
1146 
1147  if (functionOp.getSpecifiers()) {
1148  for (Attribute specifier : functionOp.getSpecifiersAttr()) {
1149  os << cast<StringAttr>(specifier).str() << " ";
1150  }
1151  }
1152 
1153  if (failed(emitter.emitTypes(functionOp.getLoc(),
1154  functionOp.getFunctionType().getResults())))
1155  return failure();
1156  os << " " << functionOp.getName();
1157 
1158  os << "(";
1159  Operation *operation = functionOp.getOperation();
1160  if (failed(printFunctionArgs(emitter, operation, functionOp.getArguments())))
1161  return failure();
1162  os << ");";
1163 
1164  return success();
1165 }
1166 
1167 CppEmitter::CppEmitter(raw_ostream &os, bool declareVariablesAtTop)
1168  : os(os), declareVariablesAtTop(declareVariablesAtTop) {
1169  valueInScopeCount.push(0);
1170  labelInScopeCount.push(0);
1171 }
1172 
1173 std::string CppEmitter::getSubscriptName(emitc::SubscriptOp op) {
1174  std::string out;
1175  llvm::raw_string_ostream ss(out);
1176  ss << getOrCreateName(op.getValue());
1177  for (auto index : op.getIndices()) {
1178  ss << "[" << getOrCreateName(index) << "]";
1179  }
1180  return out;
1181 }
1182 
1183 std::string CppEmitter::createMemberAccess(emitc::MemberOp op) {
1184  std::string out;
1185  llvm::raw_string_ostream ss(out);
1186  ss << getOrCreateName(op.getOperand());
1187  ss << "." << op.getMember();
1188  return out;
1189 }
1190 
1191 std::string CppEmitter::createMemberAccess(emitc::MemberOfPtrOp op) {
1192  std::string out;
1193  llvm::raw_string_ostream ss(out);
1194  ss << getOrCreateName(op.getOperand());
1195  ss << "->" << op.getMember();
1196  return out;
1197 }
1198 
1199 void CppEmitter::cacheDeferredOpResult(Value value, StringRef str) {
1200  if (!valueMapper.count(value))
1201  valueMapper.insert(value, str.str());
1202 }
1203 
1204 /// Return the existing or a new name for a Value.
1205 StringRef CppEmitter::getOrCreateName(Value val) {
1206  if (!valueMapper.count(val)) {
1207  assert(!hasDeferredEmission(val.getDefiningOp()) &&
1208  "cacheDeferredOpResult should have been called on this value, "
1209  "update the emitOperation function.");
1210  valueMapper.insert(val, formatv("v{0}", ++valueInScopeCount.top()));
1211  }
1212  return *valueMapper.begin(val);
1213 }
1214 
1215 /// Return the existing or a new label for a Block.
1216 StringRef CppEmitter::getOrCreateName(Block &block) {
1217  if (!blockMapper.count(&block))
1218  blockMapper.insert(&block, formatv("label{0}", ++labelInScopeCount.top()));
1219  return *blockMapper.begin(&block);
1220 }
1221 
1222 bool CppEmitter::shouldMapToUnsigned(IntegerType::SignednessSemantics val) {
1223  switch (val) {
1224  case IntegerType::Signless:
1225  return false;
1226  case IntegerType::Signed:
1227  return false;
1228  case IntegerType::Unsigned:
1229  return true;
1230  }
1231  llvm_unreachable("Unexpected IntegerType::SignednessSemantics");
1232 }
1233 
1234 bool CppEmitter::hasValueInScope(Value val) { return valueMapper.count(val); }
1235 
1236 bool CppEmitter::hasBlockLabel(Block &block) {
1237  return blockMapper.count(&block);
1238 }
1239 
1240 LogicalResult CppEmitter::emitAttribute(Location loc, Attribute attr) {
1241  auto printInt = [&](const APInt &val, bool isUnsigned) {
1242  if (val.getBitWidth() == 1) {
1243  if (val.getBoolValue())
1244  os << "true";
1245  else
1246  os << "false";
1247  } else {
1248  SmallString<128> strValue;
1249  val.toString(strValue, 10, !isUnsigned, false);
1250  os << strValue;
1251  }
1252  };
1253 
1254  auto printFloat = [&](const APFloat &val) {
1255  if (val.isFinite()) {
1256  SmallString<128> strValue;
1257  // Use default values of toString except don't truncate zeros.
1258  val.toString(strValue, 0, 0, false);
1259  os << strValue;
1260  switch (llvm::APFloatBase::SemanticsToEnum(val.getSemantics())) {
1261  case llvm::APFloatBase::S_IEEEhalf:
1262  os << "f16";
1263  break;
1264  case llvm::APFloatBase::S_BFloat:
1265  os << "bf16";
1266  break;
1267  case llvm::APFloatBase::S_IEEEsingle:
1268  os << "f";
1269  break;
1270  case llvm::APFloatBase::S_IEEEdouble:
1271  break;
1272  default:
1273  llvm_unreachable("unsupported floating point type");
1274  };
1275  } else if (val.isNaN()) {
1276  os << "NAN";
1277  } else if (val.isInfinity()) {
1278  if (val.isNegative())
1279  os << "-";
1280  os << "INFINITY";
1281  }
1282  };
1283 
1284  // Print floating point attributes.
1285  if (auto fAttr = dyn_cast<FloatAttr>(attr)) {
1286  if (!isa<Float16Type, BFloat16Type, Float32Type, Float64Type>(
1287  fAttr.getType())) {
1288  return emitError(
1289  loc, "expected floating point attribute to be f16, bf16, f32 or f64");
1290  }
1291  printFloat(fAttr.getValue());
1292  return success();
1293  }
1294  if (auto dense = dyn_cast<DenseFPElementsAttr>(attr)) {
1295  if (!isa<Float16Type, BFloat16Type, Float32Type, Float64Type>(
1296  dense.getElementType())) {
1297  return emitError(
1298  loc, "expected floating point attribute to be f16, bf16, f32 or f64");
1299  }
1300  os << '{';
1301  interleaveComma(dense, os, [&](const APFloat &val) { printFloat(val); });
1302  os << '}';
1303  return success();
1304  }
1305 
1306  // Print integer attributes.
1307  if (auto iAttr = dyn_cast<IntegerAttr>(attr)) {
1308  if (auto iType = dyn_cast<IntegerType>(iAttr.getType())) {
1309  printInt(iAttr.getValue(), shouldMapToUnsigned(iType.getSignedness()));
1310  return success();
1311  }
1312  if (auto iType = dyn_cast<IndexType>(iAttr.getType())) {
1313  printInt(iAttr.getValue(), false);
1314  return success();
1315  }
1316  }
1317  if (auto dense = dyn_cast<DenseIntElementsAttr>(attr)) {
1318  if (auto iType = dyn_cast<IntegerType>(
1319  cast<TensorType>(dense.getType()).getElementType())) {
1320  os << '{';
1321  interleaveComma(dense, os, [&](const APInt &val) {
1322  printInt(val, shouldMapToUnsigned(iType.getSignedness()));
1323  });
1324  os << '}';
1325  return success();
1326  }
1327  if (auto iType = dyn_cast<IndexType>(
1328  cast<TensorType>(dense.getType()).getElementType())) {
1329  os << '{';
1330  interleaveComma(dense, os,
1331  [&](const APInt &val) { printInt(val, false); });
1332  os << '}';
1333  return success();
1334  }
1335  }
1336 
1337  // Print opaque attributes.
1338  if (auto oAttr = dyn_cast<emitc::OpaqueAttr>(attr)) {
1339  os << oAttr.getValue();
1340  return success();
1341  }
1342 
1343  // Print symbolic reference attributes.
1344  if (auto sAttr = dyn_cast<SymbolRefAttr>(attr)) {
1345  if (sAttr.getNestedReferences().size() > 1)
1346  return emitError(loc, "attribute has more than 1 nested reference");
1347  os << sAttr.getRootReference().getValue();
1348  return success();
1349  }
1350 
1351  // Print type attributes.
1352  if (auto type = dyn_cast<TypeAttr>(attr))
1353  return emitType(loc, type.getValue());
1354 
1355  return emitError(loc, "cannot emit attribute: ") << attr;
1356 }
1357 
1358 LogicalResult CppEmitter::emitExpression(ExpressionOp expressionOp) {
1359  assert(emittedExpressionPrecedence.empty() &&
1360  "Expected precedence stack to be empty");
1361  Operation *rootOp = expressionOp.getRootOp();
1362 
1363  emittedExpression = expressionOp;
1364  FailureOr<int> precedence = getOperatorPrecedence(rootOp);
1365  if (failed(precedence))
1366  return failure();
1367  pushExpressionPrecedence(precedence.value());
1368 
1369  if (failed(emitOperation(*rootOp, /*trailingSemicolon=*/false)))
1370  return failure();
1371 
1372  popExpressionPrecedence();
1373  assert(emittedExpressionPrecedence.empty() &&
1374  "Expected precedence stack to be empty");
1375  emittedExpression = nullptr;
1376 
1377  return success();
1378 }
1379 
1380 LogicalResult CppEmitter::emitOperand(Value value) {
1381  if (isPartOfCurrentExpression(value)) {
1382  Operation *def = value.getDefiningOp();
1383  assert(def && "Expected operand to be defined by an operation");
1384  FailureOr<int> precedence = getOperatorPrecedence(def);
1385  if (failed(precedence))
1386  return failure();
1387 
1388  // Sub-expressions with equal or lower precedence need to be parenthesized,
1389  // as they might be evaluated in the wrong order depending on the shape of
1390  // the expression tree.
1391  bool encloseInParenthesis = precedence.value() <= getExpressionPrecedence();
1392  if (encloseInParenthesis) {
1393  os << "(";
1394  pushExpressionPrecedence(lowestPrecedence());
1395  } else
1396  pushExpressionPrecedence(precedence.value());
1397 
1398  if (failed(emitOperation(*def, /*trailingSemicolon=*/false)))
1399  return failure();
1400 
1401  if (encloseInParenthesis)
1402  os << ")";
1403 
1404  popExpressionPrecedence();
1405  return success();
1406  }
1407 
1408  auto expressionOp = dyn_cast_if_present<ExpressionOp>(value.getDefiningOp());
1409  if (expressionOp && shouldBeInlined(expressionOp))
1410  return emitExpression(expressionOp);
1411 
1412  os << getOrCreateName(value);
1413  return success();
1414 }
1415 
1416 LogicalResult CppEmitter::emitOperands(Operation &op) {
1417  return interleaveCommaWithError(op.getOperands(), os, [&](Value operand) {
1418  // If an expression is being emitted, push lowest precedence as these
1419  // operands are either wrapped by parenthesis.
1420  if (getEmittedExpression())
1421  pushExpressionPrecedence(lowestPrecedence());
1422  if (failed(emitOperand(operand)))
1423  return failure();
1424  if (getEmittedExpression())
1425  popExpressionPrecedence();
1426  return success();
1427  });
1428 }
1429 
1430 LogicalResult
1431 CppEmitter::emitOperandsAndAttributes(Operation &op,
1432  ArrayRef<StringRef> exclude) {
1433  if (failed(emitOperands(op)))
1434  return failure();
1435  // Insert comma in between operands and non-filtered attributes if needed.
1436  if (op.getNumOperands() > 0) {
1437  for (NamedAttribute attr : op.getAttrs()) {
1438  if (!llvm::is_contained(exclude, attr.getName().strref())) {
1439  os << ", ";
1440  break;
1441  }
1442  }
1443  }
1444  // Emit attributes.
1445  auto emitNamedAttribute = [&](NamedAttribute attr) -> LogicalResult {
1446  if (llvm::is_contained(exclude, attr.getName().strref()))
1447  return success();
1448  os << "/* " << attr.getName().getValue() << " */";
1449  if (failed(emitAttribute(op.getLoc(), attr.getValue())))
1450  return failure();
1451  return success();
1452  };
1453  return interleaveCommaWithError(op.getAttrs(), os, emitNamedAttribute);
1454 }
1455 
1456 LogicalResult CppEmitter::emitVariableAssignment(OpResult result) {
1457  if (!hasValueInScope(result)) {
1458  return result.getDefiningOp()->emitOpError(
1459  "result variable for the operation has not been declared");
1460  }
1461  os << getOrCreateName(result) << " = ";
1462  return success();
1463 }
1464 
1465 LogicalResult CppEmitter::emitVariableDeclaration(OpResult result,
1466  bool trailingSemicolon) {
1467  if (hasDeferredEmission(result.getDefiningOp()))
1468  return success();
1469  if (hasValueInScope(result)) {
1470  return result.getDefiningOp()->emitError(
1471  "result variable for the operation already declared");
1472  }
1473  if (failed(emitVariableDeclaration(result.getOwner()->getLoc(),
1474  result.getType(),
1475  getOrCreateName(result))))
1476  return failure();
1477  if (trailingSemicolon)
1478  os << ";\n";
1479  return success();
1480 }
1481 
1482 LogicalResult CppEmitter::emitGlobalVariable(GlobalOp op) {
1483  if (op.getExternSpecifier())
1484  os << "extern ";
1485  else if (op.getStaticSpecifier())
1486  os << "static ";
1487  if (op.getConstSpecifier())
1488  os << "const ";
1489 
1490  if (failed(emitVariableDeclaration(op->getLoc(), op.getType(),
1491  op.getSymName()))) {
1492  return failure();
1493  }
1494 
1495  std::optional<Attribute> initialValue = op.getInitialValue();
1496  if (initialValue) {
1497  os << " = ";
1498  if (failed(emitAttribute(op->getLoc(), *initialValue)))
1499  return failure();
1500  }
1501 
1502  os << ";";
1503  return success();
1504 }
1505 
1506 LogicalResult CppEmitter::emitAssignPrefix(Operation &op) {
1507  // If op is being emitted as part of an expression, bail out.
1508  if (getEmittedExpression())
1509  return success();
1510 
1511  switch (op.getNumResults()) {
1512  case 0:
1513  break;
1514  case 1: {
1515  OpResult result = op.getResult(0);
1516  if (shouldDeclareVariablesAtTop()) {
1517  if (failed(emitVariableAssignment(result)))
1518  return failure();
1519  } else {
1520  if (failed(emitVariableDeclaration(result, /*trailingSemicolon=*/false)))
1521  return failure();
1522  os << " = ";
1523  }
1524  break;
1525  }
1526  default:
1527  if (!shouldDeclareVariablesAtTop()) {
1528  for (OpResult result : op.getResults()) {
1529  if (failed(emitVariableDeclaration(result, /*trailingSemicolon=*/true)))
1530  return failure();
1531  }
1532  }
1533  os << "std::tie(";
1534  interleaveComma(op.getResults(), os,
1535  [&](Value result) { os << getOrCreateName(result); });
1536  os << ") = ";
1537  }
1538  return success();
1539 }
1540 
1541 LogicalResult CppEmitter::emitLabel(Block &block) {
1542  if (!hasBlockLabel(block))
1543  return block.getParentOp()->emitError("label for block not found");
1544  // FIXME: Add feature in `raw_indented_ostream` to ignore indent for block
1545  // label instead of using `getOStream`.
1546  os.getOStream() << getOrCreateName(block) << ":\n";
1547  return success();
1548 }
1549 
1550 LogicalResult CppEmitter::emitOperation(Operation &op, bool trailingSemicolon) {
1551  LogicalResult status =
1553  // Builtin ops.
1554  .Case<ModuleOp>([&](auto op) { return printOperation(*this, op); })
1555  // CF ops.
1556  .Case<cf::BranchOp, cf::CondBranchOp>(
1557  [&](auto op) { return printOperation(*this, op); })
1558  // EmitC ops.
1559  .Case<emitc::AddOp, emitc::ApplyOp, emitc::AssignOp,
1560  emitc::BitwiseAndOp, emitc::BitwiseLeftShiftOp,
1561  emitc::BitwiseNotOp, emitc::BitwiseOrOp,
1562  emitc::BitwiseRightShiftOp, emitc::BitwiseXorOp, emitc::CallOp,
1563  emitc::CallOpaqueOp, emitc::CastOp, emitc::CmpOp,
1564  emitc::ConditionalOp, emitc::ConstantOp, emitc::DeclareFuncOp,
1565  emitc::DivOp, emitc::ExpressionOp, emitc::ForOp, emitc::FuncOp,
1566  emitc::GlobalOp, emitc::IfOp, emitc::IncludeOp, emitc::LoadOp,
1567  emitc::LogicalAndOp, emitc::LogicalNotOp, emitc::LogicalOrOp,
1568  emitc::MulOp, emitc::RemOp, emitc::ReturnOp, emitc::SubOp,
1569  emitc::SwitchOp, emitc::UnaryMinusOp, emitc::UnaryPlusOp,
1570  emitc::VariableOp, emitc::VerbatimOp>(
1571  [&](auto op) { return printOperation(*this, op); })
1572  // Func ops.
1573  .Case<func::CallOp, func::FuncOp, func::ReturnOp>(
1574  [&](auto op) { return printOperation(*this, op); })
1575  .Case<emitc::GetGlobalOp>([&](auto op) {
1576  cacheDeferredOpResult(op.getResult(), op.getName());
1577  return success();
1578  })
1579  .Case<emitc::LiteralOp>([&](auto op) {
1580  cacheDeferredOpResult(op.getResult(), op.getValue());
1581  return success();
1582  })
1583  .Case<emitc::MemberOp>([&](auto op) {
1584  cacheDeferredOpResult(op.getResult(), createMemberAccess(op));
1585  return success();
1586  })
1587  .Case<emitc::MemberOfPtrOp>([&](auto op) {
1588  cacheDeferredOpResult(op.getResult(), createMemberAccess(op));
1589  return success();
1590  })
1591  .Case<emitc::SubscriptOp>([&](auto op) {
1592  cacheDeferredOpResult(op.getResult(), getSubscriptName(op));
1593  return success();
1594  })
1595  .Default([&](Operation *) {
1596  return op.emitOpError("unable to find printer for op");
1597  });
1598 
1599  if (failed(status))
1600  return failure();
1601 
1602  if (hasDeferredEmission(&op))
1603  return success();
1604 
1605  if (getEmittedExpression() ||
1606  (isa<emitc::ExpressionOp>(op) &&
1607  shouldBeInlined(cast<emitc::ExpressionOp>(op))))
1608  return success();
1609 
1610  os << (trailingSemicolon ? ";\n" : "\n");
1611 
1612  return success();
1613 }
1614 
1615 LogicalResult CppEmitter::emitVariableDeclaration(Location loc, Type type,
1616  StringRef name) {
1617  if (auto arrType = dyn_cast<emitc::ArrayType>(type)) {
1618  if (failed(emitType(loc, arrType.getElementType())))
1619  return failure();
1620  os << " " << name;
1621  for (auto dim : arrType.getShape()) {
1622  os << "[" << dim << "]";
1623  }
1624  return success();
1625  }
1626  if (failed(emitType(loc, type)))
1627  return failure();
1628  os << " " << name;
1629  return success();
1630 }
1631 
1632 LogicalResult CppEmitter::emitType(Location loc, Type type) {
1633  if (auto iType = dyn_cast<IntegerType>(type)) {
1634  switch (iType.getWidth()) {
1635  case 1:
1636  return (os << "bool"), success();
1637  case 8:
1638  case 16:
1639  case 32:
1640  case 64:
1641  if (shouldMapToUnsigned(iType.getSignedness()))
1642  return (os << "uint" << iType.getWidth() << "_t"), success();
1643  else
1644  return (os << "int" << iType.getWidth() << "_t"), success();
1645  default:
1646  return emitError(loc, "cannot emit integer type ") << type;
1647  }
1648  }
1649  if (auto fType = dyn_cast<FloatType>(type)) {
1650  switch (fType.getWidth()) {
1651  case 16: {
1652  if (llvm::isa<Float16Type>(type))
1653  return (os << "_Float16"), success();
1654  else if (llvm::isa<BFloat16Type>(type))
1655  return (os << "__bf16"), success();
1656  else
1657  return emitError(loc, "cannot emit float type ") << type;
1658  }
1659  case 32:
1660  return (os << "float"), success();
1661  case 64:
1662  return (os << "double"), success();
1663  default:
1664  return emitError(loc, "cannot emit float type ") << type;
1665  }
1666  }
1667  if (auto iType = dyn_cast<IndexType>(type))
1668  return (os << "size_t"), success();
1669  if (auto sType = dyn_cast<emitc::SizeTType>(type))
1670  return (os << "size_t"), success();
1671  if (auto sType = dyn_cast<emitc::SignedSizeTType>(type))
1672  return (os << "ssize_t"), success();
1673  if (auto pType = dyn_cast<emitc::PtrDiffTType>(type))
1674  return (os << "ptrdiff_t"), success();
1675  if (auto tType = dyn_cast<TensorType>(type)) {
1676  if (!tType.hasRank())
1677  return emitError(loc, "cannot emit unranked tensor type");
1678  if (!tType.hasStaticShape())
1679  return emitError(loc, "cannot emit tensor type with non static shape");
1680  os << "Tensor<";
1681  if (isa<ArrayType>(tType.getElementType()))
1682  return emitError(loc, "cannot emit tensor of array type ") << type;
1683  if (failed(emitType(loc, tType.getElementType())))
1684  return failure();
1685  auto shape = tType.getShape();
1686  for (auto dimSize : shape) {
1687  os << ", ";
1688  os << dimSize;
1689  }
1690  os << ">";
1691  return success();
1692  }
1693  if (auto tType = dyn_cast<TupleType>(type))
1694  return emitTupleType(loc, tType.getTypes());
1695  if (auto oType = dyn_cast<emitc::OpaqueType>(type)) {
1696  os << oType.getValue();
1697  return success();
1698  }
1699  if (auto aType = dyn_cast<emitc::ArrayType>(type)) {
1700  if (failed(emitType(loc, aType.getElementType())))
1701  return failure();
1702  for (auto dim : aType.getShape())
1703  os << "[" << dim << "]";
1704  return success();
1705  }
1706  if (auto lType = dyn_cast<emitc::LValueType>(type))
1707  return emitType(loc, lType.getValueType());
1708  if (auto pType = dyn_cast<emitc::PointerType>(type)) {
1709  if (isa<ArrayType>(pType.getPointee()))
1710  return emitError(loc, "cannot emit pointer to array type ") << type;
1711  if (failed(emitType(loc, pType.getPointee())))
1712  return failure();
1713  os << "*";
1714  return success();
1715  }
1716  return emitError(loc, "cannot emit type ") << type;
1717 }
1718 
1719 LogicalResult CppEmitter::emitTypes(Location loc, ArrayRef<Type> types) {
1720  switch (types.size()) {
1721  case 0:
1722  os << "void";
1723  return success();
1724  case 1:
1725  return emitType(loc, types.front());
1726  default:
1727  return emitTupleType(loc, types);
1728  }
1729 }
1730 
1731 LogicalResult CppEmitter::emitTupleType(Location loc, ArrayRef<Type> types) {
1732  if (llvm::any_of(types, llvm::IsaPred<ArrayType>)) {
1733  return emitError(loc, "cannot emit tuple of array type");
1734  }
1735  os << "std::tuple<";
1736  if (failed(interleaveCommaWithError(
1737  types, os, [&](Type type) { return emitType(loc, type); })))
1738  return failure();
1739  os << ">";
1740  return success();
1741 }
1742 
1743 LogicalResult emitc::translateToCpp(Operation *op, raw_ostream &os,
1744  bool declareVariablesAtTop) {
1745  CppEmitter emitter(os, declareVariablesAtTop);
1746  return emitter.emitOperation(*op, /*trailingSemicolon=*/false);
1747 }
static LogicalResult printCallOperation(CppEmitter &emitter, Operation *callOp, StringRef callee)
static bool shouldBeInlined(ExpressionOp expressionOp)
Determine whether expression expressionOp should be emitted inline, i.e.
LogicalResult interleaveCommaWithError(const Container &c, raw_ostream &os, UnaryFunctor eachFn)
static FailureOr< int > getOperatorPrecedence(Operation *operation)
Return the precedence of a operator as an integer, higher values imply higher precedence.
static LogicalResult printFunctionArgs(CppEmitter &emitter, Operation *functionOp, ArrayRef< Type > arguments)
static LogicalResult printFunctionBody(CppEmitter &emitter, Operation *functionOp, Region::BlockListType &blocks)
static LogicalResult printConstantOp(CppEmitter &emitter, Operation *operation, Attribute value)
static LogicalResult emitSwitchCase(CppEmitter &emitter, raw_indented_ostream &os, Region &region)
LogicalResult interleaveWithError(ForwardIterator begin, ForwardIterator end, UnaryFunctor eachFn, NullaryFunctor betweenFn)
Convenience functions to produce interleaved output with functions returning a LogicalResult.
static LogicalResult printBinaryOperation(CppEmitter &emitter, Operation *operation, StringRef binaryOperator)
static LogicalResult printOperation(CppEmitter &emitter, emitc::ConstantOp constantOp)
static LogicalResult printUnaryOperation(CppEmitter &emitter, Operation *operation, StringRef unaryOperator)
static bool hasDeferredEmission(Operation *op)
Determine whether expression op should be emitted in a deferred way.
Attributes are known-constant values of operations.
Definition: Attributes.h:25
This class represents an argument of a Block.
Definition: Value.h:319
Block represents an ordered list of Operations.
Definition: Block.h:33
BlockArgListType getArguments()
Definition: Block.h:87
Block * getSuccessor(unsigned i)
Definition: Block.cpp:261
Operation * getParentOp()
Returns the closest surrounding operation that contains this block.
Definition: Block.cpp:33
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:66
NamedAttribute represents a combination of a name and an Attribute value.
Definition: Attributes.h:207
This is a value defined by a result of an operation.
Definition: Value.h:457
Operation * getOwner() const
Returns the operation that owns this result.
Definition: Value.h:466
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
Value getOperand(unsigned idx)
Definition: Operation.h:350
bool hasTrait()
Returns true if the operation was registered with a particular trait, e.g.
Definition: Operation.h:750
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Definition: Operation.h:407
std::enable_if_t< llvm::function_traits< std::decay_t< FnT > >::num_args==1, RetT > walk(FnT &&callback)
Walk the operation by calling the callback for each nested operation (including this one),...
Definition: Operation.h:798
Location getLoc()
The source location the operation was defined or derived from.
Definition: Operation.h:223
unsigned getNumOperands()
Definition: Operation.h:346
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
Definition: Operation.h:234
ArrayRef< NamedAttribute > getAttrs()
Return all of the attributes on this operation.
Definition: Operation.h:512
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
Definition: Operation.cpp:268
OperationName getName()
The name of an operation is the key identifier for it.
Definition: Operation.h:119
operand_range getOperands()
Returns an iterator on the underlying Value's.
Definition: Operation.h:378
result_range getResults()
Definition: Operation.h:415
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
Definition: Operation.cpp:671
unsigned getNumResults()
Return the number of results held by this operation.
Definition: Operation.h:404
This class provides iteration over the held operations of blocks directly within a region.
Definition: Region.h:134
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition: Region.h:26
iterator_range< OpIterator > getOps()
Definition: Region.h:172
llvm::iplist< Block > BlockListType
Definition: Region.h:44
OpIterator op_begin()
Return iterators that walk the operations nested directly within this region.
Definition: Region.h:170
bool empty()
Definition: Region.h:60
OpIterator op_end()
Definition: Region.h:171
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:129
user_range getUsers() const
Definition: Value.h:228
bool hasOneUse() const
Returns true if this value has exactly one use.
Definition: Value.h:215
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition: Value.cpp:20
A utility result that is used to signal how to proceed with an ongoing walk:
Definition: Visitors.h:33
static WalkResult skip()
Definition: Visitors.h:52
static WalkResult advance()
Definition: Visitors.h:51
bool wasInterrupted() const
Returns true if the walk was interrupted.
Definition: Visitors.h:55
raw_ostream subclass that simplifies indention a sequence of code.
raw_indented_ostream & unindent()
Decreases the indent and returning this raw_indented_ostream.
raw_indented_ostream & indent()
Increases the indent and returning this raw_indented_ostream.
LogicalResult translateToCpp(Operation *op, raw_ostream &os, bool declareVariablesAtTop=false)
Translates the given operation to C++ code.
Include the generated interface declarations.
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
This iterator enumerates the elements in "forward" order.
Definition: Visitors.h:65