MLIR  19.0.0git
TranslateToCpp.cpp
Go to the documentation of this file.
1 //===- TranslateToCpp.cpp - Translating to C++ calls ----------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
12 #include "mlir/IR/BuiltinOps.h"
13 #include "mlir/IR/BuiltinTypes.h"
14 #include "mlir/IR/Dialect.h"
15 #include "mlir/IR/Operation.h"
16 #include "mlir/IR/SymbolTable.h"
18 #include "mlir/Support/LLVM.h"
20 #include "llvm/ADT/DenseMap.h"
21 #include "llvm/ADT/ScopedHashTable.h"
22 #include "llvm/ADT/StringExtras.h"
23 #include "llvm/ADT/StringMap.h"
24 #include "llvm/ADT/TypeSwitch.h"
25 #include "llvm/Support/Debug.h"
26 #include "llvm/Support/FormatVariadic.h"
27 #include <stack>
28 #include <utility>
29 
30 #define DEBUG_TYPE "translate-to-cpp"
31 
32 using namespace mlir;
33 using namespace mlir::emitc;
34 using llvm::formatv;
35 
36 /// Convenience functions to produce interleaved output with functions returning
37 /// a LogicalResult. This is different than those in STLExtras as functions used
38 /// on each element doesn't return a string.
39 template <typename ForwardIterator, typename UnaryFunctor,
40  typename NullaryFunctor>
41 inline LogicalResult
43  UnaryFunctor eachFn, NullaryFunctor betweenFn) {
44  if (begin == end)
45  return success();
46  if (failed(eachFn(*begin)))
47  return failure();
48  ++begin;
49  for (; begin != end; ++begin) {
50  betweenFn();
51  if (failed(eachFn(*begin)))
52  return failure();
53  }
54  return success();
55 }
56 
57 template <typename Container, typename UnaryFunctor, typename NullaryFunctor>
58 inline LogicalResult interleaveWithError(const Container &c,
59  UnaryFunctor eachFn,
60  NullaryFunctor betweenFn) {
61  return interleaveWithError(c.begin(), c.end(), eachFn, betweenFn);
62 }
63 
64 template <typename Container, typename UnaryFunctor>
65 inline LogicalResult interleaveCommaWithError(const Container &c,
66  raw_ostream &os,
67  UnaryFunctor eachFn) {
68  return interleaveWithError(c.begin(), c.end(), eachFn, [&]() { os << ", "; });
69 }
70 
71 /// Return the precedence of a operator as an integer, higher values
72 /// imply higher precedence.
75  .Case<emitc::AddOp>([&](auto op) { return 12; })
76  .Case<emitc::ApplyOp>([&](auto op) { return 15; })
77  .Case<emitc::BitwiseAndOp>([&](auto op) { return 7; })
78  .Case<emitc::BitwiseLeftShiftOp>([&](auto op) { return 11; })
79  .Case<emitc::BitwiseNotOp>([&](auto op) { return 15; })
80  .Case<emitc::BitwiseOrOp>([&](auto op) { return 5; })
81  .Case<emitc::BitwiseRightShiftOp>([&](auto op) { return 11; })
82  .Case<emitc::BitwiseXorOp>([&](auto op) { return 6; })
83  .Case<emitc::CallOp>([&](auto op) { return 16; })
84  .Case<emitc::CallOpaqueOp>([&](auto op) { return 16; })
85  .Case<emitc::CastOp>([&](auto op) { return 15; })
86  .Case<emitc::CmpOp>([&](auto op) -> FailureOr<int> {
87  switch (op.getPredicate()) {
88  case emitc::CmpPredicate::eq:
89  case emitc::CmpPredicate::ne:
90  return 8;
91  case emitc::CmpPredicate::lt:
92  case emitc::CmpPredicate::le:
93  case emitc::CmpPredicate::gt:
94  case emitc::CmpPredicate::ge:
95  return 9;
96  case emitc::CmpPredicate::three_way:
97  return 10;
98  }
99  return op->emitError("unsupported cmp predicate");
100  })
101  .Case<emitc::ConditionalOp>([&](auto op) { return 2; })
102  .Case<emitc::DivOp>([&](auto op) { return 13; })
103  .Case<emitc::LogicalAndOp>([&](auto op) { return 4; })
104  .Case<emitc::LogicalNotOp>([&](auto op) { return 15; })
105  .Case<emitc::LogicalOrOp>([&](auto op) { return 3; })
106  .Case<emitc::MulOp>([&](auto op) { return 13; })
107  .Case<emitc::RemOp>([&](auto op) { return 13; })
108  .Case<emitc::SubOp>([&](auto op) { return 12; })
109  .Case<emitc::UnaryMinusOp>([&](auto op) { return 15; })
110  .Case<emitc::UnaryPlusOp>([&](auto op) { return 15; })
111  .Default([](auto op) { return op->emitError("unsupported operation"); });
112 }
113 
114 namespace {
115 /// Emitter that uses dialect specific emitters to emit C++ code.
116 struct CppEmitter {
117  explicit CppEmitter(raw_ostream &os, bool declareVariablesAtTop);
118 
119  /// Emits attribute or returns failure.
120  LogicalResult emitAttribute(Location loc, Attribute attr);
121 
122  /// Emits operation 'op' with/without training semicolon or returns failure.
123  LogicalResult emitOperation(Operation &op, bool trailingSemicolon);
124 
125  /// Emits type 'type' or returns failure.
126  LogicalResult emitType(Location loc, Type type);
127 
128  /// Emits array of types as a std::tuple of the emitted types.
129  /// - emits void for an empty array;
130  /// - emits the type of the only element for arrays of size one;
131  /// - emits a std::tuple otherwise;
132  LogicalResult emitTypes(Location loc, ArrayRef<Type> types);
133 
134  /// Emits array of types as a std::tuple of the emitted types independently of
135  /// the array size.
136  LogicalResult emitTupleType(Location loc, ArrayRef<Type> types);
137 
138  /// Emits an assignment for a variable which has been declared previously.
139  LogicalResult emitVariableAssignment(OpResult result);
140 
141  /// Emits a variable declaration for a result of an operation.
142  LogicalResult emitVariableDeclaration(OpResult result,
143  bool trailingSemicolon);
144 
145  /// Emits a declaration of a variable with the given type and name.
146  LogicalResult emitVariableDeclaration(Location loc, Type type,
147  StringRef name);
148 
149  /// Emits the variable declaration and assignment prefix for 'op'.
150  /// - emits separate variable followed by std::tie for multi-valued operation;
151  /// - emits single type followed by variable for single result;
152  /// - emits nothing if no value produced by op;
153  /// Emits final '=' operator where a type is produced. Returns failure if
154  /// any result type could not be converted.
155  LogicalResult emitAssignPrefix(Operation &op);
156 
157  /// Emits a label for the block.
158  LogicalResult emitLabel(Block &block);
159 
160  /// Emits the operands and atttributes of the operation. All operands are
161  /// emitted first and then all attributes in alphabetical order.
162  LogicalResult emitOperandsAndAttributes(Operation &op,
163  ArrayRef<StringRef> exclude = {});
164 
165  /// Emits the operands of the operation. All operands are emitted in order.
166  LogicalResult emitOperands(Operation &op);
167 
168  /// Emits value as an operands of an operation
169  LogicalResult emitOperand(Value value);
170 
171  /// Emit an expression as a C expression.
172  LogicalResult emitExpression(ExpressionOp expressionOp);
173 
174  /// Return the existing or a new name for a Value.
175  StringRef getOrCreateName(Value val);
176 
177  // Returns the textual representation of a subscript operation.
178  std::string getSubscriptName(emitc::SubscriptOp op);
179 
180  /// Return the existing or a new label of a Block.
181  StringRef getOrCreateName(Block &block);
182 
183  /// Whether to map an mlir integer to a unsigned integer in C++.
184  bool shouldMapToUnsigned(IntegerType::SignednessSemantics val);
185 
186  /// RAII helper function to manage entering/exiting C++ scopes.
187  struct Scope {
188  Scope(CppEmitter &emitter)
189  : valueMapperScope(emitter.valueMapper),
190  blockMapperScope(emitter.blockMapper), emitter(emitter) {
191  emitter.valueInScopeCount.push(emitter.valueInScopeCount.top());
192  emitter.labelInScopeCount.push(emitter.labelInScopeCount.top());
193  }
194  ~Scope() {
195  emitter.valueInScopeCount.pop();
196  emitter.labelInScopeCount.pop();
197  }
198 
199  private:
200  llvm::ScopedHashTableScope<Value, std::string> valueMapperScope;
201  llvm::ScopedHashTableScope<Block *, std::string> blockMapperScope;
202  CppEmitter &emitter;
203  };
204 
205  /// Returns wether the Value is assigned to a C++ variable in the scope.
206  bool hasValueInScope(Value val);
207 
208  // Returns whether a label is assigned to the block.
209  bool hasBlockLabel(Block &block);
210 
211  /// Returns the output stream.
212  raw_indented_ostream &ostream() { return os; };
213 
214  /// Returns if all variables for op results and basic block arguments need to
215  /// be declared at the beginning of a function.
216  bool shouldDeclareVariablesAtTop() { return declareVariablesAtTop; };
217 
218  /// Get expression currently being emitted.
219  ExpressionOp getEmittedExpression() { return emittedExpression; }
220 
221  /// Determine whether given value is part of the expression potentially being
222  /// emitted.
223  bool isPartOfCurrentExpression(Value value) {
224  if (!emittedExpression)
225  return false;
226  Operation *def = value.getDefiningOp();
227  if (!def)
228  return false;
229  auto operandExpression = dyn_cast<ExpressionOp>(def->getParentOp());
230  return operandExpression == emittedExpression;
231  };
232 
233 private:
234  using ValueMapper = llvm::ScopedHashTable<Value, std::string>;
235  using BlockMapper = llvm::ScopedHashTable<Block *, std::string>;
236 
237  /// Output stream to emit to.
239 
240  /// Boolean to enforce that all variables for op results and block
241  /// arguments are declared at the beginning of the function. This also
242  /// includes results from ops located in nested regions.
243  bool declareVariablesAtTop;
244 
245  /// Map from value to name of C++ variable that contain the name.
246  ValueMapper valueMapper;
247 
248  /// Map from block to name of C++ label.
249  BlockMapper blockMapper;
250 
251  /// The number of values in the current scope. This is used to declare the
252  /// names of values in a scope.
253  std::stack<int64_t> valueInScopeCount;
254  std::stack<int64_t> labelInScopeCount;
255 
256  /// State of the current expression being emitted.
257  ExpressionOp emittedExpression;
258  SmallVector<int> emittedExpressionPrecedence;
259 
260  void pushExpressionPrecedence(int precedence) {
261  emittedExpressionPrecedence.push_back(precedence);
262  }
263  void popExpressionPrecedence() { emittedExpressionPrecedence.pop_back(); }
264  static int lowestPrecedence() { return 0; }
265  int getExpressionPrecedence() {
266  if (emittedExpressionPrecedence.empty())
267  return lowestPrecedence();
268  return emittedExpressionPrecedence.back();
269  }
270 };
271 } // namespace
272 
273 /// Determine whether expression \p expressionOp should be emitted inline, i.e.
274 /// as part of its user. This function recommends inlining of any expressions
275 /// that can be inlined unless it is used by another expression, under the
276 /// assumption that any expression fusion/re-materialization was taken care of
277 /// by transformations run by the backend.
278 static bool shouldBeInlined(ExpressionOp expressionOp) {
279  // Do not inline if expression is marked as such.
280  if (expressionOp.getDoNotInline())
281  return false;
282 
283  // Do not inline expressions with side effects to prevent side-effect
284  // reordering.
285  if (expressionOp.hasSideEffects())
286  return false;
287 
288  // Do not inline expressions with multiple uses.
289  Value result = expressionOp.getResult();
290  if (!result.hasOneUse())
291  return false;
292 
293  // Do not inline expressions used by other expressions, as any desired
294  // expression folding was taken care of by transformations.
295  Operation *user = *result.getUsers().begin();
296  return !user->getParentOfType<ExpressionOp>();
297 }
298 
299 static LogicalResult printConstantOp(CppEmitter &emitter, Operation *operation,
300  Attribute value) {
301  OpResult result = operation->getResult(0);
302 
303  // Only emit an assignment as the variable was already declared when printing
304  // the FuncOp.
305  if (emitter.shouldDeclareVariablesAtTop()) {
306  // Skip the assignment if the emitc.constant has no value.
307  if (auto oAttr = dyn_cast<emitc::OpaqueAttr>(value)) {
308  if (oAttr.getValue().empty())
309  return success();
310  }
311 
312  if (failed(emitter.emitVariableAssignment(result)))
313  return failure();
314  return emitter.emitAttribute(operation->getLoc(), value);
315  }
316 
317  // Emit a variable declaration for an emitc.constant op without value.
318  if (auto oAttr = dyn_cast<emitc::OpaqueAttr>(value)) {
319  if (oAttr.getValue().empty())
320  // The semicolon gets printed by the emitOperation function.
321  return emitter.emitVariableDeclaration(result,
322  /*trailingSemicolon=*/false);
323  }
324 
325  // Emit a variable declaration.
326  if (failed(emitter.emitAssignPrefix(*operation)))
327  return failure();
328  return emitter.emitAttribute(operation->getLoc(), value);
329 }
330 
331 static LogicalResult printOperation(CppEmitter &emitter,
332  emitc::ConstantOp constantOp) {
333  Operation *operation = constantOp.getOperation();
334  Attribute value = constantOp.getValue();
335 
336  return printConstantOp(emitter, operation, value);
337 }
338 
339 static LogicalResult printOperation(CppEmitter &emitter,
340  emitc::VariableOp variableOp) {
341  Operation *operation = variableOp.getOperation();
342  Attribute value = variableOp.getValue();
343 
344  return printConstantOp(emitter, operation, value);
345 }
346 
347 static LogicalResult printOperation(CppEmitter &emitter,
348  emitc::AssignOp assignOp) {
349  OpResult result = assignOp.getVar().getDefiningOp()->getResult(0);
350 
351  if (failed(emitter.emitVariableAssignment(result)))
352  return failure();
353 
354  return emitter.emitOperand(assignOp.getValue());
355 }
356 
357 static LogicalResult printOperation(CppEmitter &emitter,
358  emitc::SubscriptOp subscriptOp) {
359  // Add name to cache so that `hasValueInScope` works.
360  emitter.getOrCreateName(subscriptOp.getResult());
361  return success();
362 }
363 
364 static LogicalResult printBinaryOperation(CppEmitter &emitter,
365  Operation *operation,
366  StringRef binaryOperator) {
367  raw_ostream &os = emitter.ostream();
368 
369  if (failed(emitter.emitAssignPrefix(*operation)))
370  return failure();
371 
372  if (failed(emitter.emitOperand(operation->getOperand(0))))
373  return failure();
374 
375  os << " " << binaryOperator << " ";
376 
377  if (failed(emitter.emitOperand(operation->getOperand(1))))
378  return failure();
379 
380  return success();
381 }
382 
383 static LogicalResult printUnaryOperation(CppEmitter &emitter,
384  Operation *operation,
385  StringRef unaryOperator) {
386  raw_ostream &os = emitter.ostream();
387 
388  if (failed(emitter.emitAssignPrefix(*operation)))
389  return failure();
390 
391  os << unaryOperator;
392 
393  if (failed(emitter.emitOperand(operation->getOperand(0))))
394  return failure();
395 
396  return success();
397 }
398 
399 static LogicalResult printOperation(CppEmitter &emitter, emitc::AddOp addOp) {
400  Operation *operation = addOp.getOperation();
401 
402  return printBinaryOperation(emitter, operation, "+");
403 }
404 
405 static LogicalResult printOperation(CppEmitter &emitter, emitc::DivOp divOp) {
406  Operation *operation = divOp.getOperation();
407 
408  return printBinaryOperation(emitter, operation, "/");
409 }
410 
411 static LogicalResult printOperation(CppEmitter &emitter, emitc::MulOp mulOp) {
412  Operation *operation = mulOp.getOperation();
413 
414  return printBinaryOperation(emitter, operation, "*");
415 }
416 
417 static LogicalResult printOperation(CppEmitter &emitter, emitc::RemOp remOp) {
418  Operation *operation = remOp.getOperation();
419 
420  return printBinaryOperation(emitter, operation, "%");
421 }
422 
423 static LogicalResult printOperation(CppEmitter &emitter, emitc::SubOp subOp) {
424  Operation *operation = subOp.getOperation();
425 
426  return printBinaryOperation(emitter, operation, "-");
427 }
428 
429 static LogicalResult printOperation(CppEmitter &emitter, emitc::CmpOp cmpOp) {
430  Operation *operation = cmpOp.getOperation();
431 
432  StringRef binaryOperator;
433 
434  switch (cmpOp.getPredicate()) {
435  case emitc::CmpPredicate::eq:
436  binaryOperator = "==";
437  break;
438  case emitc::CmpPredicate::ne:
439  binaryOperator = "!=";
440  break;
441  case emitc::CmpPredicate::lt:
442  binaryOperator = "<";
443  break;
444  case emitc::CmpPredicate::le:
445  binaryOperator = "<=";
446  break;
447  case emitc::CmpPredicate::gt:
448  binaryOperator = ">";
449  break;
450  case emitc::CmpPredicate::ge:
451  binaryOperator = ">=";
452  break;
453  case emitc::CmpPredicate::three_way:
454  binaryOperator = "<=>";
455  break;
456  }
457 
458  return printBinaryOperation(emitter, operation, binaryOperator);
459 }
460 
461 static LogicalResult printOperation(CppEmitter &emitter,
462  emitc::ConditionalOp conditionalOp) {
463  raw_ostream &os = emitter.ostream();
464 
465  if (failed(emitter.emitAssignPrefix(*conditionalOp)))
466  return failure();
467 
468  if (failed(emitter.emitOperand(conditionalOp.getCondition())))
469  return failure();
470 
471  os << " ? ";
472 
473  if (failed(emitter.emitOperand(conditionalOp.getTrueValue())))
474  return failure();
475 
476  os << " : ";
477 
478  if (failed(emitter.emitOperand(conditionalOp.getFalseValue())))
479  return failure();
480 
481  return success();
482 }
483 
484 static LogicalResult printOperation(CppEmitter &emitter,
485  emitc::VerbatimOp verbatimOp) {
486  raw_ostream &os = emitter.ostream();
487 
488  os << verbatimOp.getValue();
489 
490  return success();
491 }
492 
493 static LogicalResult printOperation(CppEmitter &emitter,
494  cf::BranchOp branchOp) {
495  raw_ostream &os = emitter.ostream();
496  Block &successor = *branchOp.getSuccessor();
497 
498  for (auto pair :
499  llvm::zip(branchOp.getOperands(), successor.getArguments())) {
500  Value &operand = std::get<0>(pair);
501  BlockArgument &argument = std::get<1>(pair);
502  os << emitter.getOrCreateName(argument) << " = "
503  << emitter.getOrCreateName(operand) << ";\n";
504  }
505 
506  os << "goto ";
507  if (!(emitter.hasBlockLabel(successor)))
508  return branchOp.emitOpError("unable to find label for successor block");
509  os << emitter.getOrCreateName(successor);
510  return success();
511 }
512 
513 static LogicalResult printOperation(CppEmitter &emitter,
514  cf::CondBranchOp condBranchOp) {
515  raw_indented_ostream &os = emitter.ostream();
516  Block &trueSuccessor = *condBranchOp.getTrueDest();
517  Block &falseSuccessor = *condBranchOp.getFalseDest();
518 
519  os << "if (" << emitter.getOrCreateName(condBranchOp.getCondition())
520  << ") {\n";
521 
522  os.indent();
523 
524  // If condition is true.
525  for (auto pair : llvm::zip(condBranchOp.getTrueOperands(),
526  trueSuccessor.getArguments())) {
527  Value &operand = std::get<0>(pair);
528  BlockArgument &argument = std::get<1>(pair);
529  os << emitter.getOrCreateName(argument) << " = "
530  << emitter.getOrCreateName(operand) << ";\n";
531  }
532 
533  os << "goto ";
534  if (!(emitter.hasBlockLabel(trueSuccessor))) {
535  return condBranchOp.emitOpError("unable to find label for successor block");
536  }
537  os << emitter.getOrCreateName(trueSuccessor) << ";\n";
538  os.unindent() << "} else {\n";
539  os.indent();
540  // If condition is false.
541  for (auto pair : llvm::zip(condBranchOp.getFalseOperands(),
542  falseSuccessor.getArguments())) {
543  Value &operand = std::get<0>(pair);
544  BlockArgument &argument = std::get<1>(pair);
545  os << emitter.getOrCreateName(argument) << " = "
546  << emitter.getOrCreateName(operand) << ";\n";
547  }
548 
549  os << "goto ";
550  if (!(emitter.hasBlockLabel(falseSuccessor))) {
551  return condBranchOp.emitOpError()
552  << "unable to find label for successor block";
553  }
554  os << emitter.getOrCreateName(falseSuccessor) << ";\n";
555  os.unindent() << "}";
556  return success();
557 }
558 
559 static LogicalResult printCallOperation(CppEmitter &emitter, Operation *callOp,
560  StringRef callee) {
561  if (failed(emitter.emitAssignPrefix(*callOp)))
562  return failure();
563 
564  raw_ostream &os = emitter.ostream();
565  os << callee << "(";
566  if (failed(emitter.emitOperands(*callOp)))
567  return failure();
568  os << ")";
569  return success();
570 }
571 
572 static LogicalResult printOperation(CppEmitter &emitter, func::CallOp callOp) {
573  Operation *operation = callOp.getOperation();
574  StringRef callee = callOp.getCallee();
575 
576  return printCallOperation(emitter, operation, callee);
577 }
578 
579 static LogicalResult printOperation(CppEmitter &emitter, emitc::CallOp callOp) {
580  Operation *operation = callOp.getOperation();
581  StringRef callee = callOp.getCallee();
582 
583  return printCallOperation(emitter, operation, callee);
584 }
585 
586 static LogicalResult printOperation(CppEmitter &emitter,
587  emitc::CallOpaqueOp callOpaqueOp) {
588  raw_ostream &os = emitter.ostream();
589  Operation &op = *callOpaqueOp.getOperation();
590 
591  if (failed(emitter.emitAssignPrefix(op)))
592  return failure();
593  os << callOpaqueOp.getCallee();
594 
595  auto emitArgs = [&](Attribute attr) -> LogicalResult {
596  if (auto t = dyn_cast<IntegerAttr>(attr)) {
597  // Index attributes are treated specially as operand index.
598  if (t.getType().isIndex()) {
599  int64_t idx = t.getInt();
600  Value operand = op.getOperand(idx);
601  auto literalDef =
602  dyn_cast_if_present<LiteralOp>(operand.getDefiningOp());
603  if (!literalDef && !emitter.hasValueInScope(operand))
604  return op.emitOpError("operand ")
605  << idx << "'s value not defined in scope";
606  os << emitter.getOrCreateName(operand);
607  return success();
608  }
609  }
610  if (failed(emitter.emitAttribute(op.getLoc(), attr)))
611  return failure();
612 
613  return success();
614  };
615 
616  if (callOpaqueOp.getTemplateArgs()) {
617  os << "<";
618  if (failed(interleaveCommaWithError(*callOpaqueOp.getTemplateArgs(), os,
619  emitArgs)))
620  return failure();
621  os << ">";
622  }
623 
624  os << "(";
625 
626  LogicalResult emittedArgs =
627  callOpaqueOp.getArgs()
628  ? interleaveCommaWithError(*callOpaqueOp.getArgs(), os, emitArgs)
629  : emitter.emitOperands(op);
630  if (failed(emittedArgs))
631  return failure();
632  os << ")";
633  return success();
634 }
635 
636 static LogicalResult printOperation(CppEmitter &emitter,
637  emitc::ApplyOp applyOp) {
638  raw_ostream &os = emitter.ostream();
639  Operation &op = *applyOp.getOperation();
640 
641  if (failed(emitter.emitAssignPrefix(op)))
642  return failure();
643  os << applyOp.getApplicableOperator();
644  os << emitter.getOrCreateName(applyOp.getOperand());
645 
646  return success();
647 }
648 
649 static LogicalResult printOperation(CppEmitter &emitter,
650  emitc::BitwiseAndOp bitwiseAndOp) {
651  Operation *operation = bitwiseAndOp.getOperation();
652  return printBinaryOperation(emitter, operation, "&");
653 }
654 
655 static LogicalResult
656 printOperation(CppEmitter &emitter,
657  emitc::BitwiseLeftShiftOp bitwiseLeftShiftOp) {
658  Operation *operation = bitwiseLeftShiftOp.getOperation();
659  return printBinaryOperation(emitter, operation, "<<");
660 }
661 
662 static LogicalResult printOperation(CppEmitter &emitter,
663  emitc::BitwiseNotOp bitwiseNotOp) {
664  Operation *operation = bitwiseNotOp.getOperation();
665  return printUnaryOperation(emitter, operation, "~");
666 }
667 
668 static LogicalResult printOperation(CppEmitter &emitter,
669  emitc::BitwiseOrOp bitwiseOrOp) {
670  Operation *operation = bitwiseOrOp.getOperation();
671  return printBinaryOperation(emitter, operation, "|");
672 }
673 
674 static LogicalResult
675 printOperation(CppEmitter &emitter,
676  emitc::BitwiseRightShiftOp bitwiseRightShiftOp) {
677  Operation *operation = bitwiseRightShiftOp.getOperation();
678  return printBinaryOperation(emitter, operation, ">>");
679 }
680 
681 static LogicalResult printOperation(CppEmitter &emitter,
682  emitc::BitwiseXorOp bitwiseXorOp) {
683  Operation *operation = bitwiseXorOp.getOperation();
684  return printBinaryOperation(emitter, operation, "^");
685 }
686 
687 static LogicalResult printOperation(CppEmitter &emitter,
688  emitc::UnaryPlusOp unaryPlusOp) {
689  Operation *operation = unaryPlusOp.getOperation();
690  return printUnaryOperation(emitter, operation, "+");
691 }
692 
693 static LogicalResult printOperation(CppEmitter &emitter,
694  emitc::UnaryMinusOp unaryMinusOp) {
695  Operation *operation = unaryMinusOp.getOperation();
696  return printUnaryOperation(emitter, operation, "-");
697 }
698 
699 static LogicalResult printOperation(CppEmitter &emitter, emitc::CastOp castOp) {
700  raw_ostream &os = emitter.ostream();
701  Operation &op = *castOp.getOperation();
702 
703  if (failed(emitter.emitAssignPrefix(op)))
704  return failure();
705  os << "(";
706  if (failed(emitter.emitType(op.getLoc(), op.getResult(0).getType())))
707  return failure();
708  os << ") ";
709  return emitter.emitOperand(castOp.getOperand());
710 }
711 
712 static LogicalResult printOperation(CppEmitter &emitter,
713  emitc::ExpressionOp expressionOp) {
714  if (shouldBeInlined(expressionOp))
715  return success();
716 
717  Operation &op = *expressionOp.getOperation();
718 
719  if (failed(emitter.emitAssignPrefix(op)))
720  return failure();
721 
722  return emitter.emitExpression(expressionOp);
723 }
724 
725 static LogicalResult printOperation(CppEmitter &emitter,
726  emitc::IncludeOp includeOp) {
727  raw_ostream &os = emitter.ostream();
728 
729  os << "#include ";
730  if (includeOp.getIsStandardInclude())
731  os << "<" << includeOp.getInclude() << ">";
732  else
733  os << "\"" << includeOp.getInclude() << "\"";
734 
735  return success();
736 }
737 
738 static LogicalResult printOperation(CppEmitter &emitter,
739  emitc::LogicalAndOp logicalAndOp) {
740  Operation *operation = logicalAndOp.getOperation();
741  return printBinaryOperation(emitter, operation, "&&");
742 }
743 
744 static LogicalResult printOperation(CppEmitter &emitter,
745  emitc::LogicalNotOp logicalNotOp) {
746  Operation *operation = logicalNotOp.getOperation();
747  return printUnaryOperation(emitter, operation, "!");
748 }
749 
750 static LogicalResult printOperation(CppEmitter &emitter,
751  emitc::LogicalOrOp logicalOrOp) {
752  Operation *operation = logicalOrOp.getOperation();
753  return printBinaryOperation(emitter, operation, "||");
754 }
755 
756 static LogicalResult printOperation(CppEmitter &emitter, emitc::ForOp forOp) {
757 
758  raw_indented_ostream &os = emitter.ostream();
759 
760  // Utility function to determine whether a value is an expression that will be
761  // inlined, and as such should be wrapped in parentheses in order to guarantee
762  // its precedence and associativity.
763  auto requiresParentheses = [&](Value value) {
764  auto expressionOp =
765  dyn_cast_if_present<ExpressionOp>(value.getDefiningOp());
766  if (!expressionOp)
767  return false;
768  return shouldBeInlined(expressionOp);
769  };
770 
771  os << "for (";
772  if (failed(
773  emitter.emitType(forOp.getLoc(), forOp.getInductionVar().getType())))
774  return failure();
775  os << " ";
776  os << emitter.getOrCreateName(forOp.getInductionVar());
777  os << " = ";
778  if (failed(emitter.emitOperand(forOp.getLowerBound())))
779  return failure();
780  os << "; ";
781  os << emitter.getOrCreateName(forOp.getInductionVar());
782  os << " < ";
783  Value upperBound = forOp.getUpperBound();
784  bool upperBoundRequiresParentheses = requiresParentheses(upperBound);
785  if (upperBoundRequiresParentheses)
786  os << "(";
787  if (failed(emitter.emitOperand(upperBound)))
788  return failure();
789  if (upperBoundRequiresParentheses)
790  os << ")";
791  os << "; ";
792  os << emitter.getOrCreateName(forOp.getInductionVar());
793  os << " += ";
794  if (failed(emitter.emitOperand(forOp.getStep())))
795  return failure();
796  os << ") {\n";
797  os.indent();
798 
799  Region &forRegion = forOp.getRegion();
800  auto regionOps = forRegion.getOps();
801 
802  // We skip the trailing yield op.
803  for (auto it = regionOps.begin(); std::next(it) != regionOps.end(); ++it) {
804  if (failed(emitter.emitOperation(*it, /*trailingSemicolon=*/true)))
805  return failure();
806  }
807 
808  os.unindent() << "}";
809 
810  return success();
811 }
812 
813 static LogicalResult printOperation(CppEmitter &emitter, emitc::IfOp ifOp) {
814  raw_indented_ostream &os = emitter.ostream();
815 
816  // Helper function to emit all ops except the last one, expected to be
817  // emitc::yield.
818  auto emitAllExceptLast = [&emitter](Region &region) {
819  Region::OpIterator it = region.op_begin(), end = region.op_end();
820  for (; std::next(it) != end; ++it) {
821  if (failed(emitter.emitOperation(*it, /*trailingSemicolon=*/true)))
822  return failure();
823  }
824  assert(isa<emitc::YieldOp>(*it) &&
825  "Expected last operation in the region to be emitc::yield");
826  return success();
827  };
828 
829  os << "if (";
830  if (failed(emitter.emitOperand(ifOp.getCondition())))
831  return failure();
832  os << ") {\n";
833  os.indent();
834  if (failed(emitAllExceptLast(ifOp.getThenRegion())))
835  return failure();
836  os.unindent() << "}";
837 
838  Region &elseRegion = ifOp.getElseRegion();
839  if (!elseRegion.empty()) {
840  os << " else {\n";
841  os.indent();
842  if (failed(emitAllExceptLast(elseRegion)))
843  return failure();
844  os.unindent() << "}";
845  }
846 
847  return success();
848 }
849 
850 static LogicalResult printOperation(CppEmitter &emitter,
851  func::ReturnOp returnOp) {
852  raw_ostream &os = emitter.ostream();
853  os << "return";
854  switch (returnOp.getNumOperands()) {
855  case 0:
856  return success();
857  case 1:
858  os << " ";
859  if (failed(emitter.emitOperand(returnOp.getOperand(0))))
860  return failure();
861  return success();
862  default:
863  os << " std::make_tuple(";
864  if (failed(emitter.emitOperandsAndAttributes(*returnOp.getOperation())))
865  return failure();
866  os << ")";
867  return success();
868  }
869 }
870 
871 static LogicalResult printOperation(CppEmitter &emitter,
872  emitc::ReturnOp returnOp) {
873  raw_ostream &os = emitter.ostream();
874  os << "return";
875  if (returnOp.getNumOperands() == 0)
876  return success();
877 
878  os << " ";
879  if (failed(emitter.emitOperand(returnOp.getOperand())))
880  return failure();
881  return success();
882 }
883 
884 static LogicalResult printOperation(CppEmitter &emitter, ModuleOp moduleOp) {
885  CppEmitter::Scope scope(emitter);
886 
887  for (Operation &op : moduleOp) {
888  if (failed(emitter.emitOperation(op, /*trailingSemicolon=*/false)))
889  return failure();
890  }
891  return success();
892 }
893 
894 static LogicalResult printFunctionArgs(CppEmitter &emitter,
895  Operation *functionOp,
896  ArrayRef<Type> arguments) {
897  raw_indented_ostream &os = emitter.ostream();
898 
899  return (
900  interleaveCommaWithError(arguments, os, [&](Type arg) -> LogicalResult {
901  return emitter.emitType(functionOp->getLoc(), arg);
902  }));
903 }
904 
905 static LogicalResult printFunctionArgs(CppEmitter &emitter,
906  Operation *functionOp,
907  Region::BlockArgListType arguments) {
908  raw_indented_ostream &os = emitter.ostream();
909 
910  return (interleaveCommaWithError(
911  arguments, os, [&](BlockArgument arg) -> LogicalResult {
912  return emitter.emitVariableDeclaration(
913  functionOp->getLoc(), arg.getType(), emitter.getOrCreateName(arg));
914  }));
915 }
916 
917 static LogicalResult printFunctionBody(CppEmitter &emitter,
918  Operation *functionOp,
919  Region::BlockListType &blocks) {
920  raw_indented_ostream &os = emitter.ostream();
921  os.indent();
922 
923  if (emitter.shouldDeclareVariablesAtTop()) {
924  // Declare all variables that hold op results including those from nested
925  // regions.
926  WalkResult result =
927  functionOp->walk<WalkOrder::PreOrder>([&](Operation *op) -> WalkResult {
928  if (isa<emitc::LiteralOp>(op) ||
929  isa<emitc::ExpressionOp>(op->getParentOp()) ||
930  (isa<emitc::ExpressionOp>(op) &&
931  shouldBeInlined(cast<emitc::ExpressionOp>(op))))
932  return WalkResult::skip();
933  for (OpResult result : op->getResults()) {
934  if (failed(emitter.emitVariableDeclaration(
935  result, /*trailingSemicolon=*/true))) {
936  return WalkResult(
937  op->emitError("unable to declare result variable for op"));
938  }
939  }
940  return WalkResult::advance();
941  });
942  if (result.wasInterrupted())
943  return failure();
944  }
945 
946  // Create label names for basic blocks.
947  for (Block &block : blocks) {
948  emitter.getOrCreateName(block);
949  }
950 
951  // Declare variables for basic block arguments.
952  for (Block &block : llvm::drop_begin(blocks)) {
953  for (BlockArgument &arg : block.getArguments()) {
954  if (emitter.hasValueInScope(arg))
955  return functionOp->emitOpError(" block argument #")
956  << arg.getArgNumber() << " is out of scope";
957  if (isa<ArrayType>(arg.getType()))
958  return functionOp->emitOpError("cannot emit block argument #")
959  << arg.getArgNumber() << " with array type";
960  if (failed(
961  emitter.emitType(block.getParentOp()->getLoc(), arg.getType()))) {
962  return failure();
963  }
964  os << " " << emitter.getOrCreateName(arg) << ";\n";
965  }
966  }
967 
968  for (Block &block : blocks) {
969  // Only print a label if the block has predecessors.
970  if (!block.hasNoPredecessors()) {
971  if (failed(emitter.emitLabel(block)))
972  return failure();
973  }
974  for (Operation &op : block.getOperations()) {
975  // When generating code for an emitc.if or cf.cond_br op no semicolon
976  // needs to be printed after the closing brace.
977  // When generating code for an emitc.for and emitc.verbatim op, printing a
978  // trailing semicolon is handled within the printOperation function.
979  bool trailingSemicolon =
980  !isa<cf::CondBranchOp, emitc::DeclareFuncOp, emitc::ForOp,
981  emitc::IfOp, emitc::LiteralOp, emitc::VerbatimOp>(op);
982 
983  if (failed(emitter.emitOperation(
984  op, /*trailingSemicolon=*/trailingSemicolon)))
985  return failure();
986  }
987  }
988 
989  os.unindent();
990 
991  return success();
992 }
993 
994 static LogicalResult printOperation(CppEmitter &emitter,
995  func::FuncOp functionOp) {
996  // We need to declare variables at top if the function has multiple blocks.
997  if (!emitter.shouldDeclareVariablesAtTop() &&
998  functionOp.getBlocks().size() > 1) {
999  return functionOp.emitOpError(
1000  "with multiple blocks needs variables declared at top");
1001  }
1002 
1003  if (llvm::any_of(functionOp.getResultTypes(), llvm::IsaPred<ArrayType>)) {
1004  return functionOp.emitOpError() << "cannot emit array type as result type";
1005  }
1006 
1007  CppEmitter::Scope scope(emitter);
1008  raw_indented_ostream &os = emitter.ostream();
1009  if (failed(emitter.emitTypes(functionOp.getLoc(),
1010  functionOp.getFunctionType().getResults())))
1011  return failure();
1012  os << " " << functionOp.getName();
1013 
1014  os << "(";
1015  Operation *operation = functionOp.getOperation();
1016  if (failed(printFunctionArgs(emitter, operation, functionOp.getArguments())))
1017  return failure();
1018  os << ") {\n";
1019  if (failed(printFunctionBody(emitter, operation, functionOp.getBlocks())))
1020  return failure();
1021  os << "}\n";
1022 
1023  return success();
1024 }
1025 
1026 static LogicalResult printOperation(CppEmitter &emitter,
1027  emitc::FuncOp functionOp) {
1028  // We need to declare variables at top if the function has multiple blocks.
1029  if (!emitter.shouldDeclareVariablesAtTop() &&
1030  functionOp.getBlocks().size() > 1) {
1031  return functionOp.emitOpError(
1032  "with multiple blocks needs variables declared at top");
1033  }
1034 
1035  CppEmitter::Scope scope(emitter);
1036  raw_indented_ostream &os = emitter.ostream();
1037  if (functionOp.getSpecifiers()) {
1038  for (Attribute specifier : functionOp.getSpecifiersAttr()) {
1039  os << cast<StringAttr>(specifier).str() << " ";
1040  }
1041  }
1042 
1043  if (failed(emitter.emitTypes(functionOp.getLoc(),
1044  functionOp.getFunctionType().getResults())))
1045  return failure();
1046  os << " " << functionOp.getName();
1047 
1048  os << "(";
1049  Operation *operation = functionOp.getOperation();
1050  if (functionOp.isExternal()) {
1051  if (failed(printFunctionArgs(emitter, operation,
1052  functionOp.getArgumentTypes())))
1053  return failure();
1054  os << ");";
1055  return success();
1056  }
1057  if (failed(printFunctionArgs(emitter, operation, functionOp.getArguments())))
1058  return failure();
1059  os << ") {\n";
1060  if (failed(printFunctionBody(emitter, operation, functionOp.getBlocks())))
1061  return failure();
1062  os << "}\n";
1063 
1064  return success();
1065 }
1066 
1067 static LogicalResult printOperation(CppEmitter &emitter,
1068  DeclareFuncOp declareFuncOp) {
1069  CppEmitter::Scope scope(emitter);
1070  raw_indented_ostream &os = emitter.ostream();
1071 
1072  auto functionOp = SymbolTable::lookupNearestSymbolFrom<emitc::FuncOp>(
1073  declareFuncOp, declareFuncOp.getSymNameAttr());
1074 
1075  if (!functionOp)
1076  return failure();
1077 
1078  if (functionOp.getSpecifiers()) {
1079  for (Attribute specifier : functionOp.getSpecifiersAttr()) {
1080  os << cast<StringAttr>(specifier).str() << " ";
1081  }
1082  }
1083 
1084  if (failed(emitter.emitTypes(functionOp.getLoc(),
1085  functionOp.getFunctionType().getResults())))
1086  return failure();
1087  os << " " << functionOp.getName();
1088 
1089  os << "(";
1090  Operation *operation = functionOp.getOperation();
1091  if (failed(printFunctionArgs(emitter, operation, functionOp.getArguments())))
1092  return failure();
1093  os << ");";
1094 
1095  return success();
1096 }
1097 
1098 CppEmitter::CppEmitter(raw_ostream &os, bool declareVariablesAtTop)
1099  : os(os), declareVariablesAtTop(declareVariablesAtTop) {
1100  valueInScopeCount.push(0);
1101  labelInScopeCount.push(0);
1102 }
1103 
1104 std::string CppEmitter::getSubscriptName(emitc::SubscriptOp op) {
1105  std::string out;
1106  llvm::raw_string_ostream ss(out);
1107  ss << getOrCreateName(op.getValue());
1108  for (auto index : op.getIndices()) {
1109  ss << "[" << getOrCreateName(index) << "]";
1110  }
1111  return out;
1112 }
1113 
1114 /// Return the existing or a new name for a Value.
1115 StringRef CppEmitter::getOrCreateName(Value val) {
1116  if (auto literal = dyn_cast_if_present<emitc::LiteralOp>(val.getDefiningOp()))
1117  return literal.getValue();
1118  if (!valueMapper.count(val)) {
1119  if (auto subscript =
1120  dyn_cast_if_present<emitc::SubscriptOp>(val.getDefiningOp())) {
1121  valueMapper.insert(val, getSubscriptName(subscript));
1122  } else {
1123  valueMapper.insert(val, formatv("v{0}", ++valueInScopeCount.top()));
1124  }
1125  }
1126  return *valueMapper.begin(val);
1127 }
1128 
1129 /// Return the existing or a new label for a Block.
1130 StringRef CppEmitter::getOrCreateName(Block &block) {
1131  if (!blockMapper.count(&block))
1132  blockMapper.insert(&block, formatv("label{0}", ++labelInScopeCount.top()));
1133  return *blockMapper.begin(&block);
1134 }
1135 
1136 bool CppEmitter::shouldMapToUnsigned(IntegerType::SignednessSemantics val) {
1137  switch (val) {
1138  case IntegerType::Signless:
1139  return false;
1140  case IntegerType::Signed:
1141  return false;
1142  case IntegerType::Unsigned:
1143  return true;
1144  }
1145  llvm_unreachable("Unexpected IntegerType::SignednessSemantics");
1146 }
1147 
1148 bool CppEmitter::hasValueInScope(Value val) { return valueMapper.count(val); }
1149 
1150 bool CppEmitter::hasBlockLabel(Block &block) {
1151  return blockMapper.count(&block);
1152 }
1153 
1154 LogicalResult CppEmitter::emitAttribute(Location loc, Attribute attr) {
1155  auto printInt = [&](const APInt &val, bool isUnsigned) {
1156  if (val.getBitWidth() == 1) {
1157  if (val.getBoolValue())
1158  os << "true";
1159  else
1160  os << "false";
1161  } else {
1162  SmallString<128> strValue;
1163  val.toString(strValue, 10, !isUnsigned, false);
1164  os << strValue;
1165  }
1166  };
1167 
1168  auto printFloat = [&](const APFloat &val) {
1169  if (val.isFinite()) {
1170  SmallString<128> strValue;
1171  // Use default values of toString except don't truncate zeros.
1172  val.toString(strValue, 0, 0, false);
1173  os << strValue;
1174  switch (llvm::APFloatBase::SemanticsToEnum(val.getSemantics())) {
1175  case llvm::APFloatBase::S_IEEEsingle:
1176  os << "f";
1177  break;
1178  case llvm::APFloatBase::S_IEEEdouble:
1179  break;
1180  default:
1181  llvm_unreachable("unsupported floating point type");
1182  };
1183  } else if (val.isNaN()) {
1184  os << "NAN";
1185  } else if (val.isInfinity()) {
1186  if (val.isNegative())
1187  os << "-";
1188  os << "INFINITY";
1189  }
1190  };
1191 
1192  // Print floating point attributes.
1193  if (auto fAttr = dyn_cast<FloatAttr>(attr)) {
1194  if (!isa<Float32Type, Float64Type>(fAttr.getType())) {
1195  return emitError(loc,
1196  "expected floating point attribute to be f32 or f64");
1197  }
1198  printFloat(fAttr.getValue());
1199  return success();
1200  }
1201  if (auto dense = dyn_cast<DenseFPElementsAttr>(attr)) {
1202  if (!isa<Float32Type, Float64Type>(dense.getElementType())) {
1203  return emitError(loc,
1204  "expected floating point attribute to be f32 or f64");
1205  }
1206  os << '{';
1207  interleaveComma(dense, os, [&](const APFloat &val) { printFloat(val); });
1208  os << '}';
1209  return success();
1210  }
1211 
1212  // Print integer attributes.
1213  if (auto iAttr = dyn_cast<IntegerAttr>(attr)) {
1214  if (auto iType = dyn_cast<IntegerType>(iAttr.getType())) {
1215  printInt(iAttr.getValue(), shouldMapToUnsigned(iType.getSignedness()));
1216  return success();
1217  }
1218  if (auto iType = dyn_cast<IndexType>(iAttr.getType())) {
1219  printInt(iAttr.getValue(), false);
1220  return success();
1221  }
1222  }
1223  if (auto dense = dyn_cast<DenseIntElementsAttr>(attr)) {
1224  if (auto iType = dyn_cast<IntegerType>(
1225  cast<TensorType>(dense.getType()).getElementType())) {
1226  os << '{';
1227  interleaveComma(dense, os, [&](const APInt &val) {
1228  printInt(val, shouldMapToUnsigned(iType.getSignedness()));
1229  });
1230  os << '}';
1231  return success();
1232  }
1233  if (auto iType = dyn_cast<IndexType>(
1234  cast<TensorType>(dense.getType()).getElementType())) {
1235  os << '{';
1236  interleaveComma(dense, os,
1237  [&](const APInt &val) { printInt(val, false); });
1238  os << '}';
1239  return success();
1240  }
1241  }
1242 
1243  // Print opaque attributes.
1244  if (auto oAttr = dyn_cast<emitc::OpaqueAttr>(attr)) {
1245  os << oAttr.getValue();
1246  return success();
1247  }
1248 
1249  // Print symbolic reference attributes.
1250  if (auto sAttr = dyn_cast<SymbolRefAttr>(attr)) {
1251  if (sAttr.getNestedReferences().size() > 1)
1252  return emitError(loc, "attribute has more than 1 nested reference");
1253  os << sAttr.getRootReference().getValue();
1254  return success();
1255  }
1256 
1257  // Print type attributes.
1258  if (auto type = dyn_cast<TypeAttr>(attr))
1259  return emitType(loc, type.getValue());
1260 
1261  return emitError(loc, "cannot emit attribute: ") << attr;
1262 }
1263 
1264 LogicalResult CppEmitter::emitExpression(ExpressionOp expressionOp) {
1265  assert(emittedExpressionPrecedence.empty() &&
1266  "Expected precedence stack to be empty");
1267  Operation *rootOp = expressionOp.getRootOp();
1268 
1269  emittedExpression = expressionOp;
1270  FailureOr<int> precedence = getOperatorPrecedence(rootOp);
1271  if (failed(precedence))
1272  return failure();
1273  pushExpressionPrecedence(precedence.value());
1274 
1275  if (failed(emitOperation(*rootOp, /*trailingSemicolon=*/false)))
1276  return failure();
1277 
1278  popExpressionPrecedence();
1279  assert(emittedExpressionPrecedence.empty() &&
1280  "Expected precedence stack to be empty");
1281  emittedExpression = nullptr;
1282 
1283  return success();
1284 }
1285 
1286 LogicalResult CppEmitter::emitOperand(Value value) {
1287  if (isPartOfCurrentExpression(value)) {
1288  Operation *def = value.getDefiningOp();
1289  assert(def && "Expected operand to be defined by an operation");
1290  FailureOr<int> precedence = getOperatorPrecedence(def);
1291  if (failed(precedence))
1292  return failure();
1293  bool encloseInParenthesis = precedence.value() < getExpressionPrecedence();
1294  if (encloseInParenthesis) {
1295  os << "(";
1296  pushExpressionPrecedence(lowestPrecedence());
1297  } else
1298  pushExpressionPrecedence(precedence.value());
1299 
1300  if (failed(emitOperation(*def, /*trailingSemicolon=*/false)))
1301  return failure();
1302 
1303  if (encloseInParenthesis)
1304  os << ")";
1305 
1306  popExpressionPrecedence();
1307  return success();
1308  }
1309 
1310  auto expressionOp = dyn_cast_if_present<ExpressionOp>(value.getDefiningOp());
1311  if (expressionOp && shouldBeInlined(expressionOp))
1312  return emitExpression(expressionOp);
1313 
1314  auto literalOp = dyn_cast_if_present<LiteralOp>(value.getDefiningOp());
1315  if (!literalOp && !hasValueInScope(value))
1316  return failure();
1317  os << getOrCreateName(value);
1318  return success();
1319 }
1320 
1321 LogicalResult CppEmitter::emitOperands(Operation &op) {
1322  return interleaveCommaWithError(op.getOperands(), os, [&](Value operand) {
1323  // If an expression is being emitted, push lowest precedence as these
1324  // operands are either wrapped by parenthesis.
1325  if (getEmittedExpression())
1326  pushExpressionPrecedence(lowestPrecedence());
1327  if (failed(emitOperand(operand)))
1328  return failure();
1329  if (getEmittedExpression())
1330  popExpressionPrecedence();
1331  return success();
1332  });
1333 }
1334 
1336 CppEmitter::emitOperandsAndAttributes(Operation &op,
1337  ArrayRef<StringRef> exclude) {
1338  if (failed(emitOperands(op)))
1339  return failure();
1340  // Insert comma in between operands and non-filtered attributes if needed.
1341  if (op.getNumOperands() > 0) {
1342  for (NamedAttribute attr : op.getAttrs()) {
1343  if (!llvm::is_contained(exclude, attr.getName().strref())) {
1344  os << ", ";
1345  break;
1346  }
1347  }
1348  }
1349  // Emit attributes.
1350  auto emitNamedAttribute = [&](NamedAttribute attr) -> LogicalResult {
1351  if (llvm::is_contained(exclude, attr.getName().strref()))
1352  return success();
1353  os << "/* " << attr.getName().getValue() << " */";
1354  if (failed(emitAttribute(op.getLoc(), attr.getValue())))
1355  return failure();
1356  return success();
1357  };
1358  return interleaveCommaWithError(op.getAttrs(), os, emitNamedAttribute);
1359 }
1360 
1361 LogicalResult CppEmitter::emitVariableAssignment(OpResult result) {
1362  if (!hasValueInScope(result)) {
1363  return result.getDefiningOp()->emitOpError(
1364  "result variable for the operation has not been declared");
1365  }
1366  os << getOrCreateName(result) << " = ";
1367  return success();
1368 }
1369 
1370 LogicalResult CppEmitter::emitVariableDeclaration(OpResult result,
1371  bool trailingSemicolon) {
1372  if (isa<emitc::SubscriptOp>(result.getDefiningOp()))
1373  return success();
1374  if (hasValueInScope(result)) {
1375  return result.getDefiningOp()->emitError(
1376  "result variable for the operation already declared");
1377  }
1378  if (failed(emitVariableDeclaration(result.getOwner()->getLoc(),
1379  result.getType(),
1380  getOrCreateName(result))))
1381  return failure();
1382  if (trailingSemicolon)
1383  os << ";\n";
1384  return success();
1385 }
1386 
1387 LogicalResult CppEmitter::emitAssignPrefix(Operation &op) {
1388  // If op is being emitted as part of an expression, bail out.
1389  if (getEmittedExpression())
1390  return success();
1391 
1392  switch (op.getNumResults()) {
1393  case 0:
1394  break;
1395  case 1: {
1396  OpResult result = op.getResult(0);
1397  if (shouldDeclareVariablesAtTop()) {
1398  if (failed(emitVariableAssignment(result)))
1399  return failure();
1400  } else {
1401  if (failed(emitVariableDeclaration(result, /*trailingSemicolon=*/false)))
1402  return failure();
1403  os << " = ";
1404  }
1405  break;
1406  }
1407  default:
1408  if (!shouldDeclareVariablesAtTop()) {
1409  for (OpResult result : op.getResults()) {
1410  if (failed(emitVariableDeclaration(result, /*trailingSemicolon=*/true)))
1411  return failure();
1412  }
1413  }
1414  os << "std::tie(";
1415  interleaveComma(op.getResults(), os,
1416  [&](Value result) { os << getOrCreateName(result); });
1417  os << ") = ";
1418  }
1419  return success();
1420 }
1421 
1422 LogicalResult CppEmitter::emitLabel(Block &block) {
1423  if (!hasBlockLabel(block))
1424  return block.getParentOp()->emitError("label for block not found");
1425  // FIXME: Add feature in `raw_indented_ostream` to ignore indent for block
1426  // label instead of using `getOStream`.
1427  os.getOStream() << getOrCreateName(block) << ":\n";
1428  return success();
1429 }
1430 
1431 LogicalResult CppEmitter::emitOperation(Operation &op, bool trailingSemicolon) {
1432  LogicalResult status =
1434  // Builtin ops.
1435  .Case<ModuleOp>([&](auto op) { return printOperation(*this, op); })
1436  // CF ops.
1437  .Case<cf::BranchOp, cf::CondBranchOp>(
1438  [&](auto op) { return printOperation(*this, op); })
1439  // EmitC ops.
1440  .Case<emitc::AddOp, emitc::ApplyOp, emitc::AssignOp,
1441  emitc::BitwiseAndOp, emitc::BitwiseLeftShiftOp,
1442  emitc::BitwiseNotOp, emitc::BitwiseOrOp,
1443  emitc::BitwiseRightShiftOp, emitc::BitwiseXorOp, emitc::CallOp,
1444  emitc::CallOpaqueOp, emitc::CastOp, emitc::CmpOp,
1445  emitc::ConditionalOp, emitc::ConstantOp, emitc::DeclareFuncOp,
1446  emitc::DivOp, emitc::ExpressionOp, emitc::ForOp, emitc::FuncOp,
1447  emitc::IfOp, emitc::IncludeOp, emitc::LogicalAndOp,
1448  emitc::LogicalNotOp, emitc::LogicalOrOp, emitc::MulOp,
1449  emitc::RemOp, emitc::ReturnOp, emitc::SubOp, emitc::SubscriptOp,
1450  emitc::UnaryMinusOp, emitc::UnaryPlusOp, emitc::VariableOp,
1451  emitc::VerbatimOp>(
1452  [&](auto op) { return printOperation(*this, op); })
1453  // Func ops.
1454  .Case<func::CallOp, func::FuncOp, func::ReturnOp>(
1455  [&](auto op) { return printOperation(*this, op); })
1456  .Case<emitc::LiteralOp>([&](auto op) { return success(); })
1457  .Default([&](Operation *) {
1458  return op.emitOpError("unable to find printer for op");
1459  });
1460 
1461  if (failed(status))
1462  return failure();
1463 
1464  if (isa<emitc::LiteralOp, emitc::SubscriptOp>(op))
1465  return success();
1466 
1467  if (getEmittedExpression() ||
1468  (isa<emitc::ExpressionOp>(op) &&
1469  shouldBeInlined(cast<emitc::ExpressionOp>(op))))
1470  return success();
1471 
1472  os << (trailingSemicolon ? ";\n" : "\n");
1473 
1474  return success();
1475 }
1476 
1477 LogicalResult CppEmitter::emitVariableDeclaration(Location loc, Type type,
1478  StringRef name) {
1479  if (auto arrType = dyn_cast<emitc::ArrayType>(type)) {
1480  if (failed(emitType(loc, arrType.getElementType())))
1481  return failure();
1482  os << " " << name;
1483  for (auto dim : arrType.getShape()) {
1484  os << "[" << dim << "]";
1485  }
1486  return success();
1487  }
1488  if (failed(emitType(loc, type)))
1489  return failure();
1490  os << " " << name;
1491  return success();
1492 }
1493 
1494 LogicalResult CppEmitter::emitType(Location loc, Type type) {
1495  if (auto iType = dyn_cast<IntegerType>(type)) {
1496  switch (iType.getWidth()) {
1497  case 1:
1498  return (os << "bool"), success();
1499  case 8:
1500  case 16:
1501  case 32:
1502  case 64:
1503  if (shouldMapToUnsigned(iType.getSignedness()))
1504  return (os << "uint" << iType.getWidth() << "_t"), success();
1505  else
1506  return (os << "int" << iType.getWidth() << "_t"), success();
1507  default:
1508  return emitError(loc, "cannot emit integer type ") << type;
1509  }
1510  }
1511  if (auto fType = dyn_cast<FloatType>(type)) {
1512  switch (fType.getWidth()) {
1513  case 32:
1514  return (os << "float"), success();
1515  case 64:
1516  return (os << "double"), success();
1517  default:
1518  return emitError(loc, "cannot emit float type ") << type;
1519  }
1520  }
1521  if (auto iType = dyn_cast<IndexType>(type))
1522  return (os << "size_t"), success();
1523  if (auto tType = dyn_cast<TensorType>(type)) {
1524  if (!tType.hasRank())
1525  return emitError(loc, "cannot emit unranked tensor type");
1526  if (!tType.hasStaticShape())
1527  return emitError(loc, "cannot emit tensor type with non static shape");
1528  os << "Tensor<";
1529  if (isa<ArrayType>(tType.getElementType()))
1530  return emitError(loc, "cannot emit tensor of array type ") << type;
1531  if (failed(emitType(loc, tType.getElementType())))
1532  return failure();
1533  auto shape = tType.getShape();
1534  for (auto dimSize : shape) {
1535  os << ", ";
1536  os << dimSize;
1537  }
1538  os << ">";
1539  return success();
1540  }
1541  if (auto tType = dyn_cast<TupleType>(type))
1542  return emitTupleType(loc, tType.getTypes());
1543  if (auto oType = dyn_cast<emitc::OpaqueType>(type)) {
1544  os << oType.getValue();
1545  return success();
1546  }
1547  if (auto aType = dyn_cast<emitc::ArrayType>(type)) {
1548  if (failed(emitType(loc, aType.getElementType())))
1549  return failure();
1550  for (auto dim : aType.getShape())
1551  os << "[" << dim << "]";
1552  return success();
1553  }
1554  if (auto pType = dyn_cast<emitc::PointerType>(type)) {
1555  if (isa<ArrayType>(pType.getPointee()))
1556  return emitError(loc, "cannot emit pointer to array type ") << type;
1557  if (failed(emitType(loc, pType.getPointee())))
1558  return failure();
1559  os << "*";
1560  return success();
1561  }
1562  return emitError(loc, "cannot emit type ") << type;
1563 }
1564 
1565 LogicalResult CppEmitter::emitTypes(Location loc, ArrayRef<Type> types) {
1566  switch (types.size()) {
1567  case 0:
1568  os << "void";
1569  return success();
1570  case 1:
1571  return emitType(loc, types.front());
1572  default:
1573  return emitTupleType(loc, types);
1574  }
1575 }
1576 
1577 LogicalResult CppEmitter::emitTupleType(Location loc, ArrayRef<Type> types) {
1578  if (llvm::any_of(types, llvm::IsaPred<ArrayType>)) {
1579  return emitError(loc, "cannot emit tuple of array type");
1580  }
1581  os << "std::tuple<";
1583  types, os, [&](Type type) { return emitType(loc, type); })))
1584  return failure();
1585  os << ">";
1586  return success();
1587 }
1588 
1590  bool declareVariablesAtTop) {
1591  CppEmitter emitter(os, declareVariablesAtTop);
1592  return emitter.emitOperation(*op, /*trailingSemicolon=*/false);
1593 }
static LogicalResult printCallOperation(CppEmitter &emitter, Operation *callOp, StringRef callee)
static bool shouldBeInlined(ExpressionOp expressionOp)
Determine whether expression expressionOp should be emitted inline, i.e.
LogicalResult interleaveCommaWithError(const Container &c, raw_ostream &os, UnaryFunctor eachFn)
static FailureOr< int > getOperatorPrecedence(Operation *operation)
Return the precedence of a operator as an integer, higher values imply higher precedence.
static LogicalResult printFunctionArgs(CppEmitter &emitter, Operation *functionOp, ArrayRef< Type > arguments)
static LogicalResult printFunctionBody(CppEmitter &emitter, Operation *functionOp, Region::BlockListType &blocks)
static LogicalResult printConstantOp(CppEmitter &emitter, Operation *operation, Attribute value)
LogicalResult interleaveWithError(ForwardIterator begin, ForwardIterator end, UnaryFunctor eachFn, NullaryFunctor betweenFn)
Convenience functions to produce interleaved output with functions returning a LogicalResult.
static LogicalResult printBinaryOperation(CppEmitter &emitter, Operation *operation, StringRef binaryOperator)
static LogicalResult printOperation(CppEmitter &emitter, emitc::ConstantOp constantOp)
static LogicalResult printUnaryOperation(CppEmitter &emitter, Operation *operation, StringRef unaryOperator)
Attributes are known-constant values of operations.
Definition: Attributes.h:25
This class represents an argument of a Block.
Definition: Value.h:319
Block represents an ordered list of Operations.
Definition: Block.h:30
BlockArgListType getArguments()
Definition: Block.h:84
Block * getSuccessor(unsigned i)
Definition: Block.cpp:258
Operation * getParentOp()
Returns the closest surrounding operation that contains this block.
Definition: Block.cpp:30
This class provides support for representing a failure result, or a valid value of type T.
Definition: LogicalResult.h:78
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:63
NamedAttribute represents a combination of a name and an Attribute value.
Definition: Attributes.h:202
This is a value defined by a result of an operation.
Definition: Value.h:457
Operation * getOwner() const
Returns the operation that owns this result.
Definition: Value.h:466
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
Value getOperand(unsigned idx)
Definition: Operation.h:345
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Definition: Operation.h:402
std::enable_if_t< llvm::function_traits< std::decay_t< FnT > >::num_args==1, RetT > walk(FnT &&callback)
Walk the operation by calling the callback for each nested operation (including this one),...
Definition: Operation.h:793
Location getLoc()
The source location the operation was defined or derived from.
Definition: Operation.h:223
unsigned getNumOperands()
Definition: Operation.h:341
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
Definition: Operation.h:234
ArrayRef< NamedAttribute > getAttrs()
Return all of the attributes on this operation.
Definition: Operation.h:507
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
Definition: Operation.cpp:268
OpTy getParentOfType()
Return the closest surrounding parent operation that is of type 'OpTy'.
Definition: Operation.h:238
operand_range getOperands()
Returns an iterator on the underlying Value's.
Definition: Operation.h:373
result_range getResults()
Definition: Operation.h:410
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
Definition: Operation.cpp:671
unsigned getNumResults()
Return the number of results held by this operation.
Definition: Operation.h:399
This class provides iteration over the held operations of blocks directly within a region.
Definition: Region.h:134
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition: Region.h:26
iterator_range< OpIterator > getOps()
Definition: Region.h:172
llvm::iplist< Block > BlockListType
Definition: Region.h:44
bool empty()
Definition: Region.h:60
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:129
user_range getUsers() const
Definition: Value.h:228
bool hasOneUse() const
Returns true if this value has exactly one use.
Definition: Value.h:215
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition: Value.cpp:20
A utility result that is used to signal how to proceed with an ongoing walk:
Definition: Visitors.h:34
static WalkResult skip()
Definition: Visitors.h:53
static WalkResult advance()
Definition: Visitors.h:52
bool wasInterrupted() const
Returns true if the walk was interrupted.
Definition: Visitors.h:56
raw_ostream subclass that simplifies indention a sequence of code.
raw_indented_ostream & unindent()
Decreases the indent and returning this raw_indented_ostream.
raw_indented_ostream & indent()
Increases the indent and returning this raw_indented_ostream.
LogicalResult translateToCpp(Operation *op, raw_ostream &os, bool declareVariablesAtTop=false)
Translates the given operation to C++ code.
Include the generated interface declarations.
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:62
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:56
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value.
Definition: LogicalResult.h:72
This iterator enumerates the elements in "forward" order.
Definition: Visitors.h:66
This class represents an efficient way to signal success or failure.
Definition: LogicalResult.h:26