33 void populateConvertToEmitCConversionPatterns(
43 dialect->addInterfaces<ArithToEmitCDialectInterface>();
52 class ArithConstantOpConversionPattern
58 matchAndRewrite(arith::ConstantOp arithConst,
59 arith::ConstantOp::Adaptor adaptor,
61 Type newTy = this->getTypeConverter()->convertType(arithConst.getType());
71 Type adaptIntegralTypeSignedness(
Type ty,
bool needsUnsigned) {
72 if (isa<IntegerType>(ty)) {
74 auto signedness = needsUnsigned
75 ? IntegerType::SignednessSemantics::Unsigned
81 if (isa<emitc::SizeTType>(ty) != needsUnsigned) {
100 matchAndRewrite(arith::CmpFOp op, OpAdaptor adaptor,
103 if (!isa<FloatType>(adaptor.getRhs().getType())) {
105 "cmpf currently only supported on "
106 "floats, not tensors/vectors thereof");
109 bool unordered =
false;
111 switch (op.getPredicate()) {
112 case arith::CmpFPredicate::AlwaysFalse: {
114 emitc::ConstantOp::create(rewriter, op.getLoc(), rewriter.
getI1Type(),
119 case arith::CmpFPredicate::OEQ:
121 predicate = emitc::CmpPredicate::eq;
123 case arith::CmpFPredicate::OGT:
125 predicate = emitc::CmpPredicate::gt;
127 case arith::CmpFPredicate::OGE:
129 predicate = emitc::CmpPredicate::ge;
131 case arith::CmpFPredicate::OLT:
133 predicate = emitc::CmpPredicate::lt;
135 case arith::CmpFPredicate::OLE:
137 predicate = emitc::CmpPredicate::le;
139 case arith::CmpFPredicate::ONE:
141 predicate = emitc::CmpPredicate::ne;
143 case arith::CmpFPredicate::ORD: {
145 auto cmp = createCheckIsOrdered(rewriter, op.getLoc(), adaptor.getLhs(),
150 case arith::CmpFPredicate::UEQ:
152 predicate = emitc::CmpPredicate::eq;
154 case arith::CmpFPredicate::UGT:
156 predicate = emitc::CmpPredicate::gt;
158 case arith::CmpFPredicate::UGE:
160 predicate = emitc::CmpPredicate::ge;
162 case arith::CmpFPredicate::ULT:
164 predicate = emitc::CmpPredicate::lt;
166 case arith::CmpFPredicate::ULE:
168 predicate = emitc::CmpPredicate::le;
170 case arith::CmpFPredicate::UNE:
172 predicate = emitc::CmpPredicate::ne;
174 case arith::CmpFPredicate::UNO: {
176 auto cmp = createCheckIsUnordered(rewriter, op.getLoc(), adaptor.getLhs(),
181 case arith::CmpFPredicate::AlwaysTrue: {
183 emitc::ConstantOp::create(rewriter, op.getLoc(), rewriter.
getI1Type(),
192 emitc::CmpOp::create(rewriter, op.getLoc(), op.getType(), predicate,
193 adaptor.getLhs(), adaptor.getRhs());
197 auto isUnordered = createCheckIsUnordered(
198 rewriter, op.getLoc(), adaptor.getLhs(), adaptor.getRhs());
200 isUnordered, cmpResult);
204 auto isOrdered = createCheckIsOrdered(rewriter, op.getLoc(),
205 adaptor.getLhs(), adaptor.getRhs());
207 isOrdered, cmpResult);
214 Value operand)
const {
216 return emitc::CmpOp::create(rewriter, loc, rewriter.
getI1Type(),
217 emitc::CmpPredicate::ne, operand, operand);
222 Value operand)
const {
224 return emitc::CmpOp::create(rewriter, loc, rewriter.
getI1Type(),
225 emitc::CmpPredicate::eq, operand, operand);
232 auto firstIsNaN = isNaN(rewriter, loc, first);
233 auto secondIsNaN = isNaN(rewriter, loc, second);
234 return emitc::LogicalOrOp::create(rewriter, loc, rewriter.
getI1Type(),
235 firstIsNaN, secondIsNaN);
242 auto firstIsNotNaN = isNotNaN(rewriter, loc, first);
243 auto secondIsNotNaN = isNotNaN(rewriter, loc, second);
244 return emitc::LogicalAndOp::create(rewriter, loc, rewriter.
getI1Type(),
245 firstIsNotNaN, secondIsNotNaN);
253 bool needsUnsignedCmp(arith::CmpIPredicate pred)
const {
255 case arith::CmpIPredicate::eq:
256 case arith::CmpIPredicate::ne:
257 case arith::CmpIPredicate::slt:
258 case arith::CmpIPredicate::sle:
259 case arith::CmpIPredicate::sgt:
260 case arith::CmpIPredicate::sge:
262 case arith::CmpIPredicate::ult:
263 case arith::CmpIPredicate::ule:
264 case arith::CmpIPredicate::ugt:
265 case arith::CmpIPredicate::uge:
268 llvm_unreachable(
"unknown cmpi predicate kind");
273 case arith::CmpIPredicate::eq:
274 return emitc::CmpPredicate::eq;
275 case arith::CmpIPredicate::ne:
276 return emitc::CmpPredicate::ne;
277 case arith::CmpIPredicate::slt:
278 case arith::CmpIPredicate::ult:
279 return emitc::CmpPredicate::lt;
280 case arith::CmpIPredicate::sle:
281 case arith::CmpIPredicate::ule:
282 return emitc::CmpPredicate::le;
283 case arith::CmpIPredicate::sgt:
284 case arith::CmpIPredicate::ugt:
285 return emitc::CmpPredicate::gt;
286 case arith::CmpIPredicate::sge:
287 case arith::CmpIPredicate::uge:
288 return emitc::CmpPredicate::ge;
290 llvm_unreachable(
"unknown cmpi predicate kind");
294 matchAndRewrite(arith::CmpIOp op, OpAdaptor adaptor,
297 Type type = adaptor.getLhs().getType();
300 op,
"expected integer or size_t/ssize_t/ptrdiff_t type");
303 bool needsUnsigned = needsUnsignedCmp(op.getPredicate());
306 Type arithmeticType = adaptIntegralTypeSignedness(type, needsUnsigned);
307 Value lhs = adaptValueType(adaptor.getLhs(), rewriter, arithmeticType);
308 Value rhs = adaptValueType(adaptor.getRhs(), rewriter, arithmeticType);
320 matchAndRewrite(arith::NegFOp op, OpAdaptor adaptor,
323 auto adaptedOp = adaptor.getOperand();
324 auto adaptedOpType = adaptedOp.getType();
326 if (isa<TensorType>(adaptedOpType) || isa<VectorType>(adaptedOpType)) {
329 "negf currently only supports scalar types, not vectors or tensors");
334 op.getLoc(),
"floating-point type is not supported by EmitC");
343 template <
typename ArithOp,
bool castToUn
signed>
349 matchAndRewrite(ArithOp op,
typename ArithOp::Adaptor adaptor,
352 Type opReturnType = this->getTypeConverter()->convertType(op.getType());
353 if (!opReturnType || !(isa<IntegerType>(opReturnType) ||
356 op,
"expected integer or size_t/ssize_t/ptrdiff_t result type");
358 if (adaptor.getOperands().size() != 1) {
360 op,
"CastConversion only supports unary ops");
363 Type operandType = adaptor.getIn().getType();
364 if (!operandType || !(isa<IntegerType>(operandType) ||
367 op,
"expected integer or size_t/ssize_t/ptrdiff_t operand type");
370 if (operandType.
isInteger(1) && !castToUnsigned)
372 "operation not supported on i1 type");
381 auto constOne = emitc::ConstantOp::create(
382 rewriter, op.getLoc(), operandType, rewriter.
getOneAttr(attrType));
383 auto oneAndOperand = emitc::BitwiseAndOp::create(
384 rewriter, op.getLoc(), operandType, adaptor.getIn(), constOne);
391 (isa<IntegerType>(operandType) && isa<IntegerType>(opReturnType) &&
394 bool doUnsigned = castToUnsigned || isTruncation;
398 Type castDestType = adaptIntegralTypeSignedness(opReturnType, doUnsigned);
401 Type castSrcType = adaptIntegralTypeSignedness(operandType, doUnsigned);
402 Value actualOp = adaptValueType(adaptor.getIn(), rewriter, castSrcType);
406 emitc::CastOp::create(rewriter, op.getLoc(), castDestType, actualOp);
409 auto result = adaptValueType(cast, rewriter, opReturnType);
416 template <
typename ArithOp>
417 class UnsignedCastConversion :
public CastConversion<ArithOp, true> {
418 using CastConversion<ArithOp, true>::CastConversion;
421 template <
typename ArithOp>
422 class SignedCastConversion :
public CastConversion<ArithOp, false> {
423 using CastConversion<ArithOp, false>::CastConversion;
426 template <
typename ArithOp,
typename EmitCOp>
432 matchAndRewrite(ArithOp arithOp,
typename ArithOp::Adaptor adaptor,
435 Type newTy = this->getTypeConverter()->convertType(arithOp.getType());
438 "converting result type failed");
439 rewriter.template replaceOpWithNewOp<EmitCOp>(arithOp, newTy,
440 adaptor.getOperands());
446 template <
class ArithOp,
class EmitCOp>
452 matchAndRewrite(ArithOp uiBinOp,
typename ArithOp::Adaptor adaptor,
454 Type newRetTy = this->getTypeConverter()->convertType(uiBinOp.getType());
457 "converting result type failed");
458 if (!isa<IntegerType>(newRetTy)) {
462 adaptIntegralTypeSignedness(newRetTy,
true);
465 "converting result type failed");
466 Value lhsAdapted = adaptValueType(uiBinOp.getLhs(), rewriter, unsignedType);
467 Value rhsAdapted = adaptValueType(uiBinOp.getRhs(), rewriter, unsignedType);
469 auto newDivOp = EmitCOp::create(rewriter, uiBinOp.getLoc(), unsignedType,
471 Value resultAdapted = adaptValueType(newDivOp, rewriter, newRetTy);
472 rewriter.
replaceOp(uiBinOp, resultAdapted);
477 template <
typename ArithOp,
typename EmitCOp>
483 matchAndRewrite(ArithOp op,
typename ArithOp::Adaptor adaptor,
486 Type type = this->getTypeConverter()->convertType(op.getType());
489 op,
"expected integer or size_t/ssize_t/ptrdiff_t type");
497 Type arithmeticType = type;
499 !bitEnumContainsAll(op.getOverflowFlags(),
500 arith::IntegerOverflowFlags::nsw)) {
507 Value lhs = adaptValueType(adaptor.getLhs(), rewriter, arithmeticType);
508 Value rhs = adaptValueType(adaptor.getRhs(), rewriter, arithmeticType);
510 Value arithmeticResult =
511 EmitCOp::create(rewriter, op.getLoc(), arithmeticType, lhs, rhs);
513 Value result = adaptValueType(arithmeticResult, rewriter, type);
520 template <
typename ArithOp,
typename EmitCOp>
526 matchAndRewrite(ArithOp op,
typename ArithOp::Adaptor adaptor,
529 Type type = this->getTypeConverter()->convertType(op.getType());
530 if (!isa_and_nonnull<IntegerType>(type)) {
533 "expected integer type, vector/tensor support not yet implemented");
544 Type arithmeticType =
545 adaptIntegralTypeSignedness(type,
true);
547 Value lhs = adaptValueType(adaptor.getLhs(), rewriter, arithmeticType);
548 Value rhs = adaptValueType(adaptor.getRhs(), rewriter, arithmeticType);
550 Value arithmeticResult =
551 EmitCOp::create(rewriter, op.getLoc(), arithmeticType, lhs, rhs);
553 Value result = adaptValueType(arithmeticResult, rewriter, type);
560 template <
typename ArithOp,
typename EmitCOp,
bool isUn
signedOp>
566 matchAndRewrite(ArithOp op,
typename ArithOp::Adaptor adaptor,
569 Type type = this->getTypeConverter()->convertType(op.getType());
572 op,
"expected integer or size_t/ssize_t/ptrdiff_t type");
579 Type arithmeticType = adaptIntegralTypeSignedness(type, isUnsignedOp);
581 Value lhs = adaptValueType(adaptor.getLhs(), rewriter, arithmeticType);
583 Type rhsType = adaptIntegralTypeSignedness(adaptor.getRhs().getType(),
585 Value rhs = adaptValueType(adaptor.getRhs(), rewriter, rhsType);
590 Value eight = emitc::ConstantOp::create(rewriter, op.getLoc(), rhsType,
592 emitc::CallOpaqueOp sizeOfCall = emitc::CallOpaqueOp::create(
594 width = emitc::MulOp::create(rewriter, op.getLoc(), rhsType, eight,
595 sizeOfCall.getResult(0));
597 width = emitc::ConstantOp::create(
598 rewriter, op.getLoc(), rhsType,
603 emitc::CmpOp::create(rewriter, op.getLoc(), rewriter.
getI1Type(),
604 emitc::CmpPredicate::lt, rhs, width);
607 Value poison = emitc::ConstantOp::create(
608 rewriter, op.getLoc(), arithmeticType,
609 (isa<IntegerType>(arithmeticType)
613 emitc::ExpressionOp ternary = emitc::ExpressionOp::create(
614 rewriter, op.getLoc(), arithmeticType,
false);
615 Block &bodyBlock = ternary.getBodyRegion().emplaceBlock();
618 Value arithmeticResult =
619 EmitCOp::create(rewriter, op.getLoc(), arithmeticType, lhs, rhs);
620 Value resultOrPoison =
621 emitc::ConditionalOp::create(rewriter, op.getLoc(), arithmeticType,
622 excessCheck, arithmeticResult, poison);
623 emitc::YieldOp::create(rewriter, op.getLoc(), resultOrPoison);
626 Value result = adaptValueType(ternary, rewriter, type);
633 template <
typename ArithOp,
typename EmitCOp>
634 class SignedShiftOpConversion final
635 :
public ShiftOpConversion<ArithOp, EmitCOp, false> {
636 using ShiftOpConversion<ArithOp, EmitCOp, false>::ShiftOpConversion;
639 template <
typename ArithOp,
typename EmitCOp>
640 class UnsignedShiftOpConversion final
641 :
public ShiftOpConversion<ArithOp, EmitCOp, true> {
642 using ShiftOpConversion<ArithOp, EmitCOp, true>::ShiftOpConversion;
650 matchAndRewrite(arith::SelectOp selectOp, OpAdaptor adaptor,
653 Type dstType = getTypeConverter()->convertType(selectOp.getType());
657 if (!adaptor.getCondition().getType().isInteger(1))
660 "can only be converted if condition is a scalar of type i1");
663 adaptor.getOperands());
670 template <
typename CastOp>
677 matchAndRewrite(CastOp castOp,
typename CastOp::Adaptor adaptor,
680 Type operandType = adaptor.getIn().getType();
683 "unsupported cast source type");
685 Type dstType = this->getTypeConverter()->convertType(castOp.getType());
693 "unsupported cast destination type");
697 Type actualResultType = dstType;
698 if (isa<arith::FPToUIOp>(castOp)) {
704 Value result = emitc::CastOp::create(
705 rewriter, castOp.getLoc(), actualResultType, adaptor.getOperands());
707 if (isa<arith::FPToUIOp>(castOp)) {
709 emitc::CastOp::create(rewriter, castOp.getLoc(), dstType, result);
718 template <
typename CastOp>
725 matchAndRewrite(CastOp castOp,
typename CastOp::Adaptor adaptor,
728 Type operandType = adaptor.getIn().getType();
731 "unsupported cast source type");
733 Type dstType = this->getTypeConverter()->convertType(castOp.getType());
739 "unsupported cast destination type");
743 Type actualOperandType = operandType;
744 if (isa<arith::UIToFPOp>(castOp)) {
749 Value fpCastOperand = adaptor.getIn();
750 if (actualOperandType != operandType) {
751 fpCastOperand = emitc::CastOp::create(rewriter, castOp.getLoc(),
752 actualOperandType, fpCastOperand);
761 template <
typename CastOp>
768 matchAndRewrite(CastOp castOp,
typename CastOp::Adaptor adaptor,
771 Type operandType = adaptor.getIn().getType();
774 "unsupported cast source type");
775 if (
auto roundingModeOp =
776 dyn_cast<arith::ArithRoundingModeInterface>(*castOp)) {
778 if (roundingModeOp.getRoundingModeAttr())
782 Type dstType = this->getTypeConverter()->convertType(castOp.getType());
788 "unsupported cast destination type");
790 Value fpCastOperand = adaptor.getIn();
811 ArithConstantOpConversionPattern,
812 ArithOpConversion<arith::AddFOp, emitc::AddOp>,
813 ArithOpConversion<arith::DivFOp, emitc::DivOp>,
814 ArithOpConversion<arith::DivSIOp, emitc::DivOp>,
815 ArithOpConversion<arith::MulFOp, emitc::MulOp>,
816 ArithOpConversion<arith::RemSIOp, emitc::RemOp>,
817 ArithOpConversion<arith::SubFOp, emitc::SubOp>,
818 BinaryUIOpConversion<arith::DivUIOp, emitc::DivOp>,
819 BinaryUIOpConversion<arith::RemUIOp, emitc::RemOp>,
820 IntegerOpConversion<arith::AddIOp, emitc::AddOp>,
821 IntegerOpConversion<arith::MulIOp, emitc::MulOp>,
822 IntegerOpConversion<arith::SubIOp, emitc::SubOp>,
823 BitwiseOpConversion<arith::AndIOp, emitc::BitwiseAndOp>,
824 BitwiseOpConversion<arith::OrIOp, emitc::BitwiseOrOp>,
825 BitwiseOpConversion<arith::XOrIOp, emitc::BitwiseXorOp>,
826 UnsignedShiftOpConversion<arith::ShLIOp, emitc::BitwiseLeftShiftOp>,
827 SignedShiftOpConversion<arith::ShRSIOp, emitc::BitwiseRightShiftOp>,
828 UnsignedShiftOpConversion<arith::ShRUIOp, emitc::BitwiseRightShiftOp>,
834 UnsignedCastConversion<arith::TruncIOp>,
835 SignedCastConversion<arith::ExtSIOp>,
836 UnsignedCastConversion<arith::ExtUIOp>,
837 SignedCastConversion<arith::IndexCastOp>,
838 UnsignedCastConversion<arith::IndexCastUIOp>,
839 ItoFCastOpConversion<arith::SIToFPOp>,
840 ItoFCastOpConversion<arith::UIToFPOp>,
841 FtoICastOpConversion<arith::FPToSIOp>,
842 FtoICastOpConversion<arith::FPToUIOp>,
843 FpCastOpConversion<arith::ExtFOp>,
844 FpCastOpConversion<arith::TruncFOp>
845 >(typeConverter, ctx);
Block represents an ordered list of Operations.
IntegerAttr getIndexAttr(int64_t value)
IntegerAttr getIntegerAttr(Type type, int64_t value)
IntegerType getIntegerType(unsigned width)
BoolAttr getBoolAttr(bool value)
TypedAttr getOneAttr(Type type)
This class implements a pattern rewriter for use with ConversionPatterns.
void replaceOp(Operation *op, ValueRange newValues) override
Replace the given operation with the new values.
This class describes a specific conversion target.
ConvertToEmitCPatternInterface(Dialect *dialect)
The DialectRegistry maps a dialect namespace to a constructor for the matching dialect.
bool addExtension(TypeID extensionID, std::unique_ptr< DialectExtensionBase > extension)
Add the given extension to the registry.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
Block::iterator getInsertionPoint() const
Returns the current insertion point of the builder.
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
OpConversionPattern is a wrapper around ConversionPattern that allows for matching and rewriting agai...
OpConversionPattern(MLIRContext *context, PatternBenefit benefit=1)
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
MLIRContext * getContext() const
Return the MLIRContext in which this type was uniqued.
bool isSignedInteger() const
Return true if this is a signed integer type (with the specified width).
bool isSignlessInteger() const
Return true if this is a signless integer type (with the specified width).
bool isUnsignedInteger() const
Return true if this is an unsigned integer type (with the specified width).
bool isInteger() const
Return true if this is an integer type (with the specified width).
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Location getLoc() const
Return the location of this value.
bool isSupportedFloatType(mlir::Type type)
Determines whether type is a valid floating-point type in EmitC.
bool isPointerWideType(mlir::Type type)
Determines whether type is a emitc.size_t/ssize_t type.
bool isSupportedIntegerType(mlir::Type type)
Determines whether type is a valid integer type in EmitC.
CmpPredicate
Copy of the enum from arith and index to allow the common integer range infrastructure to not depend ...
Include the generated interface declarations.
void registerConvertArithToEmitCInterface(DialectRegistry ®istry)
void populateEmitCSizeTTypeConversions(TypeConverter &converter)
const FrozenRewritePatternSet & patterns
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
void populateArithToEmitCPatterns(TypeConverter &typeConverter, RewritePatternSet &patterns)