33 void populateConvertToEmitCConversionPatterns(
43 dialect->addInterfaces<ArithToEmitCDialectInterface>();
52 class ArithConstantOpConversionPattern
58 matchAndRewrite(arith::ConstantOp arithConst,
59 arith::ConstantOp::Adaptor adaptor,
61 Type newTy = this->getTypeConverter()->convertType(arithConst.getType());
71 Type adaptIntegralTypeSignedness(
Type ty,
bool needsUnsigned) {
72 if (isa<IntegerType>(ty)) {
74 auto signedness = needsUnsigned
75 ? IntegerType::SignednessSemantics::Unsigned
81 if (isa<emitc::SizeTType>(ty) != needsUnsigned) {
100 matchAndRewrite(arith::CmpFOp op, OpAdaptor adaptor,
103 if (!isa<FloatType>(adaptor.getRhs().getType())) {
105 "cmpf currently only supported on "
106 "floats, not tensors/vectors thereof");
109 bool unordered =
false;
111 switch (op.getPredicate()) {
112 case arith::CmpFPredicate::AlwaysFalse: {
113 auto constant = rewriter.
create<emitc::ConstantOp>(
119 case arith::CmpFPredicate::OEQ:
121 predicate = emitc::CmpPredicate::eq;
123 case arith::CmpFPredicate::OGT:
125 predicate = emitc::CmpPredicate::gt;
127 case arith::CmpFPredicate::OGE:
129 predicate = emitc::CmpPredicate::ge;
131 case arith::CmpFPredicate::OLT:
133 predicate = emitc::CmpPredicate::lt;
135 case arith::CmpFPredicate::OLE:
137 predicate = emitc::CmpPredicate::le;
139 case arith::CmpFPredicate::ONE:
141 predicate = emitc::CmpPredicate::ne;
143 case arith::CmpFPredicate::ORD: {
145 auto cmp = createCheckIsOrdered(rewriter, op.getLoc(), adaptor.getLhs(),
150 case arith::CmpFPredicate::UEQ:
152 predicate = emitc::CmpPredicate::eq;
154 case arith::CmpFPredicate::UGT:
156 predicate = emitc::CmpPredicate::gt;
158 case arith::CmpFPredicate::UGE:
160 predicate = emitc::CmpPredicate::ge;
162 case arith::CmpFPredicate::ULT:
164 predicate = emitc::CmpPredicate::lt;
166 case arith::CmpFPredicate::ULE:
168 predicate = emitc::CmpPredicate::le;
170 case arith::CmpFPredicate::UNE:
172 predicate = emitc::CmpPredicate::ne;
174 case arith::CmpFPredicate::UNO: {
176 auto cmp = createCheckIsUnordered(rewriter, op.getLoc(), adaptor.getLhs(),
181 case arith::CmpFPredicate::AlwaysTrue: {
182 auto constant = rewriter.
create<emitc::ConstantOp>(
192 rewriter.
create<emitc::CmpOp>(op.getLoc(), op.getType(), predicate,
193 adaptor.getLhs(), adaptor.getRhs());
197 auto isUnordered = createCheckIsUnordered(
198 rewriter, op.getLoc(), adaptor.getLhs(), adaptor.getRhs());
200 isUnordered, cmpResult);
204 auto isOrdered = createCheckIsOrdered(rewriter, op.getLoc(),
205 adaptor.getLhs(), adaptor.getRhs());
207 isOrdered, cmpResult);
214 Value operand)
const {
216 return rewriter.
create<emitc::CmpOp>(
217 loc, rewriter.
getI1Type(), emitc::CmpPredicate::ne, operand, operand);
222 Value operand)
const {
224 return rewriter.
create<emitc::CmpOp>(
225 loc, rewriter.
getI1Type(), emitc::CmpPredicate::eq, operand, operand);
232 auto firstIsNaN = isNaN(rewriter, loc, first);
233 auto secondIsNaN = isNaN(rewriter, loc, second);
235 firstIsNaN, secondIsNaN);
242 auto firstIsNotNaN = isNotNaN(rewriter, loc, first);
243 auto secondIsNotNaN = isNotNaN(rewriter, loc, second);
245 firstIsNotNaN, secondIsNotNaN);
253 bool needsUnsignedCmp(arith::CmpIPredicate pred)
const {
255 case arith::CmpIPredicate::eq:
256 case arith::CmpIPredicate::ne:
257 case arith::CmpIPredicate::slt:
258 case arith::CmpIPredicate::sle:
259 case arith::CmpIPredicate::sgt:
260 case arith::CmpIPredicate::sge:
262 case arith::CmpIPredicate::ult:
263 case arith::CmpIPredicate::ule:
264 case arith::CmpIPredicate::ugt:
265 case arith::CmpIPredicate::uge:
268 llvm_unreachable(
"unknown cmpi predicate kind");
273 case arith::CmpIPredicate::eq:
274 return emitc::CmpPredicate::eq;
275 case arith::CmpIPredicate::ne:
276 return emitc::CmpPredicate::ne;
277 case arith::CmpIPredicate::slt:
278 case arith::CmpIPredicate::ult:
279 return emitc::CmpPredicate::lt;
280 case arith::CmpIPredicate::sle:
281 case arith::CmpIPredicate::ule:
282 return emitc::CmpPredicate::le;
283 case arith::CmpIPredicate::sgt:
284 case arith::CmpIPredicate::ugt:
285 return emitc::CmpPredicate::gt;
286 case arith::CmpIPredicate::sge:
287 case arith::CmpIPredicate::uge:
288 return emitc::CmpPredicate::ge;
290 llvm_unreachable(
"unknown cmpi predicate kind");
294 matchAndRewrite(arith::CmpIOp op, OpAdaptor adaptor,
297 Type type = adaptor.getLhs().getType();
300 op,
"expected integer or size_t/ssize_t/ptrdiff_t type");
303 bool needsUnsigned = needsUnsignedCmp(op.getPredicate());
306 Type arithmeticType = adaptIntegralTypeSignedness(type, needsUnsigned);
307 Value lhs = adaptValueType(adaptor.getLhs(), rewriter, arithmeticType);
308 Value rhs = adaptValueType(adaptor.getRhs(), rewriter, arithmeticType);
320 matchAndRewrite(arith::NegFOp op, OpAdaptor adaptor,
323 auto adaptedOp = adaptor.getOperand();
324 auto adaptedOpType = adaptedOp.getType();
326 if (isa<TensorType>(adaptedOpType) || isa<VectorType>(adaptedOpType)) {
329 "negf currently only supports scalar types, not vectors or tensors");
334 op.getLoc(),
"floating-point type is not supported by EmitC");
343 template <
typename ArithOp,
bool castToUn
signed>
349 matchAndRewrite(ArithOp op,
typename ArithOp::Adaptor adaptor,
352 Type opReturnType = this->getTypeConverter()->convertType(op.getType());
353 if (!opReturnType || !(isa<IntegerType>(opReturnType) ||
356 op,
"expected integer or size_t/ssize_t/ptrdiff_t result type");
358 if (adaptor.getOperands().size() != 1) {
360 op,
"CastConversion only supports unary ops");
363 Type operandType = adaptor.getIn().getType();
364 if (!operandType || !(isa<IntegerType>(operandType) ||
367 op,
"expected integer or size_t/ssize_t/ptrdiff_t operand type");
370 if (operandType.
isInteger(1) && !castToUnsigned)
372 "operation not supported on i1 type");
381 auto constOne = rewriter.
create<emitc::ConstantOp>(
382 op.getLoc(), operandType, rewriter.
getOneAttr(attrType));
383 auto oneAndOperand = rewriter.
create<emitc::BitwiseAndOp>(
384 op.getLoc(), operandType, adaptor.getIn(), constOne);
391 (isa<IntegerType>(operandType) && isa<IntegerType>(opReturnType) &&
394 bool doUnsigned = castToUnsigned || isTruncation;
398 Type castDestType = adaptIntegralTypeSignedness(opReturnType, doUnsigned);
401 Type castSrcType = adaptIntegralTypeSignedness(operandType, doUnsigned);
402 Value actualOp = adaptValueType(adaptor.getIn(), rewriter, castSrcType);
405 auto cast = rewriter.template create<emitc::CastOp>(op.getLoc(),
406 castDestType, actualOp);
409 auto result = adaptValueType(cast, rewriter, opReturnType);
416 template <
typename ArithOp>
417 class UnsignedCastConversion :
public CastConversion<ArithOp, true> {
418 using CastConversion<ArithOp, true>::CastConversion;
421 template <
typename ArithOp>
422 class SignedCastConversion :
public CastConversion<ArithOp, false> {
423 using CastConversion<ArithOp, false>::CastConversion;
426 template <
typename ArithOp,
typename EmitCOp>
432 matchAndRewrite(ArithOp arithOp,
typename ArithOp::Adaptor adaptor,
435 Type newTy = this->getTypeConverter()->convertType(arithOp.getType());
438 "converting result type failed");
439 rewriter.template replaceOpWithNewOp<EmitCOp>(arithOp, newTy,
440 adaptor.getOperands());
446 template <
class ArithOp,
class EmitCOp>
452 matchAndRewrite(ArithOp uiBinOp,
typename ArithOp::Adaptor adaptor,
454 Type newRetTy = this->getTypeConverter()->convertType(uiBinOp.getType());
457 "converting result type failed");
458 if (!isa<IntegerType>(newRetTy)) {
462 adaptIntegralTypeSignedness(newRetTy,
true);
465 "converting result type failed");
466 Value lhsAdapted = adaptValueType(uiBinOp.getLhs(), rewriter, unsignedType);
467 Value rhsAdapted = adaptValueType(uiBinOp.getRhs(), rewriter, unsignedType);
470 rewriter.
create<EmitCOp>(uiBinOp.getLoc(), unsignedType,
472 Value resultAdapted = adaptValueType(newDivOp, rewriter, newRetTy);
473 rewriter.
replaceOp(uiBinOp, resultAdapted);
478 template <
typename ArithOp,
typename EmitCOp>
484 matchAndRewrite(ArithOp op,
typename ArithOp::Adaptor adaptor,
487 Type type = this->getTypeConverter()->convertType(op.getType());
490 op,
"expected integer or size_t/ssize_t/ptrdiff_t type");
498 Type arithmeticType = type;
500 !bitEnumContainsAll(op.getOverflowFlags(),
501 arith::IntegerOverflowFlags::nsw)) {
508 Value lhs = adaptValueType(adaptor.getLhs(), rewriter, arithmeticType);
509 Value rhs = adaptValueType(adaptor.getRhs(), rewriter, arithmeticType);
511 Value arithmeticResult = rewriter.template create<EmitCOp>(
512 op.getLoc(), arithmeticType, lhs, rhs);
514 Value result = adaptValueType(arithmeticResult, rewriter, type);
521 template <
typename ArithOp,
typename EmitCOp>
527 matchAndRewrite(ArithOp op,
typename ArithOp::Adaptor adaptor,
530 Type type = this->getTypeConverter()->convertType(op.getType());
531 if (!isa_and_nonnull<IntegerType>(type)) {
534 "expected integer type, vector/tensor support not yet implemented");
545 Type arithmeticType =
546 adaptIntegralTypeSignedness(type,
true);
548 Value lhs = adaptValueType(adaptor.getLhs(), rewriter, arithmeticType);
549 Value rhs = adaptValueType(adaptor.getRhs(), rewriter, arithmeticType);
551 Value arithmeticResult = rewriter.template create<EmitCOp>(
552 op.getLoc(), arithmeticType, lhs, rhs);
554 Value result = adaptValueType(arithmeticResult, rewriter, type);
561 template <
typename ArithOp,
typename EmitCOp,
bool isUn
signedOp>
567 matchAndRewrite(ArithOp op,
typename ArithOp::Adaptor adaptor,
570 Type type = this->getTypeConverter()->convertType(op.getType());
573 op,
"expected integer or size_t/ssize_t/ptrdiff_t type");
580 Type arithmeticType = adaptIntegralTypeSignedness(type, isUnsignedOp);
582 Value lhs = adaptValueType(adaptor.getLhs(), rewriter, arithmeticType);
584 Type rhsType = adaptIntegralTypeSignedness(adaptor.getRhs().getType(),
586 Value rhs = adaptValueType(adaptor.getRhs(), rewriter, rhsType);
593 emitc::CallOpaqueOp sizeOfCall = rewriter.
create<emitc::CallOpaqueOp>(
595 width = rewriter.
create<emitc::MulOp>(op.getLoc(), rhsType, eight,
596 sizeOfCall.getResult(0));
598 width = rewriter.
create<emitc::ConstantOp>(
599 op.getLoc(), rhsType,
604 op.getLoc(), rewriter.
getI1Type(), emitc::CmpPredicate::lt, rhs, width);
608 op.getLoc(), arithmeticType,
609 (isa<IntegerType>(arithmeticType)
613 emitc::ExpressionOp ternary = rewriter.
create<emitc::ExpressionOp>(
614 op.getLoc(), arithmeticType,
false);
615 Block &bodyBlock = ternary.getBodyRegion().emplaceBlock();
618 Value arithmeticResult =
619 rewriter.
create<EmitCOp>(op.getLoc(), arithmeticType, lhs, rhs);
620 Value resultOrPoison = rewriter.
create<emitc::ConditionalOp>(
621 op.getLoc(), arithmeticType, excessCheck, arithmeticResult, poison);
622 rewriter.
create<emitc::YieldOp>(op.getLoc(), resultOrPoison);
625 Value result = adaptValueType(ternary, rewriter, type);
632 template <
typename ArithOp,
typename EmitCOp>
633 class SignedShiftOpConversion final
634 :
public ShiftOpConversion<ArithOp, EmitCOp, false> {
635 using ShiftOpConversion<ArithOp, EmitCOp, false>::ShiftOpConversion;
638 template <
typename ArithOp,
typename EmitCOp>
639 class UnsignedShiftOpConversion final
640 :
public ShiftOpConversion<ArithOp, EmitCOp, true> {
641 using ShiftOpConversion<ArithOp, EmitCOp, true>::ShiftOpConversion;
649 matchAndRewrite(arith::SelectOp selectOp, OpAdaptor adaptor,
652 Type dstType = getTypeConverter()->convertType(selectOp.getType());
656 if (!adaptor.getCondition().getType().isInteger(1))
659 "can only be converted if condition is a scalar of type i1");
662 adaptor.getOperands());
669 template <
typename CastOp>
676 matchAndRewrite(CastOp castOp,
typename CastOp::Adaptor adaptor,
679 Type operandType = adaptor.getIn().getType();
682 "unsupported cast source type");
684 Type dstType = this->getTypeConverter()->convertType(castOp.getType());
692 "unsupported cast destination type");
696 Type actualResultType = dstType;
697 if (isa<arith::FPToUIOp>(castOp)) {
704 castOp.getLoc(), actualResultType, adaptor.getOperands());
706 if (isa<arith::FPToUIOp>(castOp)) {
707 result = rewriter.
create<emitc::CastOp>(castOp.getLoc(), dstType, result);
716 template <
typename CastOp>
723 matchAndRewrite(CastOp castOp,
typename CastOp::Adaptor adaptor,
726 Type operandType = adaptor.getIn().getType();
729 "unsupported cast source type");
731 Type dstType = this->getTypeConverter()->convertType(castOp.getType());
737 "unsupported cast destination type");
741 Type actualOperandType = operandType;
742 if (isa<arith::UIToFPOp>(castOp)) {
747 Value fpCastOperand = adaptor.getIn();
748 if (actualOperandType != operandType) {
749 fpCastOperand = rewriter.template create<emitc::CastOp>(
750 castOp.getLoc(), actualOperandType, fpCastOperand);
759 template <
typename CastOp>
766 matchAndRewrite(CastOp castOp,
typename CastOp::Adaptor adaptor,
769 Type operandType = adaptor.getIn().getType();
772 "unsupported cast source type");
773 if (
auto roundingModeOp =
774 dyn_cast<arith::ArithRoundingModeInterface>(*castOp)) {
776 if (roundingModeOp.getRoundingModeAttr())
780 Type dstType = this->getTypeConverter()->convertType(castOp.getType());
786 "unsupported cast destination type");
788 Value fpCastOperand = adaptor.getIn();
809 ArithConstantOpConversionPattern,
810 ArithOpConversion<arith::AddFOp, emitc::AddOp>,
811 ArithOpConversion<arith::DivFOp, emitc::DivOp>,
812 ArithOpConversion<arith::DivSIOp, emitc::DivOp>,
813 ArithOpConversion<arith::MulFOp, emitc::MulOp>,
814 ArithOpConversion<arith::RemSIOp, emitc::RemOp>,
815 ArithOpConversion<arith::SubFOp, emitc::SubOp>,
816 BinaryUIOpConversion<arith::DivUIOp, emitc::DivOp>,
817 BinaryUIOpConversion<arith::RemUIOp, emitc::RemOp>,
818 IntegerOpConversion<arith::AddIOp, emitc::AddOp>,
819 IntegerOpConversion<arith::MulIOp, emitc::MulOp>,
820 IntegerOpConversion<arith::SubIOp, emitc::SubOp>,
821 BitwiseOpConversion<arith::AndIOp, emitc::BitwiseAndOp>,
822 BitwiseOpConversion<arith::OrIOp, emitc::BitwiseOrOp>,
823 BitwiseOpConversion<arith::XOrIOp, emitc::BitwiseXorOp>,
824 UnsignedShiftOpConversion<arith::ShLIOp, emitc::BitwiseLeftShiftOp>,
825 SignedShiftOpConversion<arith::ShRSIOp, emitc::BitwiseRightShiftOp>,
826 UnsignedShiftOpConversion<arith::ShRUIOp, emitc::BitwiseRightShiftOp>,
832 UnsignedCastConversion<arith::TruncIOp>,
833 SignedCastConversion<arith::ExtSIOp>,
834 UnsignedCastConversion<arith::ExtUIOp>,
835 SignedCastConversion<arith::IndexCastOp>,
836 UnsignedCastConversion<arith::IndexCastUIOp>,
837 ItoFCastOpConversion<arith::SIToFPOp>,
838 ItoFCastOpConversion<arith::UIToFPOp>,
839 FtoICastOpConversion<arith::FPToSIOp>,
840 FtoICastOpConversion<arith::FPToUIOp>,
841 FpCastOpConversion<arith::ExtFOp>,
842 FpCastOpConversion<arith::TruncFOp>
843 >(typeConverter, ctx);
Block represents an ordered list of Operations.
IntegerAttr getIndexAttr(int64_t value)
IntegerAttr getIntegerAttr(Type type, int64_t value)
IntegerType getIntegerType(unsigned width)
BoolAttr getBoolAttr(bool value)
TypedAttr getOneAttr(Type type)
This class implements a pattern rewriter for use with ConversionPatterns.
void replaceOp(Operation *op, ValueRange newValues) override
Replace the given operation with the new values.
This class describes a specific conversion target.
ConvertToEmitCPatternInterface(Dialect *dialect)
The DialectRegistry maps a dialect namespace to a constructor for the matching dialect.
bool addExtension(TypeID extensionID, std::unique_ptr< DialectExtensionBase > extension)
Add the given extension to the registry.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
Block::iterator getInsertionPoint() const
Returns the current insertion point of the builder.
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
OpConversionPattern is a wrapper around ConversionPattern that allows for matching and rewriting agai...
OpConversionPattern(MLIRContext *context, PatternBenefit benefit=1)
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
MLIRContext * getContext() const
Return the MLIRContext in which this type was uniqued.
bool isSignedInteger() const
Return true if this is a signed integer type (with the specified width).
bool isSignlessInteger() const
Return true if this is a signless integer type (with the specified width).
bool isUnsignedInteger() const
Return true if this is an unsigned integer type (with the specified width).
bool isInteger() const
Return true if this is an integer type (with the specified width).
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Location getLoc() const
Return the location of this value.
bool isSupportedFloatType(mlir::Type type)
Determines whether type is a valid floating-point type in EmitC.
bool isPointerWideType(mlir::Type type)
Determines whether type is a emitc.size_t/ssize_t type.
bool isSupportedIntegerType(mlir::Type type)
Determines whether type is a valid integer type in EmitC.
CmpPredicate
Copy of the enum from arith and index to allow the common integer range infrastructure to not depend ...
Include the generated interface declarations.
void registerConvertArithToEmitCInterface(DialectRegistry ®istry)
void populateEmitCSizeTTypeConversions(TypeConverter &converter)
const FrozenRewritePatternSet & patterns
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
void populateArithToEmitCPatterns(TypeConverter &typeConverter, RewritePatternSet &patterns)