30 class ArithConstantOpConversionPattern
36 matchAndRewrite(arith::ConstantOp arithConst,
37 arith::ConstantOp::Adaptor adaptor,
39 Type newTy = this->getTypeConverter()->convertType(arithConst.getType());
49 Type adaptIntegralTypeSignedness(
Type ty,
bool needsUnsigned) {
50 if (isa<IntegerType>(ty)) {
52 auto signedness = needsUnsigned
53 ? IntegerType::SignednessSemantics::Unsigned
59 if (isa<emitc::SizeTType>(ty) != needsUnsigned) {
78 matchAndRewrite(arith::CmpFOp op, OpAdaptor adaptor,
81 if (!isa<FloatType>(adaptor.getRhs().getType())) {
83 "cmpf currently only supported on "
84 "floats, not tensors/vectors thereof");
87 bool unordered =
false;
89 switch (op.getPredicate()) {
90 case arith::CmpFPredicate::AlwaysFalse: {
91 auto constant = rewriter.
create<emitc::ConstantOp>(
97 case arith::CmpFPredicate::OEQ:
99 predicate = emitc::CmpPredicate::eq;
101 case arith::CmpFPredicate::OGT:
103 predicate = emitc::CmpPredicate::gt;
105 case arith::CmpFPredicate::OGE:
107 predicate = emitc::CmpPredicate::ge;
109 case arith::CmpFPredicate::OLT:
111 predicate = emitc::CmpPredicate::lt;
113 case arith::CmpFPredicate::OLE:
115 predicate = emitc::CmpPredicate::le;
117 case arith::CmpFPredicate::ONE:
119 predicate = emitc::CmpPredicate::ne;
121 case arith::CmpFPredicate::ORD: {
123 auto cmp = createCheckIsOrdered(rewriter, op.getLoc(), adaptor.getLhs(),
128 case arith::CmpFPredicate::UEQ:
130 predicate = emitc::CmpPredicate::eq;
132 case arith::CmpFPredicate::UGT:
134 predicate = emitc::CmpPredicate::gt;
136 case arith::CmpFPredicate::UGE:
138 predicate = emitc::CmpPredicate::ge;
140 case arith::CmpFPredicate::ULT:
142 predicate = emitc::CmpPredicate::lt;
144 case arith::CmpFPredicate::ULE:
146 predicate = emitc::CmpPredicate::le;
148 case arith::CmpFPredicate::UNE:
150 predicate = emitc::CmpPredicate::ne;
152 case arith::CmpFPredicate::UNO: {
154 auto cmp = createCheckIsUnordered(rewriter, op.getLoc(), adaptor.getLhs(),
159 case arith::CmpFPredicate::AlwaysTrue: {
160 auto constant = rewriter.
create<emitc::ConstantOp>(
170 rewriter.
create<emitc::CmpOp>(op.getLoc(), op.getType(), predicate,
171 adaptor.getLhs(), adaptor.getRhs());
175 auto isUnordered = createCheckIsUnordered(
176 rewriter, op.getLoc(), adaptor.getLhs(), adaptor.getRhs());
178 isUnordered, cmpResult);
182 auto isOrdered = createCheckIsOrdered(rewriter, op.getLoc(),
183 adaptor.getLhs(), adaptor.getRhs());
185 isOrdered, cmpResult);
192 Value operand)
const {
194 return rewriter.
create<emitc::CmpOp>(
195 loc, rewriter.
getI1Type(), emitc::CmpPredicate::ne, operand, operand);
200 Value operand)
const {
202 return rewriter.
create<emitc::CmpOp>(
203 loc, rewriter.
getI1Type(), emitc::CmpPredicate::eq, operand, operand);
210 auto firstIsNaN = isNaN(rewriter, loc, first);
211 auto secondIsNaN = isNaN(rewriter, loc, second);
213 firstIsNaN, secondIsNaN);
220 auto firstIsNotNaN = isNotNaN(rewriter, loc, first);
221 auto secondIsNotNaN = isNotNaN(rewriter, loc, second);
223 firstIsNotNaN, secondIsNotNaN);
231 bool needsUnsignedCmp(arith::CmpIPredicate pred)
const {
233 case arith::CmpIPredicate::eq:
234 case arith::CmpIPredicate::ne:
235 case arith::CmpIPredicate::slt:
236 case arith::CmpIPredicate::sle:
237 case arith::CmpIPredicate::sgt:
238 case arith::CmpIPredicate::sge:
240 case arith::CmpIPredicate::ult:
241 case arith::CmpIPredicate::ule:
242 case arith::CmpIPredicate::ugt:
243 case arith::CmpIPredicate::uge:
246 llvm_unreachable(
"unknown cmpi predicate kind");
251 case arith::CmpIPredicate::eq:
252 return emitc::CmpPredicate::eq;
253 case arith::CmpIPredicate::ne:
254 return emitc::CmpPredicate::ne;
255 case arith::CmpIPredicate::slt:
256 case arith::CmpIPredicate::ult:
257 return emitc::CmpPredicate::lt;
258 case arith::CmpIPredicate::sle:
259 case arith::CmpIPredicate::ule:
260 return emitc::CmpPredicate::le;
261 case arith::CmpIPredicate::sgt:
262 case arith::CmpIPredicate::ugt:
263 return emitc::CmpPredicate::gt;
264 case arith::CmpIPredicate::sge:
265 case arith::CmpIPredicate::uge:
266 return emitc::CmpPredicate::ge;
268 llvm_unreachable(
"unknown cmpi predicate kind");
272 matchAndRewrite(arith::CmpIOp op, OpAdaptor adaptor,
275 Type type = adaptor.getLhs().getType();
278 op,
"expected integer or size_t/ssize_t/ptrdiff_t type");
281 bool needsUnsigned = needsUnsignedCmp(op.getPredicate());
284 Type arithmeticType = adaptIntegralTypeSignedness(type, needsUnsigned);
285 Value lhs = adaptValueType(adaptor.getLhs(), rewriter, arithmeticType);
286 Value rhs = adaptValueType(adaptor.getRhs(), rewriter, arithmeticType);
298 matchAndRewrite(arith::NegFOp op, OpAdaptor adaptor,
301 auto adaptedOp = adaptor.getOperand();
302 auto adaptedOpType = adaptedOp.getType();
304 if (isa<TensorType>(adaptedOpType) || isa<VectorType>(adaptedOpType)) {
307 "negf currently only supports scalar types, not vectors or tensors");
312 op.getLoc(),
"floating-point type is not supported by EmitC");
321 template <
typename ArithOp,
bool castToUn
signed>
327 matchAndRewrite(ArithOp op,
typename ArithOp::Adaptor adaptor,
330 Type opReturnType = this->getTypeConverter()->convertType(op.getType());
331 if (!opReturnType || !(isa<IntegerType>(opReturnType) ||
334 op,
"expected integer or size_t/ssize_t/ptrdiff_t result type");
336 if (adaptor.getOperands().size() != 1) {
338 op,
"CastConversion only supports unary ops");
341 Type operandType = adaptor.getIn().getType();
342 if (!operandType || !(isa<IntegerType>(operandType) ||
345 op,
"expected integer or size_t/ssize_t/ptrdiff_t operand type");
348 if (operandType.
isInteger(1) && !castToUnsigned)
350 "operation not supported on i1 type");
359 auto constOne = rewriter.
create<emitc::ConstantOp>(
360 op.getLoc(), operandType, rewriter.
getOneAttr(attrType));
361 auto oneAndOperand = rewriter.
create<emitc::BitwiseAndOp>(
362 op.getLoc(), operandType, adaptor.getIn(), constOne);
369 (isa<IntegerType>(operandType) && isa<IntegerType>(opReturnType) &&
372 bool doUnsigned = castToUnsigned || isTruncation;
376 Type castDestType = adaptIntegralTypeSignedness(opReturnType, doUnsigned);
379 Type castSrcType = adaptIntegralTypeSignedness(operandType, doUnsigned);
380 Value actualOp = adaptValueType(adaptor.getIn(), rewriter, castSrcType);
383 auto cast = rewriter.template create<emitc::CastOp>(op.getLoc(),
384 castDestType, actualOp);
387 auto result = adaptValueType(cast, rewriter, opReturnType);
394 template <
typename ArithOp>
395 class UnsignedCastConversion :
public CastConversion<ArithOp, true> {
396 using CastConversion<ArithOp, true>::CastConversion;
399 template <
typename ArithOp>
400 class SignedCastConversion :
public CastConversion<ArithOp, false> {
401 using CastConversion<ArithOp, false>::CastConversion;
404 template <
typename ArithOp,
typename EmitCOp>
410 matchAndRewrite(ArithOp arithOp,
typename ArithOp::Adaptor adaptor,
413 Type newTy = this->getTypeConverter()->convertType(arithOp.getType());
416 "converting result type failed");
417 rewriter.template replaceOpWithNewOp<EmitCOp>(arithOp, newTy,
418 adaptor.getOperands());
424 template <
class ArithOp,
class EmitCOp>
430 matchAndRewrite(ArithOp uiBinOp,
typename ArithOp::Adaptor adaptor,
432 Type newRetTy = this->getTypeConverter()->convertType(uiBinOp.getType());
435 "converting result type failed");
436 if (!isa<IntegerType>(newRetTy)) {
440 adaptIntegralTypeSignedness(newRetTy,
true);
443 "converting result type failed");
444 Value lhsAdapted = adaptValueType(uiBinOp.getLhs(), rewriter, unsignedType);
445 Value rhsAdapted = adaptValueType(uiBinOp.getRhs(), rewriter, unsignedType);
448 rewriter.
create<EmitCOp>(uiBinOp.getLoc(), unsignedType,
450 Value resultAdapted = adaptValueType(newDivOp, rewriter, newRetTy);
451 rewriter.
replaceOp(uiBinOp, resultAdapted);
456 template <
typename ArithOp,
typename EmitCOp>
462 matchAndRewrite(ArithOp op,
typename ArithOp::Adaptor adaptor,
465 Type type = this->getTypeConverter()->convertType(op.getType());
468 op,
"expected integer or size_t/ssize_t/ptrdiff_t type");
476 Type arithmeticType = type;
478 !bitEnumContainsAll(op.getOverflowFlags(),
479 arith::IntegerOverflowFlags::nsw)) {
486 Value lhs = adaptValueType(adaptor.getLhs(), rewriter, arithmeticType);
487 Value rhs = adaptValueType(adaptor.getRhs(), rewriter, arithmeticType);
489 Value arithmeticResult = rewriter.template create<EmitCOp>(
490 op.getLoc(), arithmeticType, lhs, rhs);
492 Value result = adaptValueType(arithmeticResult, rewriter, type);
499 template <
typename ArithOp,
typename EmitCOp>
505 matchAndRewrite(ArithOp op,
typename ArithOp::Adaptor adaptor,
508 Type type = this->getTypeConverter()->convertType(op.getType());
509 if (!isa_and_nonnull<IntegerType>(type)) {
512 "expected integer type, vector/tensor support not yet implemented");
523 Type arithmeticType =
524 adaptIntegralTypeSignedness(type,
true);
526 Value lhs = adaptValueType(adaptor.getLhs(), rewriter, arithmeticType);
527 Value rhs = adaptValueType(adaptor.getRhs(), rewriter, arithmeticType);
529 Value arithmeticResult = rewriter.template create<EmitCOp>(
530 op.getLoc(), arithmeticType, lhs, rhs);
532 Value result = adaptValueType(arithmeticResult, rewriter, type);
539 template <
typename ArithOp,
typename EmitCOp,
bool isUn
signedOp>
545 matchAndRewrite(ArithOp op,
typename ArithOp::Adaptor adaptor,
548 Type type = this->getTypeConverter()->convertType(op.getType());
551 op,
"expected integer or size_t/ssize_t/ptrdiff_t type");
558 Type arithmeticType = adaptIntegralTypeSignedness(type, isUnsignedOp);
560 Value lhs = adaptValueType(adaptor.getLhs(), rewriter, arithmeticType);
562 Type rhsType = adaptIntegralTypeSignedness(adaptor.getRhs().getType(),
564 Value rhs = adaptValueType(adaptor.getRhs(), rewriter, rhsType);
571 emitc::CallOpaqueOp sizeOfCall = rewriter.
create<emitc::CallOpaqueOp>(
573 width = rewriter.
create<emitc::MulOp>(op.getLoc(), rhsType, eight,
574 sizeOfCall.getResult(0));
576 width = rewriter.
create<emitc::ConstantOp>(
577 op.getLoc(), rhsType,
582 op.getLoc(), rewriter.
getI1Type(), emitc::CmpPredicate::lt, rhs, width);
586 op.getLoc(), arithmeticType,
587 (isa<IntegerType>(arithmeticType)
591 emitc::ExpressionOp ternary = rewriter.
create<emitc::ExpressionOp>(
592 op.getLoc(), arithmeticType,
false);
593 Block &bodyBlock = ternary.getBodyRegion().emplaceBlock();
596 Value arithmeticResult =
597 rewriter.
create<EmitCOp>(op.getLoc(), arithmeticType, lhs, rhs);
598 Value resultOrPoison = rewriter.
create<emitc::ConditionalOp>(
599 op.getLoc(), arithmeticType, excessCheck, arithmeticResult, poison);
600 rewriter.
create<emitc::YieldOp>(op.getLoc(), resultOrPoison);
603 Value result = adaptValueType(ternary, rewriter, type);
610 template <
typename ArithOp,
typename EmitCOp>
611 class SignedShiftOpConversion final
612 :
public ShiftOpConversion<ArithOp, EmitCOp, false> {
613 using ShiftOpConversion<ArithOp, EmitCOp, false>::ShiftOpConversion;
616 template <
typename ArithOp,
typename EmitCOp>
617 class UnsignedShiftOpConversion final
618 :
public ShiftOpConversion<ArithOp, EmitCOp, true> {
619 using ShiftOpConversion<ArithOp, EmitCOp, true>::ShiftOpConversion;
627 matchAndRewrite(arith::SelectOp selectOp, OpAdaptor adaptor,
630 Type dstType = getTypeConverter()->convertType(selectOp.getType());
634 if (!adaptor.getCondition().getType().isInteger(1))
637 "can only be converted if condition is a scalar of type i1");
640 adaptor.getOperands());
647 template <
typename CastOp>
654 matchAndRewrite(CastOp castOp,
typename CastOp::Adaptor adaptor,
657 Type operandType = adaptor.getIn().getType();
660 "unsupported cast source type");
662 Type dstType = this->getTypeConverter()->convertType(castOp.getType());
670 "unsupported cast destination type");
674 Type actualResultType = dstType;
675 if (isa<arith::FPToUIOp>(castOp)) {
682 castOp.getLoc(), actualResultType, adaptor.getOperands());
684 if (isa<arith::FPToUIOp>(castOp)) {
685 result = rewriter.
create<emitc::CastOp>(castOp.getLoc(), dstType, result);
694 template <
typename CastOp>
701 matchAndRewrite(CastOp castOp,
typename CastOp::Adaptor adaptor,
704 Type operandType = adaptor.getIn().getType();
707 "unsupported cast source type");
709 Type dstType = this->getTypeConverter()->convertType(castOp.getType());
715 "unsupported cast destination type");
719 Type actualOperandType = operandType;
720 if (isa<arith::UIToFPOp>(castOp)) {
725 Value fpCastOperand = adaptor.getIn();
726 if (actualOperandType != operandType) {
727 fpCastOperand = rewriter.template create<emitc::CastOp>(
728 castOp.getLoc(), actualOperandType, fpCastOperand);
750 ArithConstantOpConversionPattern,
751 ArithOpConversion<arith::AddFOp, emitc::AddOp>,
752 ArithOpConversion<arith::DivFOp, emitc::DivOp>,
753 ArithOpConversion<arith::DivSIOp, emitc::DivOp>,
754 ArithOpConversion<arith::MulFOp, emitc::MulOp>,
755 ArithOpConversion<arith::RemSIOp, emitc::RemOp>,
756 ArithOpConversion<arith::SubFOp, emitc::SubOp>,
757 BinaryUIOpConversion<arith::DivUIOp, emitc::DivOp>,
758 BinaryUIOpConversion<arith::RemUIOp, emitc::RemOp>,
759 IntegerOpConversion<arith::AddIOp, emitc::AddOp>,
760 IntegerOpConversion<arith::MulIOp, emitc::MulOp>,
761 IntegerOpConversion<arith::SubIOp, emitc::SubOp>,
762 BitwiseOpConversion<arith::AndIOp, emitc::BitwiseAndOp>,
763 BitwiseOpConversion<arith::OrIOp, emitc::BitwiseOrOp>,
764 BitwiseOpConversion<arith::XOrIOp, emitc::BitwiseXorOp>,
765 UnsignedShiftOpConversion<arith::ShLIOp, emitc::BitwiseLeftShiftOp>,
766 SignedShiftOpConversion<arith::ShRSIOp, emitc::BitwiseRightShiftOp>,
767 UnsignedShiftOpConversion<arith::ShRUIOp, emitc::BitwiseRightShiftOp>,
773 UnsignedCastConversion<arith::TruncIOp>,
774 SignedCastConversion<arith::ExtSIOp>,
775 UnsignedCastConversion<arith::ExtUIOp>,
776 SignedCastConversion<arith::IndexCastOp>,
777 UnsignedCastConversion<arith::IndexCastUIOp>,
778 ItoFCastOpConversion<arith::SIToFPOp>,
779 ItoFCastOpConversion<arith::UIToFPOp>,
780 FtoICastOpConversion<arith::FPToSIOp>,
781 FtoICastOpConversion<arith::FPToUIOp>
782 >(typeConverter, ctx);
Block represents an ordered list of Operations.
IntegerAttr getIndexAttr(int64_t value)
IntegerAttr getIntegerAttr(Type type, int64_t value)
IntegerType getIntegerType(unsigned width)
BoolAttr getBoolAttr(bool value)
TypedAttr getOneAttr(Type type)
This class implements a pattern rewriter for use with ConversionPatterns.
void replaceOp(Operation *op, ValueRange newValues) override
Replace the given operation with the new values.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
Block::iterator getInsertionPoint() const
Returns the current insertion point of the builder.
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
OpConversionPattern is a wrapper around ConversionPattern that allows for matching and rewriting agai...
OpConversionPattern(MLIRContext *context, PatternBenefit benefit=1)
MLIRContext * getContext() const
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
MLIRContext * getContext() const
Return the MLIRContext in which this type was uniqued.
bool isSignedInteger() const
Return true if this is a signed integer type (with the specified width).
bool isSignlessInteger() const
Return true if this is a signless integer type (with the specified width).
bool isUnsignedInteger() const
Return true if this is an unsigned integer type (with the specified width).
bool isInteger() const
Return true if this is an integer type (with the specified width).
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Location getLoc() const
Return the location of this value.
bool isSupportedFloatType(mlir::Type type)
Determines whether type is a valid floating-point type in EmitC.
bool isPointerWideType(mlir::Type type)
Determines whether type is a emitc.size_t/ssize_t type.
bool isSupportedIntegerType(mlir::Type type)
Determines whether type is a valid integer type in EmitC.
CmpPredicate
Copy of the enum from arith and index to allow the common integer range infrastructure to not depend ...
Include the generated interface declarations.
void populateEmitCSizeTTypeConversions(TypeConverter &converter)
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
void populateArithToEmitCPatterns(TypeConverter &typeConverter, RewritePatternSet &patterns)