MLIR  18.0.0git
TranslateToCpp.cpp
Go to the documentation of this file.
1 //===- TranslateToCpp.cpp - Translating to C++ calls ----------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
13 #include "mlir/IR/BuiltinOps.h"
14 #include "mlir/IR/BuiltinTypes.h"
15 #include "mlir/IR/Dialect.h"
16 #include "mlir/IR/Operation.h"
19 #include "llvm/ADT/DenseMap.h"
20 #include "llvm/ADT/StringExtras.h"
21 #include "llvm/ADT/StringMap.h"
22 #include "llvm/ADT/TypeSwitch.h"
23 #include "llvm/Support/Debug.h"
24 #include "llvm/Support/FormatVariadic.h"
25 #include <utility>
26 
27 #define DEBUG_TYPE "translate-to-cpp"
28 
29 using namespace mlir;
30 using namespace mlir::emitc;
31 using llvm::formatv;
32 
33 /// Convenience functions to produce interleaved output with functions returning
34 /// a LogicalResult. This is different than those in STLExtras as functions used
35 /// on each element doesn't return a string.
36 template <typename ForwardIterator, typename UnaryFunctor,
37  typename NullaryFunctor>
38 inline LogicalResult
40  UnaryFunctor eachFn, NullaryFunctor betweenFn) {
41  if (begin == end)
42  return success();
43  if (failed(eachFn(*begin)))
44  return failure();
45  ++begin;
46  for (; begin != end; ++begin) {
47  betweenFn();
48  if (failed(eachFn(*begin)))
49  return failure();
50  }
51  return success();
52 }
53 
54 template <typename Container, typename UnaryFunctor, typename NullaryFunctor>
55 inline LogicalResult interleaveWithError(const Container &c,
56  UnaryFunctor eachFn,
57  NullaryFunctor betweenFn) {
58  return interleaveWithError(c.begin(), c.end(), eachFn, betweenFn);
59 }
60 
61 template <typename Container, typename UnaryFunctor>
62 inline LogicalResult interleaveCommaWithError(const Container &c,
63  raw_ostream &os,
64  UnaryFunctor eachFn) {
65  return interleaveWithError(c.begin(), c.end(), eachFn, [&]() { os << ", "; });
66 }
67 
68 namespace {
69 /// Emitter that uses dialect specific emitters to emit C++ code.
70 struct CppEmitter {
71  explicit CppEmitter(raw_ostream &os, bool declareVariablesAtTop);
72 
73  /// Emits attribute or returns failure.
74  LogicalResult emitAttribute(Location loc, Attribute attr);
75 
76  /// Emits operation 'op' with/without training semicolon or returns failure.
77  LogicalResult emitOperation(Operation &op, bool trailingSemicolon);
78 
79  /// Emits type 'type' or returns failure.
80  LogicalResult emitType(Location loc, Type type);
81 
82  /// Emits array of types as a std::tuple of the emitted types.
83  /// - emits void for an empty array;
84  /// - emits the type of the only element for arrays of size one;
85  /// - emits a std::tuple otherwise;
86  LogicalResult emitTypes(Location loc, ArrayRef<Type> types);
87 
88  /// Emits array of types as a std::tuple of the emitted types independently of
89  /// the array size.
90  LogicalResult emitTupleType(Location loc, ArrayRef<Type> types);
91 
92  /// Emits an assignment for a variable which has been declared previously.
93  LogicalResult emitVariableAssignment(OpResult result);
94 
95  /// Emits a variable declaration for a result of an operation.
96  LogicalResult emitVariableDeclaration(OpResult result,
97  bool trailingSemicolon);
98 
99  /// Emits the variable declaration and assignment prefix for 'op'.
100  /// - emits separate variable followed by std::tie for multi-valued operation;
101  /// - emits single type followed by variable for single result;
102  /// - emits nothing if no value produced by op;
103  /// Emits final '=' operator where a type is produced. Returns failure if
104  /// any result type could not be converted.
105  LogicalResult emitAssignPrefix(Operation &op);
106 
107  /// Emits a label for the block.
108  LogicalResult emitLabel(Block &block);
109 
110  /// Emits the operands and atttributes of the operation. All operands are
111  /// emitted first and then all attributes in alphabetical order.
112  LogicalResult emitOperandsAndAttributes(Operation &op,
113  ArrayRef<StringRef> exclude = {});
114 
115  /// Emits the operands of the operation. All operands are emitted in order.
116  LogicalResult emitOperands(Operation &op);
117 
118  /// Return the existing or a new name for a Value.
119  StringRef getOrCreateName(Value val);
120 
121  /// Return the existing or a new label of a Block.
122  StringRef getOrCreateName(Block &block);
123 
124  /// Whether to map an mlir integer to a unsigned integer in C++.
125  bool shouldMapToUnsigned(IntegerType::SignednessSemantics val);
126 
127  /// RAII helper function to manage entering/exiting C++ scopes.
128  struct Scope {
129  Scope(CppEmitter &emitter)
130  : valueMapperScope(emitter.valueMapper),
131  blockMapperScope(emitter.blockMapper), emitter(emitter) {
132  emitter.valueInScopeCount.push(emitter.valueInScopeCount.top());
133  emitter.labelInScopeCount.push(emitter.labelInScopeCount.top());
134  }
135  ~Scope() {
136  emitter.valueInScopeCount.pop();
137  emitter.labelInScopeCount.pop();
138  }
139 
140  private:
141  llvm::ScopedHashTableScope<Value, std::string> valueMapperScope;
142  llvm::ScopedHashTableScope<Block *, std::string> blockMapperScope;
143  CppEmitter &emitter;
144  };
145 
146  /// Returns wether the Value is assigned to a C++ variable in the scope.
147  bool hasValueInScope(Value val);
148 
149  // Returns whether a label is assigned to the block.
150  bool hasBlockLabel(Block &block);
151 
152  /// Returns the output stream.
153  raw_indented_ostream &ostream() { return os; };
154 
155  /// Returns if all variables for op results and basic block arguments need to
156  /// be declared at the beginning of a function.
157  bool shouldDeclareVariablesAtTop() { return declareVariablesAtTop; };
158 
159 private:
160  using ValueMapper = llvm::ScopedHashTable<Value, std::string>;
161  using BlockMapper = llvm::ScopedHashTable<Block *, std::string>;
162 
163  /// Output stream to emit to.
165 
166  /// Boolean to enforce that all variables for op results and block
167  /// arguments are declared at the beginning of the function. This also
168  /// includes results from ops located in nested regions.
169  bool declareVariablesAtTop;
170 
171  /// Map from value to name of C++ variable that contain the name.
172  ValueMapper valueMapper;
173 
174  /// Map from block to name of C++ label.
175  BlockMapper blockMapper;
176 
177  /// The number of values in the current scope. This is used to declare the
178  /// names of values in a scope.
179  std::stack<int64_t> valueInScopeCount;
180  std::stack<int64_t> labelInScopeCount;
181 };
182 } // namespace
183 
184 static LogicalResult printConstantOp(CppEmitter &emitter, Operation *operation,
185  Attribute value) {
186  OpResult result = operation->getResult(0);
187 
188  // Only emit an assignment as the variable was already declared when printing
189  // the FuncOp.
190  if (emitter.shouldDeclareVariablesAtTop()) {
191  // Skip the assignment if the emitc.constant has no value.
192  if (auto oAttr = dyn_cast<emitc::OpaqueAttr>(value)) {
193  if (oAttr.getValue().empty())
194  return success();
195  }
196 
197  if (failed(emitter.emitVariableAssignment(result)))
198  return failure();
199  return emitter.emitAttribute(operation->getLoc(), value);
200  }
201 
202  // Emit a variable declaration for an emitc.constant op without value.
203  if (auto oAttr = dyn_cast<emitc::OpaqueAttr>(value)) {
204  if (oAttr.getValue().empty())
205  // The semicolon gets printed by the emitOperation function.
206  return emitter.emitVariableDeclaration(result,
207  /*trailingSemicolon=*/false);
208  }
209 
210  // Emit a variable declaration.
211  if (failed(emitter.emitAssignPrefix(*operation)))
212  return failure();
213  return emitter.emitAttribute(operation->getLoc(), value);
214 }
215 
216 static LogicalResult printOperation(CppEmitter &emitter,
217  emitc::ConstantOp constantOp) {
218  Operation *operation = constantOp.getOperation();
219  Attribute value = constantOp.getValue();
220 
221  return printConstantOp(emitter, operation, value);
222 }
223 
224 static LogicalResult printOperation(CppEmitter &emitter,
225  emitc::VariableOp variableOp) {
226  Operation *operation = variableOp.getOperation();
227  Attribute value = variableOp.getValue();
228 
229  return printConstantOp(emitter, operation, value);
230 }
231 
232 static LogicalResult printOperation(CppEmitter &emitter,
233  arith::ConstantOp constantOp) {
234  Operation *operation = constantOp.getOperation();
235  Attribute value = constantOp.getValue();
236 
237  return printConstantOp(emitter, operation, value);
238 }
239 
240 static LogicalResult printOperation(CppEmitter &emitter,
241  func::ConstantOp constantOp) {
242  Operation *operation = constantOp.getOperation();
243  Attribute value = constantOp.getValueAttr();
244 
245  return printConstantOp(emitter, operation, value);
246 }
247 
248 static LogicalResult printOperation(CppEmitter &emitter,
249  emitc::AssignOp assignOp) {
250  auto variableOp = cast<emitc::VariableOp>(assignOp.getVar().getDefiningOp());
251  OpResult result = variableOp->getResult(0);
252 
253  if (failed(emitter.emitVariableAssignment(result)))
254  return failure();
255 
256  emitter.ostream() << emitter.getOrCreateName(assignOp.getValue());
257 
258  return success();
259 }
260 
261 static LogicalResult printBinaryOperation(CppEmitter &emitter,
262  Operation *operation,
263  StringRef binaryOperator) {
264  raw_ostream &os = emitter.ostream();
265 
266  if (failed(emitter.emitAssignPrefix(*operation)))
267  return failure();
268  os << emitter.getOrCreateName(operation->getOperand(0));
269  os << " " << binaryOperator;
270  os << " " << emitter.getOrCreateName(operation->getOperand(1));
271 
272  return success();
273 }
274 
275 static LogicalResult printOperation(CppEmitter &emitter, emitc::AddOp addOp) {
276  Operation *operation = addOp.getOperation();
277 
278  return printBinaryOperation(emitter, operation, "+");
279 }
280 
281 static LogicalResult printOperation(CppEmitter &emitter, emitc::DivOp divOp) {
282  Operation *operation = divOp.getOperation();
283 
284  return printBinaryOperation(emitter, operation, "/");
285 }
286 
287 static LogicalResult printOperation(CppEmitter &emitter, emitc::MulOp mulOp) {
288  Operation *operation = mulOp.getOperation();
289 
290  return printBinaryOperation(emitter, operation, "*");
291 }
292 
293 static LogicalResult printOperation(CppEmitter &emitter, emitc::RemOp remOp) {
294  Operation *operation = remOp.getOperation();
295 
296  return printBinaryOperation(emitter, operation, "%");
297 }
298 
299 static LogicalResult printOperation(CppEmitter &emitter, emitc::SubOp subOp) {
300  Operation *operation = subOp.getOperation();
301 
302  return printBinaryOperation(emitter, operation, "-");
303 }
304 
305 static LogicalResult printOperation(CppEmitter &emitter, emitc::CmpOp cmpOp) {
306  Operation *operation = cmpOp.getOperation();
307 
308  StringRef binaryOperator;
309 
310  switch (cmpOp.getPredicate()) {
311  case emitc::CmpPredicate::eq:
312  binaryOperator = "==";
313  break;
314  case emitc::CmpPredicate::ne:
315  binaryOperator = "!=";
316  break;
317  case emitc::CmpPredicate::lt:
318  binaryOperator = "<";
319  break;
320  case emitc::CmpPredicate::le:
321  binaryOperator = "<=";
322  break;
323  case emitc::CmpPredicate::gt:
324  binaryOperator = ">";
325  break;
326  case emitc::CmpPredicate::ge:
327  binaryOperator = ">=";
328  break;
329  case emitc::CmpPredicate::three_way:
330  binaryOperator = "<=>";
331  break;
332  }
333 
334  return printBinaryOperation(emitter, operation, binaryOperator);
335 }
336 
337 static LogicalResult printOperation(CppEmitter &emitter,
338  cf::BranchOp branchOp) {
339  raw_ostream &os = emitter.ostream();
340  Block &successor = *branchOp.getSuccessor();
341 
342  for (auto pair :
343  llvm::zip(branchOp.getOperands(), successor.getArguments())) {
344  Value &operand = std::get<0>(pair);
345  BlockArgument &argument = std::get<1>(pair);
346  os << emitter.getOrCreateName(argument) << " = "
347  << emitter.getOrCreateName(operand) << ";\n";
348  }
349 
350  os << "goto ";
351  if (!(emitter.hasBlockLabel(successor)))
352  return branchOp.emitOpError("unable to find label for successor block");
353  os << emitter.getOrCreateName(successor);
354  return success();
355 }
356 
357 static LogicalResult printOperation(CppEmitter &emitter,
358  cf::CondBranchOp condBranchOp) {
359  raw_indented_ostream &os = emitter.ostream();
360  Block &trueSuccessor = *condBranchOp.getTrueDest();
361  Block &falseSuccessor = *condBranchOp.getFalseDest();
362 
363  os << "if (" << emitter.getOrCreateName(condBranchOp.getCondition())
364  << ") {\n";
365 
366  os.indent();
367 
368  // If condition is true.
369  for (auto pair : llvm::zip(condBranchOp.getTrueOperands(),
370  trueSuccessor.getArguments())) {
371  Value &operand = std::get<0>(pair);
372  BlockArgument &argument = std::get<1>(pair);
373  os << emitter.getOrCreateName(argument) << " = "
374  << emitter.getOrCreateName(operand) << ";\n";
375  }
376 
377  os << "goto ";
378  if (!(emitter.hasBlockLabel(trueSuccessor))) {
379  return condBranchOp.emitOpError("unable to find label for successor block");
380  }
381  os << emitter.getOrCreateName(trueSuccessor) << ";\n";
382  os.unindent() << "} else {\n";
383  os.indent();
384  // If condition is false.
385  for (auto pair : llvm::zip(condBranchOp.getFalseOperands(),
386  falseSuccessor.getArguments())) {
387  Value &operand = std::get<0>(pair);
388  BlockArgument &argument = std::get<1>(pair);
389  os << emitter.getOrCreateName(argument) << " = "
390  << emitter.getOrCreateName(operand) << ";\n";
391  }
392 
393  os << "goto ";
394  if (!(emitter.hasBlockLabel(falseSuccessor))) {
395  return condBranchOp.emitOpError()
396  << "unable to find label for successor block";
397  }
398  os << emitter.getOrCreateName(falseSuccessor) << ";\n";
399  os.unindent() << "}";
400  return success();
401 }
402 
403 static LogicalResult printOperation(CppEmitter &emitter, func::CallOp callOp) {
404  if (failed(emitter.emitAssignPrefix(*callOp.getOperation())))
405  return failure();
406 
407  raw_ostream &os = emitter.ostream();
408  os << callOp.getCallee() << "(";
409  if (failed(emitter.emitOperands(*callOp.getOperation())))
410  return failure();
411  os << ")";
412  return success();
413 }
414 
415 static LogicalResult printOperation(CppEmitter &emitter,
416  emitc::CallOpaqueOp callOpaqueOp) {
417  raw_ostream &os = emitter.ostream();
418  Operation &op = *callOpaqueOp.getOperation();
419 
420  if (failed(emitter.emitAssignPrefix(op)))
421  return failure();
422  os << callOpaqueOp.getCallee();
423 
424  auto emitArgs = [&](Attribute attr) -> LogicalResult {
425  if (auto t = dyn_cast<IntegerAttr>(attr)) {
426  // Index attributes are treated specially as operand index.
427  if (t.getType().isIndex()) {
428  int64_t idx = t.getInt();
429  Value operand = op.getOperand(idx);
430  auto literalDef =
431  dyn_cast_if_present<LiteralOp>(operand.getDefiningOp());
432  if (!literalDef && !emitter.hasValueInScope(operand))
433  return op.emitOpError("operand ")
434  << idx << "'s value not defined in scope";
435  os << emitter.getOrCreateName(operand);
436  return success();
437  }
438  }
439  if (failed(emitter.emitAttribute(op.getLoc(), attr)))
440  return failure();
441 
442  return success();
443  };
444 
445  if (callOpaqueOp.getTemplateArgs()) {
446  os << "<";
447  if (failed(interleaveCommaWithError(*callOpaqueOp.getTemplateArgs(), os,
448  emitArgs)))
449  return failure();
450  os << ">";
451  }
452 
453  os << "(";
454 
455  LogicalResult emittedArgs =
456  callOpaqueOp.getArgs()
457  ? interleaveCommaWithError(*callOpaqueOp.getArgs(), os, emitArgs)
458  : emitter.emitOperands(op);
459  if (failed(emittedArgs))
460  return failure();
461  os << ")";
462  return success();
463 }
464 
465 static LogicalResult printOperation(CppEmitter &emitter,
466  emitc::ApplyOp applyOp) {
467  raw_ostream &os = emitter.ostream();
468  Operation &op = *applyOp.getOperation();
469 
470  if (failed(emitter.emitAssignPrefix(op)))
471  return failure();
472  os << applyOp.getApplicableOperator();
473  os << emitter.getOrCreateName(applyOp.getOperand());
474 
475  return success();
476 }
477 
478 static LogicalResult printOperation(CppEmitter &emitter, emitc::CastOp castOp) {
479  raw_ostream &os = emitter.ostream();
480  Operation &op = *castOp.getOperation();
481 
482  if (failed(emitter.emitAssignPrefix(op)))
483  return failure();
484  os << "(";
485  if (failed(emitter.emitType(op.getLoc(), op.getResult(0).getType())))
486  return failure();
487  os << ") ";
488  os << emitter.getOrCreateName(castOp.getOperand());
489 
490  return success();
491 }
492 
493 static LogicalResult printOperation(CppEmitter &emitter,
494  emitc::IncludeOp includeOp) {
495  raw_ostream &os = emitter.ostream();
496 
497  os << "#include ";
498  if (includeOp.getIsStandardInclude())
499  os << "<" << includeOp.getInclude() << ">";
500  else
501  os << "\"" << includeOp.getInclude() << "\"";
502 
503  return success();
504 }
505 
506 static LogicalResult printOperation(CppEmitter &emitter, emitc::ForOp forOp) {
507 
508  raw_indented_ostream &os = emitter.ostream();
509 
510  os << "for (";
511  if (failed(
512  emitter.emitType(forOp.getLoc(), forOp.getInductionVar().getType())))
513  return failure();
514  os << " ";
515  os << emitter.getOrCreateName(forOp.getInductionVar());
516  os << " = ";
517  os << emitter.getOrCreateName(forOp.getLowerBound());
518  os << "; ";
519  os << emitter.getOrCreateName(forOp.getInductionVar());
520  os << " < ";
521  os << emitter.getOrCreateName(forOp.getUpperBound());
522  os << "; ";
523  os << emitter.getOrCreateName(forOp.getInductionVar());
524  os << " += ";
525  os << emitter.getOrCreateName(forOp.getStep());
526  os << ") {\n";
527  os.indent();
528 
529  Region &forRegion = forOp.getRegion();
530  auto regionOps = forRegion.getOps();
531 
532  // We skip the trailing yield op.
533  for (auto it = regionOps.begin(); std::next(it) != regionOps.end(); ++it) {
534  if (failed(emitter.emitOperation(*it, /*trailingSemicolon=*/true)))
535  return failure();
536  }
537 
538  os.unindent() << "}";
539 
540  return success();
541 }
542 
543 static LogicalResult printOperation(CppEmitter &emitter, emitc::IfOp ifOp) {
544  raw_indented_ostream &os = emitter.ostream();
545 
546  // Helper function to emit all ops except the last one, expected to be
547  // emitc::yield.
548  auto emitAllExceptLast = [&emitter](Region &region) {
549  Region::OpIterator it = region.op_begin(), end = region.op_end();
550  for (; std::next(it) != end; ++it) {
551  if (failed(emitter.emitOperation(*it, /*trailingSemicolon=*/true)))
552  return failure();
553  }
554  assert(isa<emitc::YieldOp>(*it) &&
555  "Expected last operation in the region to be emitc::yield");
556  return success();
557  };
558 
559  os << "if (";
560  if (failed(emitter.emitOperands(*ifOp.getOperation())))
561  return failure();
562  os << ") {\n";
563  os.indent();
564  if (failed(emitAllExceptLast(ifOp.getThenRegion())))
565  return failure();
566  os.unindent() << "}";
567 
568  Region &elseRegion = ifOp.getElseRegion();
569  if (!elseRegion.empty()) {
570  os << " else {\n";
571  os.indent();
572  if (failed(emitAllExceptLast(elseRegion)))
573  return failure();
574  os.unindent() << "}";
575  }
576 
577  return success();
578 }
579 
580 static LogicalResult printOperation(CppEmitter &emitter,
581  func::ReturnOp returnOp) {
582  raw_ostream &os = emitter.ostream();
583  os << "return";
584  switch (returnOp.getNumOperands()) {
585  case 0:
586  return success();
587  case 1:
588  os << " " << emitter.getOrCreateName(returnOp.getOperand(0));
589  return success(emitter.hasValueInScope(returnOp.getOperand(0)));
590  default:
591  os << " std::make_tuple(";
592  if (failed(emitter.emitOperandsAndAttributes(*returnOp.getOperation())))
593  return failure();
594  os << ")";
595  return success();
596  }
597 }
598 
599 static LogicalResult printOperation(CppEmitter &emitter, ModuleOp moduleOp) {
600  CppEmitter::Scope scope(emitter);
601 
602  for (Operation &op : moduleOp) {
603  if (failed(emitter.emitOperation(op, /*trailingSemicolon=*/false)))
604  return failure();
605  }
606  return success();
607 }
608 
609 static LogicalResult printOperation(CppEmitter &emitter,
610  func::FuncOp functionOp) {
611  // We need to declare variables at top if the function has multiple blocks.
612  if (!emitter.shouldDeclareVariablesAtTop() &&
613  functionOp.getBlocks().size() > 1) {
614  return functionOp.emitOpError(
615  "with multiple blocks needs variables declared at top");
616  }
617 
618  CppEmitter::Scope scope(emitter);
619  raw_indented_ostream &os = emitter.ostream();
620  if (failed(emitter.emitTypes(functionOp.getLoc(),
621  functionOp.getFunctionType().getResults())))
622  return failure();
623  os << " " << functionOp.getName();
624 
625  os << "(";
627  functionOp.getArguments(), os,
628  [&](BlockArgument arg) -> LogicalResult {
629  if (failed(emitter.emitType(functionOp.getLoc(), arg.getType())))
630  return failure();
631  os << " " << emitter.getOrCreateName(arg);
632  return success();
633  })))
634  return failure();
635  os << ") {\n";
636  os.indent();
637  if (emitter.shouldDeclareVariablesAtTop()) {
638  // Declare all variables that hold op results including those from nested
639  // regions.
640  WalkResult result =
641  functionOp.walk<WalkOrder::PreOrder>([&](Operation *op) -> WalkResult {
642  if (isa<emitc::LiteralOp>(op))
643  return WalkResult::skip();
644  for (OpResult result : op->getResults()) {
645  if (failed(emitter.emitVariableDeclaration(
646  result, /*trailingSemicolon=*/true))) {
647  return WalkResult(
648  op->emitError("unable to declare result variable for op"));
649  }
650  }
651  return WalkResult::advance();
652  });
653  if (result.wasInterrupted())
654  return failure();
655  }
656 
657  Region::BlockListType &blocks = functionOp.getBlocks();
658  // Create label names for basic blocks.
659  for (Block &block : blocks) {
660  emitter.getOrCreateName(block);
661  }
662 
663  // Declare variables for basic block arguments.
664  for (Block &block : llvm::drop_begin(blocks)) {
665  for (BlockArgument &arg : block.getArguments()) {
666  if (emitter.hasValueInScope(arg))
667  return functionOp.emitOpError(" block argument #")
668  << arg.getArgNumber() << " is out of scope";
669  if (failed(
670  emitter.emitType(block.getParentOp()->getLoc(), arg.getType()))) {
671  return failure();
672  }
673  os << " " << emitter.getOrCreateName(arg) << ";\n";
674  }
675  }
676 
677  for (Block &block : blocks) {
678  // Only print a label if the block has predecessors.
679  if (!block.hasNoPredecessors()) {
680  if (failed(emitter.emitLabel(block)))
681  return failure();
682  }
683  for (Operation &op : block.getOperations()) {
684  // When generating code for an emitc.if or cf.cond_br op no semicolon
685  // needs to be printed after the closing brace.
686  // When generating code for an emitc.for op, printing a trailing semicolon
687  // is handled within the printOperation function.
688  bool trailingSemicolon =
689  !isa<cf::CondBranchOp, emitc::ForOp, emitc::IfOp, emitc::LiteralOp>(
690  op);
691 
692  if (failed(emitter.emitOperation(
693  op, /*trailingSemicolon=*/trailingSemicolon)))
694  return failure();
695  }
696  }
697  os.unindent() << "}\n";
698  return success();
699 }
700 
701 CppEmitter::CppEmitter(raw_ostream &os, bool declareVariablesAtTop)
702  : os(os), declareVariablesAtTop(declareVariablesAtTop) {
703  valueInScopeCount.push(0);
704  labelInScopeCount.push(0);
705 }
706 
707 /// Return the existing or a new name for a Value.
708 StringRef CppEmitter::getOrCreateName(Value val) {
709  if (auto literal = dyn_cast_if_present<emitc::LiteralOp>(val.getDefiningOp()))
710  return literal.getValue();
711  if (!valueMapper.count(val))
712  valueMapper.insert(val, formatv("v{0}", ++valueInScopeCount.top()));
713  return *valueMapper.begin(val);
714 }
715 
716 /// Return the existing or a new label for a Block.
717 StringRef CppEmitter::getOrCreateName(Block &block) {
718  if (!blockMapper.count(&block))
719  blockMapper.insert(&block, formatv("label{0}", ++labelInScopeCount.top()));
720  return *blockMapper.begin(&block);
721 }
722 
723 bool CppEmitter::shouldMapToUnsigned(IntegerType::SignednessSemantics val) {
724  switch (val) {
725  case IntegerType::Signless:
726  return false;
727  case IntegerType::Signed:
728  return false;
729  case IntegerType::Unsigned:
730  return true;
731  }
732  llvm_unreachable("Unexpected IntegerType::SignednessSemantics");
733 }
734 
735 bool CppEmitter::hasValueInScope(Value val) { return valueMapper.count(val); }
736 
737 bool CppEmitter::hasBlockLabel(Block &block) {
738  return blockMapper.count(&block);
739 }
740 
741 LogicalResult CppEmitter::emitAttribute(Location loc, Attribute attr) {
742  auto printInt = [&](const APInt &val, bool isUnsigned) {
743  if (val.getBitWidth() == 1) {
744  if (val.getBoolValue())
745  os << "true";
746  else
747  os << "false";
748  } else {
749  SmallString<128> strValue;
750  val.toString(strValue, 10, !isUnsigned, false);
751  os << strValue;
752  }
753  };
754 
755  auto printFloat = [&](const APFloat &val) {
756  if (val.isFinite()) {
757  SmallString<128> strValue;
758  // Use default values of toString except don't truncate zeros.
759  val.toString(strValue, 0, 0, false);
760  switch (llvm::APFloatBase::SemanticsToEnum(val.getSemantics())) {
761  case llvm::APFloatBase::S_IEEEsingle:
762  os << "(float)";
763  break;
764  case llvm::APFloatBase::S_IEEEdouble:
765  os << "(double)";
766  break;
767  default:
768  break;
769  };
770  os << strValue;
771  } else if (val.isNaN()) {
772  os << "NAN";
773  } else if (val.isInfinity()) {
774  if (val.isNegative())
775  os << "-";
776  os << "INFINITY";
777  }
778  };
779 
780  // Print floating point attributes.
781  if (auto fAttr = dyn_cast<FloatAttr>(attr)) {
782  printFloat(fAttr.getValue());
783  return success();
784  }
785  if (auto dense = dyn_cast<DenseFPElementsAttr>(attr)) {
786  os << '{';
787  interleaveComma(dense, os, [&](const APFloat &val) { printFloat(val); });
788  os << '}';
789  return success();
790  }
791 
792  // Print integer attributes.
793  if (auto iAttr = dyn_cast<IntegerAttr>(attr)) {
794  if (auto iType = dyn_cast<IntegerType>(iAttr.getType())) {
795  printInt(iAttr.getValue(), shouldMapToUnsigned(iType.getSignedness()));
796  return success();
797  }
798  if (auto iType = dyn_cast<IndexType>(iAttr.getType())) {
799  printInt(iAttr.getValue(), false);
800  return success();
801  }
802  }
803  if (auto dense = dyn_cast<DenseIntElementsAttr>(attr)) {
804  if (auto iType = dyn_cast<IntegerType>(
805  cast<TensorType>(dense.getType()).getElementType())) {
806  os << '{';
807  interleaveComma(dense, os, [&](const APInt &val) {
808  printInt(val, shouldMapToUnsigned(iType.getSignedness()));
809  });
810  os << '}';
811  return success();
812  }
813  if (auto iType = dyn_cast<IndexType>(
814  cast<TensorType>(dense.getType()).getElementType())) {
815  os << '{';
816  interleaveComma(dense, os,
817  [&](const APInt &val) { printInt(val, false); });
818  os << '}';
819  return success();
820  }
821  }
822 
823  // Print opaque attributes.
824  if (auto oAttr = dyn_cast<emitc::OpaqueAttr>(attr)) {
825  os << oAttr.getValue();
826  return success();
827  }
828 
829  // Print symbolic reference attributes.
830  if (auto sAttr = dyn_cast<SymbolRefAttr>(attr)) {
831  if (sAttr.getNestedReferences().size() > 1)
832  return emitError(loc, "attribute has more than 1 nested reference");
833  os << sAttr.getRootReference().getValue();
834  return success();
835  }
836 
837  // Print type attributes.
838  if (auto type = dyn_cast<TypeAttr>(attr))
839  return emitType(loc, type.getValue());
840 
841  return emitError(loc, "cannot emit attribute: ") << attr;
842 }
843 
844 LogicalResult CppEmitter::emitOperands(Operation &op) {
845  auto emitOperandName = [&](Value result) -> LogicalResult {
846  auto literalDef = dyn_cast_if_present<LiteralOp>(result.getDefiningOp());
847  if (!literalDef && !hasValueInScope(result))
848  return op.emitOpError() << "operand value not in scope";
849  os << getOrCreateName(result);
850  return success();
851  };
852  return interleaveCommaWithError(op.getOperands(), os, emitOperandName);
853 }
854 
856 CppEmitter::emitOperandsAndAttributes(Operation &op,
857  ArrayRef<StringRef> exclude) {
858  if (failed(emitOperands(op)))
859  return failure();
860  // Insert comma in between operands and non-filtered attributes if needed.
861  if (op.getNumOperands() > 0) {
862  for (NamedAttribute attr : op.getAttrs()) {
863  if (!llvm::is_contained(exclude, attr.getName().strref())) {
864  os << ", ";
865  break;
866  }
867  }
868  }
869  // Emit attributes.
870  auto emitNamedAttribute = [&](NamedAttribute attr) -> LogicalResult {
871  if (llvm::is_contained(exclude, attr.getName().strref()))
872  return success();
873  os << "/* " << attr.getName().getValue() << " */";
874  if (failed(emitAttribute(op.getLoc(), attr.getValue())))
875  return failure();
876  return success();
877  };
878  return interleaveCommaWithError(op.getAttrs(), os, emitNamedAttribute);
879 }
880 
881 LogicalResult CppEmitter::emitVariableAssignment(OpResult result) {
882  if (!hasValueInScope(result)) {
883  return result.getDefiningOp()->emitOpError(
884  "result variable for the operation has not been declared");
885  }
886  os << getOrCreateName(result) << " = ";
887  return success();
888 }
889 
890 LogicalResult CppEmitter::emitVariableDeclaration(OpResult result,
891  bool trailingSemicolon) {
892  if (hasValueInScope(result)) {
893  return result.getDefiningOp()->emitError(
894  "result variable for the operation already declared");
895  }
896  if (failed(emitType(result.getOwner()->getLoc(), result.getType())))
897  return failure();
898  os << " " << getOrCreateName(result);
899  if (trailingSemicolon)
900  os << ";\n";
901  return success();
902 }
903 
904 LogicalResult CppEmitter::emitAssignPrefix(Operation &op) {
905  switch (op.getNumResults()) {
906  case 0:
907  break;
908  case 1: {
909  OpResult result = op.getResult(0);
910  if (shouldDeclareVariablesAtTop()) {
911  if (failed(emitVariableAssignment(result)))
912  return failure();
913  } else {
914  if (failed(emitVariableDeclaration(result, /*trailingSemicolon=*/false)))
915  return failure();
916  os << " = ";
917  }
918  break;
919  }
920  default:
921  if (!shouldDeclareVariablesAtTop()) {
922  for (OpResult result : op.getResults()) {
923  if (failed(emitVariableDeclaration(result, /*trailingSemicolon=*/true)))
924  return failure();
925  }
926  }
927  os << "std::tie(";
928  interleaveComma(op.getResults(), os,
929  [&](Value result) { os << getOrCreateName(result); });
930  os << ") = ";
931  }
932  return success();
933 }
934 
935 LogicalResult CppEmitter::emitLabel(Block &block) {
936  if (!hasBlockLabel(block))
937  return block.getParentOp()->emitError("label for block not found");
938  // FIXME: Add feature in `raw_indented_ostream` to ignore indent for block
939  // label instead of using `getOStream`.
940  os.getOStream() << getOrCreateName(block) << ":\n";
941  return success();
942 }
943 
944 LogicalResult CppEmitter::emitOperation(Operation &op, bool trailingSemicolon) {
945  LogicalResult status =
947  // Builtin ops.
948  .Case<ModuleOp>([&](auto op) { return printOperation(*this, op); })
949  // CF ops.
950  .Case<cf::BranchOp, cf::CondBranchOp>(
951  [&](auto op) { return printOperation(*this, op); })
952  // EmitC ops.
953  .Case<emitc::AddOp, emitc::ApplyOp, emitc::AssignOp,
954  emitc::CallOpaqueOp, emitc::CastOp, emitc::CmpOp,
955  emitc::ConstantOp, emitc::DivOp, emitc::ForOp, emitc::IfOp,
956  emitc::IncludeOp, emitc::MulOp, emitc::RemOp, emitc::SubOp,
957  emitc::VariableOp>(
958  [&](auto op) { return printOperation(*this, op); })
959  // Func ops.
960  .Case<func::CallOp, func::ConstantOp, func::FuncOp, func::ReturnOp>(
961  [&](auto op) { return printOperation(*this, op); })
962  // Arithmetic ops.
963  .Case<arith::ConstantOp>(
964  [&](auto op) { return printOperation(*this, op); })
965  .Case<emitc::LiteralOp>([&](auto op) { return success(); })
966  .Default([&](Operation *) {
967  return op.emitOpError("unable to find printer for op");
968  });
969 
970  if (failed(status))
971  return failure();
972 
973  if (isa<emitc::LiteralOp>(op))
974  return success();
975 
976  os << (trailingSemicolon ? ";\n" : "\n");
977  return success();
978 }
979 
980 LogicalResult CppEmitter::emitType(Location loc, Type type) {
981  if (auto iType = dyn_cast<IntegerType>(type)) {
982  switch (iType.getWidth()) {
983  case 1:
984  return (os << "bool"), success();
985  case 8:
986  case 16:
987  case 32:
988  case 64:
989  if (shouldMapToUnsigned(iType.getSignedness()))
990  return (os << "uint" << iType.getWidth() << "_t"), success();
991  else
992  return (os << "int" << iType.getWidth() << "_t"), success();
993  default:
994  return emitError(loc, "cannot emit integer type ") << type;
995  }
996  }
997  if (auto fType = dyn_cast<FloatType>(type)) {
998  switch (fType.getWidth()) {
999  case 32:
1000  return (os << "float"), success();
1001  case 64:
1002  return (os << "double"), success();
1003  default:
1004  return emitError(loc, "cannot emit float type ") << type;
1005  }
1006  }
1007  if (auto iType = dyn_cast<IndexType>(type))
1008  return (os << "size_t"), success();
1009  if (auto tType = dyn_cast<TensorType>(type)) {
1010  if (!tType.hasRank())
1011  return emitError(loc, "cannot emit unranked tensor type");
1012  if (!tType.hasStaticShape())
1013  return emitError(loc, "cannot emit tensor type with non static shape");
1014  os << "Tensor<";
1015  if (failed(emitType(loc, tType.getElementType())))
1016  return failure();
1017  auto shape = tType.getShape();
1018  for (auto dimSize : shape) {
1019  os << ", ";
1020  os << dimSize;
1021  }
1022  os << ">";
1023  return success();
1024  }
1025  if (auto tType = dyn_cast<TupleType>(type))
1026  return emitTupleType(loc, tType.getTypes());
1027  if (auto oType = dyn_cast<emitc::OpaqueType>(type)) {
1028  os << oType.getValue();
1029  return success();
1030  }
1031  if (auto pType = dyn_cast<emitc::PointerType>(type)) {
1032  if (failed(emitType(loc, pType.getPointee())))
1033  return failure();
1034  os << "*";
1035  return success();
1036  }
1037  return emitError(loc, "cannot emit type ") << type;
1038 }
1039 
1040 LogicalResult CppEmitter::emitTypes(Location loc, ArrayRef<Type> types) {
1041  switch (types.size()) {
1042  case 0:
1043  os << "void";
1044  return success();
1045  case 1:
1046  return emitType(loc, types.front());
1047  default:
1048  return emitTupleType(loc, types);
1049  }
1050 }
1051 
1052 LogicalResult CppEmitter::emitTupleType(Location loc, ArrayRef<Type> types) {
1053  os << "std::tuple<";
1055  types, os, [&](Type type) { return emitType(loc, type); })))
1056  return failure();
1057  os << ">";
1058  return success();
1059 }
1060 
1062  bool declareVariablesAtTop) {
1063  CppEmitter emitter(os, declareVariablesAtTop);
1064  return emitter.emitOperation(*op, /*trailingSemicolon=*/false);
1065 }
LogicalResult interleaveCommaWithError(const Container &c, raw_ostream &os, UnaryFunctor eachFn)
static LogicalResult printConstantOp(CppEmitter &emitter, Operation *operation, Attribute value)
LogicalResult interleaveWithError(ForwardIterator begin, ForwardIterator end, UnaryFunctor eachFn, NullaryFunctor betweenFn)
Convenience functions to produce interleaved output with functions returning a LogicalResult.
static LogicalResult printBinaryOperation(CppEmitter &emitter, Operation *operation, StringRef binaryOperator)
static LogicalResult printOperation(CppEmitter &emitter, emitc::ConstantOp constantOp)
Attributes are known-constant values of operations.
Definition: Attributes.h:25
This class represents an argument of a Block.
Definition: Value.h:315
unsigned getArgNumber() const
Returns the number of this argument.
Definition: Value.h:327
Block represents an ordered list of Operations.
Definition: Block.h:30
BlockArgListType getArguments()
Definition: Block.h:80
Block * getSuccessor(unsigned i)
Definition: Block.cpp:253
Operation * getParentOp()
Returns the closest surrounding operation that contains this block.
Definition: Block.cpp:30
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:63
NamedAttribute represents a combination of a name and an Attribute value.
Definition: Attributes.h:198
This is a value defined by a result of an operation.
Definition: Value.h:453
Operation * getOwner() const
Returns the operation that owns this result.
Definition: Value.h:462
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
Value getOperand(unsigned idx)
Definition: Operation.h:345
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Definition: Operation.h:402
Location getLoc()
The source location the operation was defined or derived from.
Definition: Operation.h:223
unsigned getNumOperands()
Definition: Operation.h:341
ArrayRef< NamedAttribute > getAttrs()
Return all of the attributes on this operation.
Definition: Operation.h:486
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
Definition: Operation.cpp:267
operand_range getOperands()
Returns an iterator on the underlying Value's.
Definition: Operation.h:373
result_range getResults()
Definition: Operation.h:410
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
Definition: Operation.cpp:640
unsigned getNumResults()
Return the number of results held by this operation.
Definition: Operation.h:399
This class provides iteration over the held operations of blocks directly within a region.
Definition: Region.h:134
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition: Region.h:26
iterator_range< OpIterator > getOps()
Definition: Region.h:172
llvm::iplist< Block > BlockListType
Definition: Region.h:44
bool empty()
Definition: Region.h:60
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:125
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition: Value.cpp:20
A utility result that is used to signal how to proceed with an ongoing walk:
Definition: Visitors.h:34
static WalkResult skip()
Definition: Visitors.h:53
static WalkResult advance()
Definition: Visitors.h:52
bool wasInterrupted() const
Returns true if the walk was interrupted.
Definition: Visitors.h:56
raw_ostream subclass that simplifies indention a sequence of code.
raw_indented_ostream & unindent()
Decreases the indent and returning this raw_indented_ostream.
raw_indented_ostream & indent()
Increases the indent and returning this raw_indented_ostream.
LogicalResult translateToCpp(Operation *op, raw_ostream &os, bool declareVariablesAtTop=false)
Translates the given operation to C++ code.
Include the generated interface declarations.
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:62
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:56
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value.
Definition: LogicalResult.h:72
This iterator enumerates the elements in "forward" order.
Definition: Visitors.h:66
This class represents an efficient way to signal success or failure.
Definition: LogicalResult.h:26