25 #include "llvm/ADT/STLExtras.h"
26 #include "llvm/ADT/SmallVector.h"
31 #define GEN_PASS_DEF_ARITHINTNARROWING
32 #include "mlir/Dialect/Arith/Transforms/Passes.h.inc"
42 template <
typename SourceOp>
43 struct NarrowingPattern : OpRewritePattern<SourceOp> {
44 NarrowingPattern(MLIRContext *ctx,
const ArithIntNarrowingOptions &
options,
45 PatternBenefit benefit = 1)
46 : OpRewritePattern<SourceOp>(ctx, benefit),
47 supportedBitwidths(
options.bitwidthsSupported.begin(),
48 options.bitwidthsSupported.end()) {
49 assert(!supportedBitwidths.empty() &&
"Invalid options");
50 assert(!llvm::is_contained(supportedBitwidths, 0) &&
"Invalid bitwidth");
51 llvm::sort(supportedBitwidths);
55 getNarrowestCompatibleBitwidth(
unsigned bitsRequired)
const {
56 for (
unsigned candidate : supportedBitwidths)
57 if (candidate >= bitsRequired)
64 FailureOr<Type> getNarrowType(
unsigned bitsRequired,
Type origTy)
const {
66 FailureOr<unsigned> bestBitwidth =
67 getNarrowestCompatibleBitwidth(bitsRequired);
68 if (failed(bestBitwidth))
72 if (!isa<IntegerType>(elemTy))
76 if (newElemTy == elemTy)
82 if (
auto shapedTy = dyn_cast<ShapedType>(origTy))
83 if (dyn_cast<IntegerType>(shapedTy.getElementType()))
84 return shapedTy.clone(shapedTy.getShape(), newElemTy);
95 FailureOr<unsigned> calculateBitsRequired(
Type type) {
98 return intTy.getWidth();
103 enum class ExtensionKind { Sign, Zero };
114 static FailureOr<ExtensionOp> from(Operation *op) {
115 if (dyn_cast_or_null<arith::ExtSIOp>(op))
116 return ExtensionOp{op, ExtensionKind::Sign};
117 if (dyn_cast_or_null<arith::ExtUIOp>(op))
118 return ExtensionOp{op, ExtensionKind::Zero};
123 ExtensionOp(
const ExtensionOp &) =
default;
124 ExtensionOp &operator=(
const ExtensionOp &) =
default;
127 Operation *recreate(PatternRewriter &rewriter, Location loc,
Type newType,
129 if (kind == ExtensionKind::Sign)
130 return rewriter.create<arith::ExtSIOp>(loc, newType, in);
132 return rewriter.create<arith::ExtUIOp>(loc, newType, in);
136 void recreateAndReplace(PatternRewriter &rewriter, Operation *toReplace,
138 assert(toReplace->getNumResults() == 1);
139 Type newType = toReplace->getResult(0).getType();
140 Operation *newOp = recreate(rewriter, toReplace->getLoc(), newType, in);
141 rewriter.replaceOp(toReplace, newOp->getResult(0));
144 ExtensionKind getKind() {
return kind; }
146 Value getResult() {
return op->getResult(0); }
147 Value getIn() {
return op->getOperand(0); }
151 Type getInType() {
return getIn().getType(); }
155 ExtensionOp(Operation *op, ExtensionKind kind) : op(op), kind(kind) {
157 assert((isa<arith::ExtSIOp, arith::ExtUIOp>(op)) &&
"Not an extension op");
159 Operation *op =
nullptr;
160 ExtensionKind kind = {};
164 unsigned calculateBitsRequired(
const APInt &value,
165 ExtensionKind lookThroughExtension) {
168 if (lookThroughExtension == ExtensionKind::Zero)
169 return std::max(value.getActiveBits(), 1u);
172 if (value.isNonNegative())
173 return value.getActiveBits() + 1;
176 if (value.isMinSignedValue())
177 return value.getBitWidth();
181 return value.getBitWidth() - value.getNumSignBits() + 1;
187 FailureOr<unsigned> calculateBitsRequired(
Value value,
188 ExtensionKind lookThroughExtension) {
191 if (
auto intAttr = dyn_cast<IntegerAttr>(attr))
192 return calculateBitsRequired(intAttr.getValue(), lookThroughExtension);
194 if (
auto elemsAttr = dyn_cast<DenseElementsAttr>(attr)) {
195 if (elemsAttr.getElementType().isIntOrIndex()) {
196 if (elemsAttr.isSplat())
197 return calculateBitsRequired(elemsAttr.getSplatValue<APInt>(),
198 lookThroughExtension);
200 unsigned maxBits = 1;
201 for (
const APInt &elemValue : elemsAttr.getValues<APInt>())
203 maxBits, calculateBitsRequired(elemValue, lookThroughExtension));
209 if (lookThroughExtension == ExtensionKind::Sign) {
210 if (
auto sext = value.getDefiningOp<arith::ExtSIOp>())
211 return calculateBitsRequired(sext.getIn().getType());
212 }
else if (lookThroughExtension == ExtensionKind::Zero) {
213 if (
auto zext = value.getDefiningOp<arith::ExtUIOp>())
214 return calculateBitsRequired(
zext.getIn().getType());
218 return calculateBitsRequired(value.getType());
233 template <
typename BinaryOp>
234 struct BinaryOpNarrowingPattern : NarrowingPattern<BinaryOp> {
235 using NarrowingPattern<BinaryOp>::NarrowingPattern;
240 virtual unsigned getResultBitsProduced(
unsigned operandBits)
const = 0;
244 virtual bool isSupported(ExtensionOp)
const {
return true; }
246 LogicalResult matchAndRewrite(BinaryOp op,
247 PatternRewriter &rewriter)
const final {
248 Type origTy = op.getType();
249 FailureOr<unsigned> resultBits = calculateBitsRequired(origTy);
250 if (failed(resultBits))
255 FailureOr<ExtensionOp> ext = ExtensionOp::from(op.getLhs().getDefiningOp());
256 if (failed(ext) || !isSupported(*ext))
259 FailureOr<unsigned> lhsBitsRequired =
260 calculateBitsRequired(ext->getIn(), ext->getKind());
261 if (failed(lhsBitsRequired) || *lhsBitsRequired >= *resultBits)
264 FailureOr<unsigned> rhsBitsRequired =
265 calculateBitsRequired(op.getRhs(), ext->getKind());
266 if (failed(rhsBitsRequired) || *rhsBitsRequired >= *resultBits)
271 unsigned commonBitsRequired =
272 getResultBitsProduced(
std::max(*lhsBitsRequired, *rhsBitsRequired));
273 FailureOr<Type> narrowTy = this->getNarrowType(commonBitsRequired, origTy);
274 if (failed(narrowTy) || calculateBitsRequired(*narrowTy) >= *resultBits)
277 Location loc = op.getLoc();
279 rewriter.createOrFold<arith::TruncIOp>(loc, *narrowTy, op.getLhs());
281 rewriter.createOrFold<arith::TruncIOp>(loc, *narrowTy, op.getRhs());
282 Value newAdd = rewriter.create<BinaryOp>(loc, newLhs, newRhs);
283 ext->recreateAndReplace(rewriter, op, newAdd);
292 struct AddIPattern final : BinaryOpNarrowingPattern<arith::AddIOp> {
293 using BinaryOpNarrowingPattern::BinaryOpNarrowingPattern;
297 unsigned getResultBitsProduced(
unsigned operandBits)
const override {
298 return operandBits + 1;
306 struct SubIPattern final : BinaryOpNarrowingPattern<arith::SubIOp> {
307 using BinaryOpNarrowingPattern::BinaryOpNarrowingPattern;
310 bool isSupported(ExtensionOp ext)
const override {
311 return ext.getKind() == ExtensionKind::Sign;
316 unsigned getResultBitsProduced(
unsigned operandBits)
const override {
317 return operandBits + 1;
325 struct MulIPattern final : BinaryOpNarrowingPattern<arith::MulIOp> {
326 using BinaryOpNarrowingPattern::BinaryOpNarrowingPattern;
330 unsigned getResultBitsProduced(
unsigned operandBits)
const override {
331 return 2 * operandBits;
339 struct DivSIPattern final : BinaryOpNarrowingPattern<arith::DivSIOp> {
340 using BinaryOpNarrowingPattern::BinaryOpNarrowingPattern;
343 bool isSupported(ExtensionOp ext)
const override {
344 return ext.getKind() == ExtensionKind::Sign;
349 unsigned getResultBitsProduced(
unsigned operandBits)
const override {
350 return operandBits + 1;
358 struct DivUIPattern final : BinaryOpNarrowingPattern<arith::DivUIOp> {
359 using BinaryOpNarrowingPattern::BinaryOpNarrowingPattern;
362 bool isSupported(ExtensionOp ext)
const override {
363 return ext.getKind() == ExtensionKind::Zero;
367 unsigned getResultBitsProduced(
unsigned operandBits)
const override {
376 template <
typename MinMaxOp, ExtensionKind Kind>
377 struct MinMaxPattern final : BinaryOpNarrowingPattern<MinMaxOp> {
378 using BinaryOpNarrowingPattern<MinMaxOp>::BinaryOpNarrowingPattern;
380 bool isSupported(ExtensionOp ext)
const override {
381 return ext.getKind() ==
Kind;
386 unsigned getResultBitsProduced(
unsigned operandBits)
const override {
390 using MaxSIPattern = MinMaxPattern<arith::MaxSIOp, ExtensionKind::Sign>;
391 using MaxUIPattern = MinMaxPattern<arith::MaxUIOp, ExtensionKind::Zero>;
392 using MinSIPattern = MinMaxPattern<arith::MinSIOp, ExtensionKind::Sign>;
393 using MinUIPattern = MinMaxPattern<arith::MinUIOp, ExtensionKind::Zero>;
399 template <
typename IToFPOp, ExtensionKind Extension>
400 struct IToFPPattern final : NarrowingPattern<IToFPOp> {
401 using NarrowingPattern<IToFPOp>::NarrowingPattern;
403 LogicalResult matchAndRewrite(IToFPOp op,
404 PatternRewriter &rewriter)
const override {
405 FailureOr<unsigned> narrowestWidth =
406 calculateBitsRequired(op.getIn(), Extension);
407 if (failed(narrowestWidth))
410 FailureOr<Type> narrowTy =
411 this->getNarrowType(*narrowestWidth, op.getIn().getType());
412 if (failed(narrowTy))
415 Value newIn = rewriter.createOrFold<arith::TruncIOp>(op.getLoc(), *narrowTy,
417 rewriter.replaceOpWithNewOp<IToFPOp>(op, op.getType(), newIn);
421 using SIToFPPattern = IToFPPattern<arith::SIToFPOp, ExtensionKind::Sign>;
422 using UIToFPPattern = IToFPPattern<arith::UIToFPOp, ExtensionKind::Zero>;
431 template <
typename CastOp, ExtensionKind Kind>
432 struct IndexCastPattern final : NarrowingPattern<CastOp> {
433 using NarrowingPattern<CastOp>::NarrowingPattern;
435 LogicalResult matchAndRewrite(CastOp op,
436 PatternRewriter &rewriter)
const override {
437 Value in = op.getIn();
439 if (!isa<IndexType>(in.getType()))
456 assert(*lb <= *ub &&
"Invalid bounds");
457 unsigned lbBitsRequired = calculateBitsRequired(APInt(64, *lb),
Kind);
458 unsigned ubBitsRequired = calculateBitsRequired(APInt(64, *ub),
Kind);
459 unsigned bitsRequired =
std::max(lbBitsRequired, ubBitsRequired);
461 IntegerType resultTy = cast<IntegerType>(op.getType());
462 if (resultTy.getWidth() <= bitsRequired)
465 FailureOr<Type> narrowTy = this->getNarrowType(bitsRequired, resultTy);
466 if (failed(narrowTy))
469 Value newCast = rewriter.create<CastOp>(op.getLoc(), *narrowTy, op.getIn());
471 if (
Kind == ExtensionKind::Sign)
472 rewriter.replaceOpWithNewOp<arith::ExtSIOp>(op, resultTy, newCast);
474 rewriter.replaceOpWithNewOp<arith::ExtUIOp>(op, resultTy, newCast);
478 using IndexCastSIPattern =
479 IndexCastPattern<arith::IndexCastOp, ExtensionKind::Sign>;
480 using IndexCastUIPattern =
481 IndexCastPattern<arith::IndexCastUIOp, ExtensionKind::Zero>;
487 struct ExtensionOverBroadcast final : NarrowingPattern<vector::BroadcastOp> {
488 using NarrowingPattern::NarrowingPattern;
490 LogicalResult matchAndRewrite(vector::BroadcastOp op,
491 PatternRewriter &rewriter)
const override {
492 FailureOr<ExtensionOp> ext =
493 ExtensionOp::from(op.getSource().getDefiningOp());
497 VectorType origTy = op.getResultVectorType();
499 origTy.cloneWith(origTy.getShape(), ext->getInElementType());
501 rewriter.create<vector::BroadcastOp>(op.getLoc(), newTy, ext->getIn());
502 ext->recreateAndReplace(rewriter, op, newBroadcast);
507 struct ExtensionOverExtract final : NarrowingPattern<vector::ExtractOp> {
508 using NarrowingPattern::NarrowingPattern;
510 LogicalResult matchAndRewrite(vector::ExtractOp op,
511 PatternRewriter &rewriter)
const override {
512 FailureOr<ExtensionOp> ext =
513 ExtensionOp::from(op.getVector().getDefiningOp());
517 Value newExtract = rewriter.create<vector::ExtractOp>(
518 op.getLoc(), ext->getIn(), op.getMixedPosition());
519 ext->recreateAndReplace(rewriter, op, newExtract);
524 struct ExtensionOverExtractElement final
525 : NarrowingPattern<vector::ExtractElementOp> {
526 using NarrowingPattern::NarrowingPattern;
528 LogicalResult matchAndRewrite(vector::ExtractElementOp op,
529 PatternRewriter &rewriter)
const override {
530 FailureOr<ExtensionOp> ext =
531 ExtensionOp::from(op.getVector().getDefiningOp());
535 Value newExtract = rewriter.create<vector::ExtractElementOp>(
536 op.getLoc(), ext->getIn(), op.getPosition());
537 ext->recreateAndReplace(rewriter, op, newExtract);
542 struct ExtensionOverExtractStridedSlice final
543 : NarrowingPattern<vector::ExtractStridedSliceOp> {
544 using NarrowingPattern::NarrowingPattern;
546 LogicalResult matchAndRewrite(vector::ExtractStridedSliceOp op,
547 PatternRewriter &rewriter)
const override {
548 FailureOr<ExtensionOp> ext =
549 ExtensionOp::from(op.getVector().getDefiningOp());
553 VectorType origTy = op.getType();
554 VectorType extractTy =
555 origTy.cloneWith(origTy.getShape(), ext->getInElementType());
556 Value newExtract = rewriter.create<vector::ExtractStridedSliceOp>(
557 op.getLoc(), extractTy, ext->getIn(), op.getOffsets(), op.getSizes(),
559 ext->recreateAndReplace(rewriter, op, newExtract);
565 template <
typename InsertionOp>
566 struct ExtensionOverInsertionPattern : NarrowingPattern<InsertionOp> {
567 using NarrowingPattern<InsertionOp>::NarrowingPattern;
571 virtual InsertionOp createInsertionOp(PatternRewriter &rewriter,
572 InsertionOp origInsert,
574 Value narrowDest)
const = 0;
576 LogicalResult matchAndRewrite(InsertionOp op,
577 PatternRewriter &rewriter)
const final {
578 FailureOr<ExtensionOp> ext =
579 ExtensionOp::from(op.getSource().getDefiningOp());
583 FailureOr<InsertionOp> newInsert = createNarrowInsert(op, rewriter, *ext);
584 if (failed(newInsert))
586 ext->recreateAndReplace(rewriter, op, *newInsert);
590 FailureOr<InsertionOp> createNarrowInsert(InsertionOp op,
591 PatternRewriter &rewriter,
592 ExtensionOp insValue)
const {
599 FailureOr<unsigned> origBitsRequired = calculateBitsRequired(op.getType());
600 if (failed(origBitsRequired))
605 FailureOr<unsigned> destBitsRequired =
606 calculateBitsRequired(op.getDest(), insValue.getKind());
607 if (failed(destBitsRequired) || *destBitsRequired >= *origBitsRequired)
610 FailureOr<unsigned> insertedBitsRequired =
611 calculateBitsRequired(insValue.getIn(), insValue.getKind());
612 if (failed(insertedBitsRequired) ||
613 *insertedBitsRequired >= *origBitsRequired)
618 unsigned newInsertionBits =
619 std::max(*destBitsRequired, *insertedBitsRequired);
620 FailureOr<Type> newVecTy =
621 this->getNarrowType(newInsertionBits, op.getType());
622 if (failed(newVecTy) || *newVecTy == op.getType())
625 FailureOr<Type> newInsertedValueTy =
626 this->getNarrowType(newInsertionBits, insValue.getType());
627 if (failed(newInsertedValueTy))
630 Location loc = op.getLoc();
631 Value narrowValue = rewriter.createOrFold<arith::TruncIOp>(
632 loc, *newInsertedValueTy, insValue.getResult());
634 rewriter.createOrFold<arith::TruncIOp>(loc, *newVecTy, op.getDest());
635 return createInsertionOp(rewriter, op, narrowValue, narrowDest);
639 struct ExtensionOverInsert final
640 : ExtensionOverInsertionPattern<vector::InsertOp> {
641 using ExtensionOverInsertionPattern::ExtensionOverInsertionPattern;
643 vector::InsertOp createInsertionOp(PatternRewriter &rewriter,
644 vector::InsertOp origInsert,
646 Value narrowDest)
const override {
647 return rewriter.create<vector::InsertOp>(origInsert.getLoc(), narrowValue,
649 origInsert.getMixedPosition());
653 struct ExtensionOverInsertElement final
654 : ExtensionOverInsertionPattern<vector::InsertElementOp> {
655 using ExtensionOverInsertionPattern::ExtensionOverInsertionPattern;
657 vector::InsertElementOp createInsertionOp(PatternRewriter &rewriter,
658 vector::InsertElementOp origInsert,
660 Value narrowDest)
const override {
661 return rewriter.create<vector::InsertElementOp>(
662 origInsert.getLoc(), narrowValue, narrowDest, origInsert.getPosition());
666 struct ExtensionOverInsertStridedSlice final
667 : ExtensionOverInsertionPattern<vector::InsertStridedSliceOp> {
668 using ExtensionOverInsertionPattern::ExtensionOverInsertionPattern;
670 vector::InsertStridedSliceOp
671 createInsertionOp(PatternRewriter &rewriter,
672 vector::InsertStridedSliceOp origInsert,
Value narrowValue,
673 Value narrowDest)
const override {
674 return rewriter.create<vector::InsertStridedSliceOp>(
675 origInsert.getLoc(), narrowValue, narrowDest, origInsert.getOffsets(),
676 origInsert.getStrides());
680 struct ExtensionOverShapeCast final : NarrowingPattern<vector::ShapeCastOp> {
681 using NarrowingPattern::NarrowingPattern;
683 LogicalResult matchAndRewrite(vector::ShapeCastOp op,
684 PatternRewriter &rewriter)
const override {
685 FailureOr<ExtensionOp> ext =
686 ExtensionOp::from(op.getSource().getDefiningOp());
690 VectorType origTy = op.getResultVectorType();
692 origTy.cloneWith(origTy.getShape(), ext->getInElementType());
694 rewriter.create<vector::ShapeCastOp>(op.getLoc(), newTy, ext->getIn());
695 ext->recreateAndReplace(rewriter, op, newCast);
700 struct ExtensionOverTranspose final : NarrowingPattern<vector::TransposeOp> {
701 using NarrowingPattern::NarrowingPattern;
703 LogicalResult matchAndRewrite(vector::TransposeOp op,
704 PatternRewriter &rewriter)
const override {
705 FailureOr<ExtensionOp> ext =
706 ExtensionOp::from(op.getVector().getDefiningOp());
710 VectorType origTy = op.getResultVectorType();
712 origTy.cloneWith(origTy.getShape(), ext->getInElementType());
713 Value newTranspose = rewriter.create<vector::TransposeOp>(
714 op.getLoc(), newTy, ext->getIn(), op.getPermutation());
715 ext->recreateAndReplace(rewriter, op, newTranspose);
720 struct ExtensionOverFlatTranspose final
721 : NarrowingPattern<vector::FlatTransposeOp> {
722 using NarrowingPattern::NarrowingPattern;
724 LogicalResult matchAndRewrite(vector::FlatTransposeOp op,
725 PatternRewriter &rewriter)
const override {
726 FailureOr<ExtensionOp> ext =
727 ExtensionOp::from(op.getMatrix().getDefiningOp());
731 VectorType origTy = op.getType();
733 origTy.cloneWith(origTy.getShape(), ext->getInElementType());
734 Value newTranspose = rewriter.create<vector::FlatTransposeOp>(
735 op.getLoc(), newTy, ext->getIn(), op.getRowsAttr(),
736 op.getColumnsAttr());
737 ext->recreateAndReplace(rewriter, op, newTranspose);
746 struct ArithIntNarrowingPass final
747 : impl::ArithIntNarrowingBase<ArithIntNarrowingPass> {
748 using ArithIntNarrowingBase::ArithIntNarrowingBase;
750 void runOnOperation()
override {
751 if (bitwidthsSupported.empty() ||
752 llvm::is_contained(bitwidthsSupported, 0)) {
754 return signalPassFailure();
757 Operation *op = getOperation();
758 MLIRContext *ctx = op->getContext();
759 RewritePatternSet patterns(ctx);
761 patterns, ArithIntNarrowingOptions{
762 llvm::to_vector_of<unsigned>(bitwidthsSupported)});
777 patterns.
add<ExtensionOverBroadcast, ExtensionOverExtract,
778 ExtensionOverExtractElement, ExtensionOverExtractStridedSlice,
779 ExtensionOverInsert, ExtensionOverInsertElement,
780 ExtensionOverInsertStridedSlice, ExtensionOverShapeCast,
781 ExtensionOverTranspose, ExtensionOverFlatTranspose>(
784 patterns.
add<AddIPattern, SubIPattern, MulIPattern, DivSIPattern,
785 DivUIPattern, MaxSIPattern, MaxUIPattern, MinSIPattern,
786 MinUIPattern, SIToFPPattern, UIToFPPattern, IndexCastSIPattern,
static uint64_t zext(uint32_t arg)
static llvm::ManagedStatic< PassManagerOptions > options
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
static Type getElementType(Type type, ArrayRef< int32_t > indices, function_ref< InFlightDiagnostic(StringRef)> emitErrorFn)
Walks the given type hierarchy with the given indices, potentially down to component granularity,...
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
MLIRContext * getContext() const
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
static FailureOr< int64_t > computeConstantBound(presburger::BoundType type, const Variable &var, StopConditionFn stopCondition=nullptr, bool closedUB=false)
Compute a constant bound for the given variable.
void populateArithIntNarrowingPatterns(RewritePatternSet &patterns, const ArithIntNarrowingOptions &options)
Add patterns for integer bitwidth narrowing.
@ Type
An inlay hint that for a type annotation.
Kind
An enumeration of the kinds of predicates.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
LogicalResult applyPatternsAndFoldGreedily(Region ®ion, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr)
Rewrite ops in the given region, which must be isolated from above, by repeatedly applying the highes...
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.