19 #include "llvm/ADT/DenseMap.h"
20 #include "llvm/ADT/StringExtras.h"
21 #include "llvm/ADT/StringMap.h"
22 #include "llvm/ADT/TypeSwitch.h"
23 #include "llvm/Support/Debug.h"
24 #include "llvm/Support/FormatVariadic.h"
27 #define DEBUG_TYPE "translate-to-cpp"
37 typename NullaryFunctor>
40 UnaryFunctor eachFn, NullaryFunctor betweenFn) {
43 if (
failed(eachFn(*begin)))
46 for (; begin != end; ++begin) {
48 if (
failed(eachFn(*begin)))
54 template <
typename Container,
typename UnaryFunctor,
typename NullaryFunctor>
57 NullaryFunctor betweenFn) {
61 template <
typename Container,
typename UnaryFunctor>
64 UnaryFunctor eachFn) {
71 explicit CppEmitter(raw_ostream &os,
bool declareVariablesAtTop);
97 bool trailingSemicolon);
119 StringRef getOrCreateName(
Value val);
122 StringRef getOrCreateName(
Block &block);
125 bool shouldMapToUnsigned(IntegerType::SignednessSemantics val);
129 Scope(CppEmitter &emitter)
130 : valueMapperScope(emitter.valueMapper),
131 blockMapperScope(emitter.blockMapper), emitter(emitter) {
132 emitter.valueInScopeCount.push(emitter.valueInScopeCount.top());
133 emitter.labelInScopeCount.push(emitter.labelInScopeCount.top());
136 emitter.valueInScopeCount.pop();
137 emitter.labelInScopeCount.pop();
141 llvm::ScopedHashTableScope<Value, std::string> valueMapperScope;
142 llvm::ScopedHashTableScope<Block *, std::string> blockMapperScope;
147 bool hasValueInScope(
Value val);
150 bool hasBlockLabel(
Block &block);
157 bool shouldDeclareVariablesAtTop() {
return declareVariablesAtTop; };
160 using ValueMapper = llvm::ScopedHashTable<Value, std::string>;
161 using BlockMapper = llvm::ScopedHashTable<Block *, std::string>;
169 bool declareVariablesAtTop;
172 ValueMapper valueMapper;
175 BlockMapper blockMapper;
179 std::stack<int64_t> valueInScopeCount;
180 std::stack<int64_t> labelInScopeCount;
190 if (emitter.shouldDeclareVariablesAtTop()) {
192 if (
auto oAttr = dyn_cast<emitc::OpaqueAttr>(value)) {
193 if (oAttr.getValue().empty())
197 if (
failed(emitter.emitVariableAssignment(result)))
199 return emitter.emitAttribute(operation->
getLoc(), value);
203 if (
auto oAttr = dyn_cast<emitc::OpaqueAttr>(value)) {
204 if (oAttr.getValue().empty())
206 return emitter.emitVariableDeclaration(result,
211 if (
failed(emitter.emitAssignPrefix(*operation)))
213 return emitter.emitAttribute(operation->
getLoc(), value);
217 emitc::ConstantOp constantOp) {
218 Operation *operation = constantOp.getOperation();
225 emitc::VariableOp variableOp) {
226 Operation *operation = variableOp.getOperation();
233 arith::ConstantOp constantOp) {
234 Operation *operation = constantOp.getOperation();
241 func::ConstantOp constantOp) {
242 Operation *operation = constantOp.getOperation();
243 Attribute value = constantOp.getValueAttr();
249 emitc::AssignOp assignOp) {
250 auto variableOp = cast<emitc::VariableOp>(assignOp.getVar().getDefiningOp());
251 OpResult result = variableOp->getResult(0);
253 if (
failed(emitter.emitVariableAssignment(result)))
256 emitter.ostream() << emitter.getOrCreateName(assignOp.getValue());
263 StringRef binaryOperator) {
264 raw_ostream &os = emitter.ostream();
266 if (
failed(emitter.emitAssignPrefix(*operation)))
268 os << emitter.getOrCreateName(operation->
getOperand(0));
269 os <<
" " << binaryOperator;
270 os <<
" " << emitter.getOrCreateName(operation->
getOperand(1));
276 Operation *operation = addOp.getOperation();
282 Operation *operation = divOp.getOperation();
288 Operation *operation = mulOp.getOperation();
294 Operation *operation = remOp.getOperation();
300 Operation *operation = subOp.getOperation();
306 Operation *operation = cmpOp.getOperation();
308 StringRef binaryOperator;
310 switch (cmpOp.getPredicate()) {
311 case emitc::CmpPredicate::eq:
312 binaryOperator =
"==";
314 case emitc::CmpPredicate::ne:
315 binaryOperator =
"!=";
317 case emitc::CmpPredicate::lt:
318 binaryOperator =
"<";
320 case emitc::CmpPredicate::le:
321 binaryOperator =
"<=";
323 case emitc::CmpPredicate::gt:
324 binaryOperator =
">";
326 case emitc::CmpPredicate::ge:
327 binaryOperator =
">=";
329 case emitc::CmpPredicate::three_way:
330 binaryOperator =
"<=>";
338 cf::BranchOp branchOp) {
339 raw_ostream &os = emitter.ostream();
343 llvm::zip(branchOp.getOperands(), successor.
getArguments())) {
344 Value &operand = std::get<0>(pair);
346 os << emitter.getOrCreateName(argument) <<
" = "
347 << emitter.getOrCreateName(operand) <<
";\n";
351 if (!(emitter.hasBlockLabel(successor)))
352 return branchOp.emitOpError(
"unable to find label for successor block");
353 os << emitter.getOrCreateName(successor);
358 cf::CondBranchOp condBranchOp) {
360 Block &trueSuccessor = *condBranchOp.getTrueDest();
361 Block &falseSuccessor = *condBranchOp.getFalseDest();
363 os <<
"if (" << emitter.getOrCreateName(condBranchOp.getCondition())
369 for (
auto pair : llvm::zip(condBranchOp.getTrueOperands(),
371 Value &operand = std::get<0>(pair);
373 os << emitter.getOrCreateName(argument) <<
" = "
374 << emitter.getOrCreateName(operand) <<
";\n";
378 if (!(emitter.hasBlockLabel(trueSuccessor))) {
379 return condBranchOp.emitOpError(
"unable to find label for successor block");
381 os << emitter.getOrCreateName(trueSuccessor) <<
";\n";
385 for (
auto pair : llvm::zip(condBranchOp.getFalseOperands(),
387 Value &operand = std::get<0>(pair);
389 os << emitter.getOrCreateName(argument) <<
" = "
390 << emitter.getOrCreateName(operand) <<
";\n";
394 if (!(emitter.hasBlockLabel(falseSuccessor))) {
395 return condBranchOp.emitOpError()
396 <<
"unable to find label for successor block";
398 os << emitter.getOrCreateName(falseSuccessor) <<
";\n";
404 if (
failed(emitter.emitAssignPrefix(*callOp.getOperation())))
407 raw_ostream &os = emitter.ostream();
408 os << callOp.getCallee() <<
"(";
409 if (
failed(emitter.emitOperands(*callOp.getOperation())))
416 emitc::CallOpaqueOp callOpaqueOp) {
417 raw_ostream &os = emitter.ostream();
418 Operation &op = *callOpaqueOp.getOperation();
420 if (
failed(emitter.emitAssignPrefix(op)))
422 os << callOpaqueOp.getCallee();
425 if (
auto t = dyn_cast<IntegerAttr>(attr)) {
427 if (t.getType().isIndex()) {
428 int64_t idx = t.getInt();
432 if (!literalDef && !emitter.hasValueInScope(operand))
434 << idx <<
"'s value not defined in scope";
435 os << emitter.getOrCreateName(operand);
445 if (callOpaqueOp.getTemplateArgs()) {
456 callOpaqueOp.getArgs()
458 : emitter.emitOperands(op);
466 emitc::ApplyOp applyOp) {
467 raw_ostream &os = emitter.ostream();
470 if (
failed(emitter.emitAssignPrefix(op)))
472 os << applyOp.getApplicableOperator();
473 os << emitter.getOrCreateName(applyOp.getOperand());
479 raw_ostream &os = emitter.ostream();
482 if (
failed(emitter.emitAssignPrefix(op)))
488 os << emitter.getOrCreateName(castOp.getOperand());
494 emitc::IncludeOp includeOp) {
495 raw_ostream &os = emitter.ostream();
498 if (includeOp.getIsStandardInclude())
499 os <<
"<" << includeOp.getInclude() <<
">";
501 os <<
"\"" << includeOp.getInclude() <<
"\"";
512 emitter.emitType(forOp.getLoc(), forOp.getInductionVar().getType())))
515 os << emitter.getOrCreateName(forOp.getInductionVar());
517 os << emitter.getOrCreateName(forOp.getLowerBound());
519 os << emitter.getOrCreateName(forOp.getInductionVar());
521 os << emitter.getOrCreateName(forOp.getUpperBound());
523 os << emitter.getOrCreateName(forOp.getInductionVar());
525 os << emitter.getOrCreateName(forOp.getStep());
529 Region &forRegion = forOp.getRegion();
530 auto regionOps = forRegion.
getOps();
533 for (
auto it = regionOps.begin(); std::next(it) != regionOps.end(); ++it) {
534 if (
failed(emitter.emitOperation(*it,
true)))
548 auto emitAllExceptLast = [&emitter](
Region ®ion) {
550 for (; std::next(it) != end; ++it) {
551 if (
failed(emitter.emitOperation(*it,
true)))
554 assert(isa<emitc::YieldOp>(*it) &&
555 "Expected last operation in the region to be emitc::yield");
560 if (
failed(emitter.emitOperands(*ifOp.getOperation())))
564 if (
failed(emitAllExceptLast(ifOp.getThenRegion())))
568 Region &elseRegion = ifOp.getElseRegion();
569 if (!elseRegion.
empty()) {
572 if (
failed(emitAllExceptLast(elseRegion)))
581 func::ReturnOp returnOp) {
582 raw_ostream &os = emitter.ostream();
584 switch (returnOp.getNumOperands()) {
588 os <<
" " << emitter.getOrCreateName(returnOp.getOperand(0));
589 return success(emitter.hasValueInScope(returnOp.getOperand(0)));
591 os <<
" std::make_tuple(";
592 if (
failed(emitter.emitOperandsAndAttributes(*returnOp.getOperation())))
600 CppEmitter::Scope scope(emitter);
603 if (
failed(emitter.emitOperation(op,
false)))
610 func::FuncOp functionOp) {
612 if (!emitter.shouldDeclareVariablesAtTop() &&
613 functionOp.getBlocks().size() > 1) {
614 return functionOp.emitOpError(
615 "with multiple blocks needs variables declared at top");
618 CppEmitter::Scope scope(emitter);
620 if (
failed(emitter.emitTypes(functionOp.getLoc(),
621 functionOp.getFunctionType().getResults())))
623 os <<
" " << functionOp.getName();
627 functionOp.getArguments(), os,
629 if (failed(emitter.emitType(functionOp.getLoc(), arg.getType())))
631 os <<
" " << emitter.getOrCreateName(arg);
637 if (emitter.shouldDeclareVariablesAtTop()) {
642 if (isa<emitc::LiteralOp>(op))
645 if (
failed(emitter.emitVariableDeclaration(
648 op->
emitError(
"unable to declare result variable for op"));
659 for (
Block &block : blocks) {
660 emitter.getOrCreateName(block);
664 for (
Block &block : llvm::drop_begin(blocks)) {
666 if (emitter.hasValueInScope(arg))
667 return functionOp.emitOpError(
" block argument #")
670 emitter.emitType(block.getParentOp()->getLoc(), arg.
getType()))) {
673 os <<
" " << emitter.getOrCreateName(arg) <<
";\n";
677 for (
Block &block : blocks) {
679 if (!block.hasNoPredecessors()) {
680 if (
failed(emitter.emitLabel(block)))
683 for (
Operation &op : block.getOperations()) {
688 bool trailingSemicolon =
689 !isa<cf::CondBranchOp, emitc::ForOp, emitc::IfOp, emitc::LiteralOp>(
692 if (
failed(emitter.emitOperation(
693 op, trailingSemicolon)))
701 CppEmitter::CppEmitter(raw_ostream &os,
bool declareVariablesAtTop)
702 : os(os), declareVariablesAtTop(declareVariablesAtTop) {
703 valueInScopeCount.push(0);
704 labelInScopeCount.push(0);
708 StringRef CppEmitter::getOrCreateName(
Value val) {
709 if (
auto literal = dyn_cast_if_present<emitc::LiteralOp>(val.
getDefiningOp()))
710 return literal.getValue();
711 if (!valueMapper.count(val))
712 valueMapper.insert(val, formatv(
"v{0}", ++valueInScopeCount.top()));
713 return *valueMapper.begin(val);
717 StringRef CppEmitter::getOrCreateName(
Block &block) {
718 if (!blockMapper.count(&block))
719 blockMapper.insert(&block, formatv(
"label{0}", ++labelInScopeCount.top()));
720 return *blockMapper.begin(&block);
723 bool CppEmitter::shouldMapToUnsigned(IntegerType::SignednessSemantics val) {
725 case IntegerType::Signless:
729 case IntegerType::Unsigned:
732 llvm_unreachable(
"Unexpected IntegerType::SignednessSemantics");
735 bool CppEmitter::hasValueInScope(
Value val) {
return valueMapper.count(val); }
737 bool CppEmitter::hasBlockLabel(
Block &block) {
738 return blockMapper.count(&block);
742 auto printInt = [&](
const APInt &val,
bool isUnsigned) {
743 if (val.getBitWidth() == 1) {
744 if (val.getBoolValue())
750 val.toString(strValue, 10, !isUnsigned,
false);
755 auto printFloat = [&](
const APFloat &val) {
756 if (val.isFinite()) {
759 val.toString(strValue, 0, 0,
false);
760 switch (llvm::APFloatBase::SemanticsToEnum(val.getSemantics())) {
761 case llvm::APFloatBase::S_IEEEsingle:
764 case llvm::APFloatBase::S_IEEEdouble:
771 }
else if (val.isNaN()) {
773 }
else if (val.isInfinity()) {
774 if (val.isNegative())
781 if (
auto fAttr = dyn_cast<FloatAttr>(attr)) {
782 printFloat(fAttr.getValue());
785 if (
auto dense = dyn_cast<DenseFPElementsAttr>(attr)) {
787 interleaveComma(dense, os, [&](
const APFloat &val) { printFloat(val); });
793 if (
auto iAttr = dyn_cast<IntegerAttr>(attr)) {
794 if (
auto iType = dyn_cast<IntegerType>(iAttr.getType())) {
795 printInt(iAttr.getValue(), shouldMapToUnsigned(iType.getSignedness()));
798 if (
auto iType = dyn_cast<IndexType>(iAttr.getType())) {
799 printInt(iAttr.getValue(),
false);
803 if (
auto dense = dyn_cast<DenseIntElementsAttr>(attr)) {
804 if (
auto iType = dyn_cast<IntegerType>(
805 cast<TensorType>(dense.getType()).getElementType())) {
807 interleaveComma(dense, os, [&](
const APInt &val) {
808 printInt(val, shouldMapToUnsigned(iType.getSignedness()));
813 if (
auto iType = dyn_cast<IndexType>(
814 cast<TensorType>(dense.getType()).getElementType())) {
816 interleaveComma(dense, os,
817 [&](
const APInt &val) { printInt(val,
false); });
824 if (
auto oAttr = dyn_cast<emitc::OpaqueAttr>(attr)) {
825 os << oAttr.getValue();
830 if (
auto sAttr = dyn_cast<SymbolRefAttr>(attr)) {
831 if (sAttr.getNestedReferences().size() > 1)
832 return emitError(loc,
"attribute has more than 1 nested reference");
833 os << sAttr.getRootReference().getValue();
838 if (
auto type = dyn_cast<TypeAttr>(attr))
839 return emitType(loc, type.getValue());
841 return emitError(loc,
"cannot emit attribute: ") << attr;
846 auto literalDef = dyn_cast_if_present<LiteralOp>(result.getDefiningOp());
847 if (!literalDef && !hasValueInScope(result))
848 return op.
emitOpError() <<
"operand value not in scope";
849 os << getOrCreateName(result);
856 CppEmitter::emitOperandsAndAttributes(
Operation &op,
858 if (
failed(emitOperands(op)))
863 if (!llvm::is_contained(exclude, attr.getName().strref())) {
871 if (llvm::is_contained(exclude, attr.getName().strref()))
873 os <<
"/* " << attr.getName().getValue() <<
" */";
874 if (
failed(emitAttribute(op.
getLoc(), attr.getValue())))
882 if (!hasValueInScope(result)) {
884 "result variable for the operation has not been declared");
886 os << getOrCreateName(result) <<
" = ";
891 bool trailingSemicolon) {
892 if (hasValueInScope(result)) {
894 "result variable for the operation already declared");
898 os <<
" " << getOrCreateName(result);
899 if (trailingSemicolon)
910 if (shouldDeclareVariablesAtTop()) {
911 if (
failed(emitVariableAssignment(result)))
914 if (
failed(emitVariableDeclaration(result,
false)))
921 if (!shouldDeclareVariablesAtTop()) {
923 if (
failed(emitVariableDeclaration(result,
true)))
929 [&](
Value result) { os << getOrCreateName(result); });
936 if (!hasBlockLabel(block))
940 os.getOStream() << getOrCreateName(block) <<
":\n";
948 .Case<ModuleOp>([&](
auto op) {
return printOperation(*
this, op); })
950 .Case<cf::BranchOp, cf::CondBranchOp>(
953 .Case<emitc::AddOp, emitc::ApplyOp, emitc::AssignOp,
954 emitc::CallOpaqueOp, emitc::CastOp, emitc::CmpOp,
955 emitc::ConstantOp, emitc::DivOp, emitc::ForOp, emitc::IfOp,
956 emitc::IncludeOp, emitc::MulOp, emitc::RemOp, emitc::SubOp,
960 .Case<func::CallOp, func::ConstantOp, func::FuncOp, func::ReturnOp>(
963 .Case<arith::ConstantOp>(
965 .Case<emitc::LiteralOp>([&](
auto op) {
return success(); })
967 return op.
emitOpError(
"unable to find printer for op");
973 if (isa<emitc::LiteralOp>(op))
976 os << (trailingSemicolon ?
";\n" :
"\n");
981 if (
auto iType = dyn_cast<IntegerType>(type)) {
982 switch (iType.getWidth()) {
984 return (os <<
"bool"),
success();
989 if (shouldMapToUnsigned(iType.getSignedness()))
990 return (os <<
"uint" << iType.getWidth() <<
"_t"),
success();
992 return (os <<
"int" << iType.getWidth() <<
"_t"),
success();
994 return emitError(loc,
"cannot emit integer type ") << type;
997 if (
auto fType = dyn_cast<FloatType>(type)) {
998 switch (fType.getWidth()) {
1000 return (os <<
"float"),
success();
1002 return (os <<
"double"),
success();
1004 return emitError(loc,
"cannot emit float type ") << type;
1007 if (
auto iType = dyn_cast<IndexType>(type))
1008 return (os <<
"size_t"),
success();
1009 if (
auto tType = dyn_cast<TensorType>(type)) {
1010 if (!tType.hasRank())
1011 return emitError(loc,
"cannot emit unranked tensor type");
1012 if (!tType.hasStaticShape())
1013 return emitError(loc,
"cannot emit tensor type with non static shape");
1015 if (
failed(emitType(loc, tType.getElementType())))
1017 auto shape = tType.getShape();
1018 for (
auto dimSize : shape) {
1025 if (
auto tType = dyn_cast<TupleType>(type))
1026 return emitTupleType(loc, tType.getTypes());
1027 if (
auto oType = dyn_cast<emitc::OpaqueType>(type)) {
1028 os << oType.getValue();
1031 if (
auto pType = dyn_cast<emitc::PointerType>(type)) {
1032 if (
failed(emitType(loc, pType.getPointee())))
1037 return emitError(loc,
"cannot emit type ") << type;
1041 switch (types.size()) {
1046 return emitType(loc, types.front());
1048 return emitTupleType(loc, types);
1053 os <<
"std::tuple<";
1055 types, os, [&](
Type type) {
return emitType(loc, type); })))
1062 bool declareVariablesAtTop) {
1063 CppEmitter emitter(os, declareVariablesAtTop);
1064 return emitter.emitOperation(*op,
false);
LogicalResult interleaveCommaWithError(const Container &c, raw_ostream &os, UnaryFunctor eachFn)
static LogicalResult printConstantOp(CppEmitter &emitter, Operation *operation, Attribute value)
LogicalResult interleaveWithError(ForwardIterator begin, ForwardIterator end, UnaryFunctor eachFn, NullaryFunctor betweenFn)
Convenience functions to produce interleaved output with functions returning a LogicalResult.
static LogicalResult printBinaryOperation(CppEmitter &emitter, Operation *operation, StringRef binaryOperator)
static LogicalResult printOperation(CppEmitter &emitter, emitc::ConstantOp constantOp)
Attributes are known-constant values of operations.
This class represents an argument of a Block.
unsigned getArgNumber() const
Returns the number of this argument.
Block represents an ordered list of Operations.
BlockArgListType getArguments()
Block * getSuccessor(unsigned i)
Operation * getParentOp()
Returns the closest surrounding operation that contains this block.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
NamedAttribute represents a combination of a name and an Attribute value.
This is a value defined by a result of an operation.
Operation * getOwner() const
Returns the operation that owns this result.
Operation is the basic unit of execution within MLIR.
Value getOperand(unsigned idx)
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Location getLoc()
The source location the operation was defined or derived from.
unsigned getNumOperands()
ArrayRef< NamedAttribute > getAttrs()
Return all of the attributes on this operation.
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
operand_range getOperands()
Returns an iterator on the underlying Value's.
result_range getResults()
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
unsigned getNumResults()
Return the number of results held by this operation.
This class provides iteration over the held operations of blocks directly within a region.
This class contains a list of basic blocks and a link to the parent operation it is attached to.
iterator_range< OpIterator > getOps()
llvm::iplist< Block > BlockListType
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
A utility result that is used to signal how to proceed with an ongoing walk:
static WalkResult advance()
bool wasInterrupted() const
Returns true if the walk was interrupted.
raw_ostream subclass that simplifies indention a sequence of code.
raw_indented_ostream & unindent()
Decreases the indent and returning this raw_indented_ostream.
raw_indented_ostream & indent()
Increases the indent and returning this raw_indented_ostream.
LogicalResult translateToCpp(Operation *op, raw_ostream &os, bool declareVariablesAtTop=false)
Translates the given operation to C++ code.
Include the generated interface declarations.
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value.
This iterator enumerates the elements in "forward" order.
This class represents an efficient way to signal success or failure.