18 #include "llvm/Support/FormatVariadic.h"
19 #include "llvm/Support/MathExtras.h"
23 #define GEN_PASS_DEF_ARITHEMULATEWIDEINT
24 #include "mlir/Dialect/Arith/Transforms/Passes.h.inc"
37 static std::pair<APInt, APInt>
getHalves(
const APInt &value,
38 unsigned newBitWidth) {
39 APInt low = value.extractBits(newBitWidth, 0);
40 APInt high = value.extractBits(newBitWidth, newBitWidth);
41 return {std::move(low), std::move(high)};
50 if (type.getShape().size() == 1)
51 return type.getElementType();
53 auto newShape = to_vector(type.getShape());
55 return VectorType::get(newShape, type.getElementType());
63 if (
auto intTy = type.
dyn_cast<IntegerType>()) {
66 auto vecTy = type.
cast<VectorType>();
70 return rewriter.
create<arith::ConstantOp>(loc, attr);
77 unsigned elementBitWidth = 0;
78 if (
auto intTy = type.
dyn_cast<IntegerType>())
79 elementBitWidth = intTy.getWidth();
81 elementBitWidth = type.
cast<VectorType>().getElementTypeBitWidth();
84 APInt(elementBitWidth, value));
96 assert(lastOffset < shape.back() &&
"Offset out of bounds");
99 if (shape.size() == 1)
100 return rewriter.
create<vector::ExtractOp>(loc, input, lastOffset);
103 offsets.back() = lastOffset;
104 auto sizes = llvm::to_vector(shape);
108 return rewriter.
create<vector::ExtractStridedSliceOp>(loc, input, offsets,
114 static std::pair<Value, Value>
131 assert(shape.size() >= 2 &&
"Expected vector with at list two dims");
132 assert(shape.back() == 1 &&
"Expected the last vector dim to be x1");
134 auto newVecTy = VectorType::get(shape.drop_back(), vecTy.getElementType());
135 return rewriter.
create<vector::ShapeCastOp>(loc, newVecTy, input);
147 auto newShape = llvm::to_vector(vecTy.getShape());
148 newShape.push_back(1);
149 auto newTy = VectorType::get(newShape, vecTy.getElementType());
150 return rewriter.
create<vector::ShapeCastOp>(loc, newTy, input);
158 int64_t lastOffset) {
160 assert(lastOffset < shape.back() &&
"Offset out of bounds");
164 return rewriter.
create<vector::InsertOp>(loc, source, dest, lastOffset);
167 offsets.back() = lastOffset;
169 return rewriter.
create<vector::InsertStridedSliceOp>(loc, source, dest,
180 Location loc, VectorType resultType,
184 assert(!resultShape.empty() &&
"Result expected to have dimentions");
185 assert(resultShape.back() ==
static_cast<int64_t
>(resultComponents.size()) &&
186 "Wrong number of result components");
204 matchAndRewrite(arith::ConstantOp op, OpAdaptor,
206 Type oldType = op.getType();
207 auto newType = getTypeConverter()->convertType(oldType).
cast<VectorType>();
208 unsigned newBitWidth = newType.getElementTypeBitWidth();
211 if (
auto intAttr = oldValue.
dyn_cast<IntegerAttr>()) {
212 auto [low, high] =
getHalves(intAttr.getValue(), newBitWidth);
220 getHalves(splatAttr.getSplatValue<APInt>(), newBitWidth);
221 int64_t numSplatElems = splatAttr.getNumElements();
223 values.reserve(numSplatElems * 2);
224 for (int64_t i = 0; i < numSplatElems; ++i) {
225 values.push_back(low);
226 values.push_back(high);
235 int64_t numElems = elemsAttr.getNumElements();
237 values.reserve(numElems * 2);
238 for (
const APInt &origVal : elemsAttr.getValues<APInt>()) {
239 auto [low, high] =
getHalves(origVal, newBitWidth);
240 values.push_back(std::move(low));
241 values.push_back(std::move(high));
250 "unhandled constant attribute");
262 matchAndRewrite(arith::AddIOp op, OpAdaptor adaptor,
265 auto newTy = getTypeConverter()
266 ->convertType(op.getType())
270 loc, llvm::formatv(
"unsupported type: {0}", op.getType()));
274 auto [lhsElem0, lhsElem1] =
276 auto [rhsElem0, rhsElem1] =
280 rewriter.
create<arith::AddUIExtendedOp>(loc, lhsElem0, rhsElem0);
282 rewriter.
create<arith::ExtUIOp>(loc, newElemTy, lowSum.getOverflow());
284 Value high0 = rewriter.
create<arith::AddIOp>(loc, overflowVal, lhsElem1);
285 Value high = rewriter.
create<arith::AddIOp>(loc, high0, rhsElem1);
299 template <
typename BinaryOp>
305 matchAndRewrite(BinaryOp op, OpAdaptor adaptor,
308 auto newTy = this->getTypeConverter()
309 ->convertType(op.getType())
310 .template dyn_cast_or_null<VectorType>();
313 loc, llvm::formatv(
"unsupported type: {0}", op.getType()));
315 auto [lhsElem0, lhsElem1] =
317 auto [rhsElem0, rhsElem1] =
320 Value resElem0 = rewriter.
create<BinaryOp>(loc, lhsElem0, rhsElem0);
321 Value resElem1 = rewriter.
create<BinaryOp>(loc, lhsElem1, rhsElem1);
335 static arith::CmpIPredicate toUnsignedPredicate(arith::CmpIPredicate pred) {
336 using P = arith::CmpIPredicate;
355 matchAndRewrite(arith::CmpIOp op, OpAdaptor adaptor,
358 auto inputTy = getTypeConverter()
359 ->convertType(op.getLhs().getType())
363 loc, llvm::formatv(
"unsupported type: {0}", op.getType()));
365 arith::CmpIPredicate highPred = adaptor.getPredicate();
366 arith::CmpIPredicate lowPred = toUnsignedPredicate(highPred);
368 auto [lhsElem0, lhsElem1] =
370 auto [rhsElem0, rhsElem1] =
374 rewriter.
create<arith::CmpIOp>(loc, lowPred, lhsElem0, rhsElem0);
376 rewriter.
create<arith::CmpIOp>(loc, highPred, lhsElem1, rhsElem1);
380 case arith::CmpIPredicate::eq: {
381 cmpResult = rewriter.
create<arith::AndIOp>(loc, lowCmp, highCmp);
384 case arith::CmpIPredicate::ne: {
385 cmpResult = rewriter.
create<arith::OrIOp>(loc, lowCmp, highCmp);
391 loc, arith::CmpIPredicate::eq, lhsElem1, rhsElem1);
393 rewriter.
create<arith::SelectOp>(loc, highEq, lowCmp, highCmp);
398 assert(cmpResult &&
"Unhandled case");
412 matchAndRewrite(arith::MulIOp op, OpAdaptor adaptor,
415 auto newTy = getTypeConverter()
416 ->convertType(op.getType())
420 loc, llvm::formatv(
"unsupported type: {0}", op.getType()));
422 auto [lhsElem0, lhsElem1] =
424 auto [rhsElem0, rhsElem1] =
431 rewriter.
create<arith::MulUIExtendedOp>(loc, lhsElem0, rhsElem0);
432 Value mulLowHi = rewriter.
create<arith::MulIOp>(loc, lhsElem0, rhsElem1);
433 Value mulHiLow = rewriter.
create<arith::MulIOp>(loc, lhsElem1, rhsElem0);
435 Value resLow = mulLowLow.getLow();
437 rewriter.
create<arith::AddIOp>(loc, mulLowLow.getHigh(), mulLowHi);
438 resHi = rewriter.
create<arith::AddIOp>(loc, resHi, mulHiLow);
455 matchAndRewrite(arith::ExtSIOp op, OpAdaptor adaptor,
458 auto newTy = getTypeConverter()
459 ->convertType(op.getType())
463 loc, llvm::formatv(
"unsupported type: {0}", op.getType()));
472 loc, newResultComponentTy, newOperand);
473 Value operandZeroCst =
476 loc, arith::CmpIPredicate::slt, extended, operandZeroCst);
478 rewriter.
create<arith::ExtSIOp>(loc, newResultComponentTy, signBit);
495 matchAndRewrite(arith::ExtUIOp op, OpAdaptor adaptor,
498 auto newTy = getTypeConverter()
499 ->convertType(op.getType())
503 loc, llvm::formatv(
"unsupported type: {0}", op.getType()));
511 loc, newResultComponentTy, newOperand);
523 template <
typename SourceOp, arith::CmpIPredicate CmpPred>
528 matchAndRewrite(SourceOp op,
typename SourceOp::Adaptor adaptor,
532 Type oldTy = op.getType();
533 auto newTy = this->getTypeConverter()
535 .template dyn_cast_or_null<VectorType>();
538 loc, llvm::formatv(
"unsupported type: {0}", op.getType()));
543 rewriter.
create<arith::CmpIOp>(loc, CmpPred, op.getLhs(), op.getRhs());
554 static bool isIndexOrIndexVector(
Type type) {
555 if (type.
isa<IndexType>())
558 if (
auto vectorTy = type.
dyn_cast<VectorType>())
559 if (vectorTy.getElementType().isa<IndexType>())
565 template <
typename CastOp>
570 matchAndRewrite(CastOp op,
typename CastOp::Adaptor adaptor,
572 Type resultType = op.getType();
573 if (!isIndexOrIndexVector(resultType))
577 Type inType = op.getIn().getType();
578 auto newInTy = this->getTypeConverter()
579 ->convertType(inType)
580 .template dyn_cast_or_null<VectorType>();
583 loc, llvm::formatv(
"unsupported type: {0}", inType));
593 template <
typename CastOp,
typename ExtensionOp>
598 matchAndRewrite(CastOp op,
typename CastOp::Adaptor adaptor,
600 Type inType = op.getIn().getType();
601 if (!isIndexOrIndexVector(inType))
605 auto *typeConverter =
606 this->
template getTypeConverter<arith::WideIntEmulationConverter>();
608 Type resultType = op.getType();
609 auto newTy = typeConverter->convertType(resultType)
610 .template dyn_cast_or_null<VectorType>();
613 loc, llvm::formatv(
"unsupported type: {0}", resultType));
617 rewriter.
getIntegerType(typeConverter->getMaxTargetIntBitWidth());
618 if (
auto vecTy = resultType.
dyn_cast<VectorType>())
619 narrowTy = VectorType::get(vecTy.getShape(), narrowTy);
623 Value underlyingVal =
624 rewriter.
create<CastOp>(loc, narrowTy, adaptor.getIn());
638 matchAndRewrite(arith::SelectOp op, OpAdaptor adaptor,
641 auto newTy = getTypeConverter()
642 ->convertType(op.getType())
646 loc, llvm::formatv(
"unsupported type: {0}", op.getType()));
648 auto [trueElem0, trueElem1] =
650 auto [falseElem0, falseElem1] =
655 rewriter.
create<arith::SelectOp>(loc, cond, trueElem0, falseElem0);
657 rewriter.
create<arith::SelectOp>(loc, cond, trueElem1, falseElem1);
673 matchAndRewrite(arith::ShLIOp op, OpAdaptor adaptor,
677 Type oldTy = op.getType();
682 loc, llvm::formatv(
"unsupported type: {0}", op.getType()));
686 unsigned newBitWidth = newTy.getElementTypeBitWidth();
688 auto [lhsElem0, lhsElem1] =
720 Value illegalElemShift = rewriter.
create<arith::CmpIOp>(
721 loc, arith::CmpIPredicate::uge, rhsElem0, elemBitWidth);
724 rewriter.
create<arith::ShLIOp>(loc, lhsElem0, rhsElem0);
725 Value resElem0 = rewriter.
create<arith::SelectOp>(loc, illegalElemShift,
726 zeroCst, shiftedElem0);
728 Value cappedShiftAmount = rewriter.
create<arith::SelectOp>(
729 loc, illegalElemShift, elemBitWidth, rhsElem0);
730 Value rightShiftAmount =
731 rewriter.
create<arith::SubIOp>(loc, elemBitWidth, cappedShiftAmount);
733 rewriter.
create<arith::ShRUIOp>(loc, lhsElem0, rightShiftAmount);
734 Value overshotShiftAmount =
735 rewriter.
create<arith::SubIOp>(loc, rhsElem0, elemBitWidth);
737 rewriter.
create<arith::ShLIOp>(loc, lhsElem0, overshotShiftAmount);
740 rewriter.
create<arith::ShLIOp>(loc, lhsElem1, rhsElem0);
741 Value resElem1High = rewriter.
create<arith::SelectOp>(
742 loc, illegalElemShift, zeroCst, shiftedElem1);
743 Value resElem1Low = rewriter.
create<arith::SelectOp>(
744 loc, illegalElemShift, shiftedLeft, shiftedRight);
746 rewriter.
create<arith::OrIOp>(loc, resElem1Low, resElem1High);
763 matchAndRewrite(arith::ShRUIOp op, OpAdaptor adaptor,
767 Type oldTy = op.getType();
772 loc, llvm::formatv(
"unsupported type: {0}", op.getType()));
776 unsigned newBitWidth = newTy.getElementTypeBitWidth();
778 auto [lhsElem0, lhsElem1] =
810 Value illegalElemShift = rewriter.
create<arith::CmpIOp>(
811 loc, arith::CmpIPredicate::uge, rhsElem0, elemBitWidth);
814 rewriter.
create<arith::ShRUIOp>(loc, lhsElem0, rhsElem0);
815 Value resElem0Low = rewriter.
create<arith::SelectOp>(loc, illegalElemShift,
816 zeroCst, shiftedElem0);
818 rewriter.
create<arith::ShRUIOp>(loc, lhsElem1, rhsElem0);
819 Value resElem1 = rewriter.
create<arith::SelectOp>(loc, illegalElemShift,
820 zeroCst, shiftedElem1);
822 Value cappedShiftAmount = rewriter.
create<arith::SelectOp>(
823 loc, illegalElemShift, elemBitWidth, rhsElem0);
824 Value leftShiftAmount =
825 rewriter.
create<arith::SubIOp>(loc, elemBitWidth, cappedShiftAmount);
827 rewriter.
create<arith::ShLIOp>(loc, lhsElem1, leftShiftAmount);
828 Value overshotShiftAmount =
829 rewriter.
create<arith::SubIOp>(loc, rhsElem0, elemBitWidth);
831 rewriter.
create<arith::ShRUIOp>(loc, lhsElem1, overshotShiftAmount);
833 Value resElem0High = rewriter.
create<arith::SelectOp>(
834 loc, illegalElemShift, shiftedRight, shiftedLeft);
836 rewriter.
create<arith::OrIOp>(loc, resElem0Low, resElem0High);
853 matchAndRewrite(arith::ShRSIOp op, OpAdaptor adaptor,
857 Type oldTy = op.getType();
862 loc, llvm::formatv(
"unsupported type: {0}", op.getType()));
868 int64_t origBitwidth = newTy.getElementTypeBitWidth() * 2;
875 loc, arith::CmpIPredicate::slt, lhsElem1, elemZero);
881 Value allSign = rewriter.
create<arith::ExtSIOp>(loc, oldTy, signBit);
884 Value numNonSignExtBits =
885 rewriter.
create<arith::SubIOp>(loc, maxShift, rhsElem0);
888 rewriter.
create<arith::ExtUIOp>(loc, oldTy, numNonSignExtBits);
890 rewriter.
create<arith::ShLIOp>(loc, allSign, numNonSignExtBits);
894 rewriter.
create<arith::ShRUIOp>(loc, op.getLhs(), op.getRhs());
895 Value shrsi = rewriter.
create<arith::OrIOp>(loc, shrui, signBits);
899 Value isNoop = rewriter.
create<arith::CmpIOp>(loc, arith::CmpIPredicate::eq,
917 matchAndRewrite(arith::TruncIOp op, OpAdaptor adaptor,
922 if (!getTypeConverter()->isLegal(op.getType()))
924 loc, llvm::formatv(
"unsupported truncation result type: {0}",
932 rewriter.
createOrFold<arith::TruncIOp>(loc, op.getType(), extracted);
946 matchAndRewrite(vector::PrintOp op, OpAdaptor adaptor,
957 struct EmulateWideIntPass final
958 : arith::impl::ArithEmulateWideIntBase<EmulateWideIntPass> {
959 using ArithEmulateWideIntBase::ArithEmulateWideIntBase;
961 void runOnOperation()
override {
962 if (!llvm::isPowerOf2_32(widestIntSupported) || widestIntSupported < 2) {
970 arith::WideIntEmulationConverter typeConverter(widestIntSupported);
972 target.addDynamicallyLegalOp<func::FuncOp>([&typeConverter](
Operation *op) {
973 return typeConverter.isLegal(cast<func::FuncOp>(op).getFunctionType());
975 auto opLegalCallback = [&typeConverter](
Operation *op) {
976 return typeConverter.isLegal(op);
978 target.addDynamicallyLegalOp<func::CallOp, func::ReturnOp>(opLegalCallback);
980 .addDynamicallyLegalDialect<arith::ArithDialect, vector::VectorDialect>(
997 unsigned widestIntSupportedByTarget)
998 : maxIntWidth(widestIntSupportedByTarget) {
999 assert(llvm::isPowerOf2_32(widestIntSupportedByTarget) &&
1000 "Only power-of-two integers with are supported");
1001 assert(widestIntSupportedByTarget >= 2 &&
"Integer type too narrow");
1007 addConversion([
this](IntegerType ty) -> std::optional<Type> {
1008 unsigned width = ty.getWidth();
1009 if (width <= maxIntWidth)
1013 if (width == 2 * maxIntWidth)
1014 return VectorType::get(2, IntegerType::get(ty.
getContext(), maxIntWidth));
1016 return std::nullopt;
1020 addConversion([
this](VectorType ty) -> std::optional<Type> {
1021 auto intTy = ty.getElementType().dyn_cast<IntegerType>();
1025 unsigned width = intTy.getWidth();
1026 if (width <= maxIntWidth)
1030 if (width == 2 * maxIntWidth) {
1031 auto newShape = to_vector(ty.getShape());
1032 newShape.push_back(2);
1033 return VectorType::get(newShape,
1034 IntegerType::get(ty.
getContext(), maxIntWidth));
1037 return std::nullopt;
1041 addConversion([
this](FunctionType ty) -> std::optional<Type> {
1046 return std::nullopt;
1050 return std::nullopt;
1052 return FunctionType::get(ty.
getContext(), inputs, results);
1059 populateFunctionOpInterfaceTypeConversionPattern<func::FuncOp>(patterns,
1067 ConvertConstant, ConvertCmpI, ConvertSelect, ConvertVectorPrint,
1069 ConvertAddI, ConvertMulI, ConvertShLI, ConvertShRSI, ConvertShRUI,
1070 ConvertMaxMin<arith::MaxUIOp, arith::CmpIPredicate::ugt>,
1071 ConvertMaxMin<arith::MaxSIOp, arith::CmpIPredicate::sgt>,
1072 ConvertMaxMin<arith::MinUIOp, arith::CmpIPredicate::ult>,
1073 ConvertMaxMin<arith::MinSIOp, arith::CmpIPredicate::slt>,
1075 ConvertBitwiseBinary<arith::AndIOp>, ConvertBitwiseBinary<arith::OrIOp>,
1076 ConvertBitwiseBinary<arith::XOrIOp>,
1078 ConvertExtSI, ConvertExtUI, ConvertTruncI,
1080 ConvertIndexCastIntToIndex<arith::IndexCastOp>,
1081 ConvertIndexCastIntToIndex<arith::IndexCastUIOp>,
1082 ConvertIndexCastIndexToInt<arith::IndexCastOp, arith::ExtSIOp>,
1083 ConvertIndexCastIndexToInt<arith::IndexCastUIOp, arith::ExtUIOp>>(
static ArrayRef< int64_t > getShape(Type type)
Returns the shape of the given type.
Attributes are known-constant values of operations.
U dyn_cast_or_null() const
IntegerAttr getIntegerAttr(Type type, int64_t value)
IntegerType getIntegerType(unsigned width)
This class implements a pattern rewriter for use with ConversionPatterns.
void replaceOp(Operation *op, ValueRange newValues) override
PatternRewriter hook for replacing the results of an operation.
LogicalResult notifyMatchFailure(Location loc, function_ref< void(Diagnostic &)> reasonCallback) override
PatternRewriter hook for notifying match failure reasons.
This class describes a specific conversion target.
An attribute that represents a reference to a dense vector or tensor object.
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
OpConversionPattern is a wrapper around ConversionPattern that allows for matching and rewriting agai...
typename SourceOp::Adaptor OpAdaptor
OpConversionPattern(MLIRContext *context, PatternBenefit benefit=1)
Operation is the basic unit of execution within MLIR.
MLIRContext * getContext()
Return the context this operation is associated with.
MLIRContext * getContext() const
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replaces the result op with a new op that is created without verification.
An attribute that represents a reference to a splat vector or tensor constant, meaning all of the ele...
void addConversion(FnT &&callback)
Register a conversion function.
LogicalResult convertTypes(TypeRange types, SmallVectorImpl< Type > &results)
Convert the given set of types, filling 'results' as necessary.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
U dyn_cast_or_null() const
MLIRContext * getContext() const
Return the MLIRContext in which this type was uniqued.
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
Converts integer types that are too wide for the target by splitting them in two halves and thus turn...
WideIntEmulationConverter(unsigned widestIntSupportedByTarget)
void populateArithWideIntEmulationPatterns(WideIntEmulationConverter &typeConverter, RewritePatternSet &patterns)
Adds patterns to emulate wide Arith and Function ops over integer types into supported ones.
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Include the generated interface declarations.
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
void populateCallOpTypeConversionPattern(RewritePatternSet &patterns, TypeConverter &converter)
Add a pattern to the given pattern list to convert the operand and result types of a CallOp with the ...
LogicalResult applyPartialConversion(ArrayRef< Operation * > ops, ConversionTarget &target, const FrozenRewritePatternSet &patterns, DenseSet< Operation * > *unconvertedOps=nullptr)
Below we define several entry points for operation conversion.
void populateReturnOpTypeConversionPattern(RewritePatternSet &patterns, TypeConverter &converter)
Add a pattern to the given pattern list to rewrite return ops to use operands that have been legalize...
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value.
This class represents an efficient way to signal success or failure.