26 #include "llvm/ADT/STLExtras.h"
27 #include "llvm/ADT/SmallVector.h"
32 #define GEN_PASS_DEF_ARITHINTNARROWING
33 #include "mlir/Dialect/Arith/Transforms/Passes.h.inc"
43 template <
typename SourceOp>
44 struct NarrowingPattern : OpRewritePattern<SourceOp> {
45 NarrowingPattern(MLIRContext *ctx,
const ArithIntNarrowingOptions &
options,
46 PatternBenefit benefit = 1)
47 : OpRewritePattern<SourceOp>(ctx, benefit),
48 supportedBitwidths(
options.bitwidthsSupported.begin(),
49 options.bitwidthsSupported.end()) {
50 assert(!supportedBitwidths.empty() &&
"Invalid options");
51 assert(!llvm::is_contained(supportedBitwidths, 0) &&
"Invalid bitwidth");
52 llvm::sort(supportedBitwidths);
56 getNarrowestCompatibleBitwidth(
unsigned bitsRequired)
const {
57 for (
unsigned candidate : supportedBitwidths)
58 if (candidate >= bitsRequired)
65 FailureOr<Type> getNarrowType(
unsigned bitsRequired,
Type origTy)
const {
67 FailureOr<unsigned> bestBitwidth =
68 getNarrowestCompatibleBitwidth(bitsRequired);
73 if (!isa<IntegerType>(elemTy))
77 if (newElemTy == elemTy)
83 if (
auto shapedTy = dyn_cast<ShapedType>(origTy))
84 if (dyn_cast<IntegerType>(shapedTy.getElementType()))
85 return shapedTy.clone(shapedTy.getShape(), newElemTy);
96 FailureOr<unsigned> calculateBitsRequired(
Type type) {
99 return intTy.getWidth();
104 enum class ExtensionKind { Sign, Zero };
115 static FailureOr<ExtensionOp> from(Operation *op) {
116 if (dyn_cast_or_null<arith::ExtSIOp>(op))
117 return ExtensionOp{op, ExtensionKind::Sign};
118 if (dyn_cast_or_null<arith::ExtUIOp>(op))
119 return ExtensionOp{op, ExtensionKind::Zero};
124 ExtensionOp(
const ExtensionOp &) =
default;
125 ExtensionOp &operator=(
const ExtensionOp &) =
default;
128 Operation *recreate(PatternRewriter &rewriter, Location loc,
Type newType,
130 if (kind == ExtensionKind::Sign)
131 return rewriter.create<arith::ExtSIOp>(loc, newType, in);
133 return rewriter.create<arith::ExtUIOp>(loc, newType, in);
137 void recreateAndReplace(PatternRewriter &rewriter, Operation *toReplace,
139 assert(toReplace->getNumResults() == 1);
140 Type newType = toReplace->getResult(0).getType();
141 Operation *newOp = recreate(rewriter, toReplace->getLoc(), newType, in);
142 rewriter.replaceOp(toReplace, newOp->getResult(0));
145 ExtensionKind getKind() {
return kind; }
147 Value getResult() {
return op->getResult(0); }
148 Value getIn() {
return op->getOperand(0); }
150 Type getType() {
return getResult().getType(); }
152 Type getInType() {
return getIn().getType(); }
156 ExtensionOp(Operation *op, ExtensionKind kind) : op(op), kind(kind) {
158 assert((isa<arith::ExtSIOp, arith::ExtUIOp>(op)) &&
"Not an extension op");
160 Operation *op =
nullptr;
161 ExtensionKind kind = {};
165 unsigned calculateBitsRequired(
const APInt &value,
166 ExtensionKind lookThroughExtension) {
169 if (lookThroughExtension == ExtensionKind::Zero)
170 return std::max(value.getActiveBits(), 1u);
173 if (value.isNonNegative())
174 return value.getActiveBits() + 1;
177 if (value.isMinSignedValue())
178 return value.getBitWidth();
182 return value.getBitWidth() - value.getNumSignBits() + 1;
188 FailureOr<unsigned> calculateBitsRequired(
Value value,
189 ExtensionKind lookThroughExtension) {
192 if (
auto intAttr = dyn_cast<IntegerAttr>(attr))
193 return calculateBitsRequired(intAttr.getValue(), lookThroughExtension);
195 if (
auto elemsAttr = dyn_cast<DenseElementsAttr>(attr)) {
196 if (elemsAttr.getElementType().isIntOrIndex()) {
197 if (elemsAttr.isSplat())
198 return calculateBitsRequired(elemsAttr.getSplatValue<APInt>(),
199 lookThroughExtension);
201 unsigned maxBits = 1;
202 for (
const APInt &elemValue : elemsAttr.getValues<APInt>())
204 maxBits, calculateBitsRequired(elemValue, lookThroughExtension));
210 if (lookThroughExtension == ExtensionKind::Sign) {
211 if (
auto sext = value.getDefiningOp<arith::ExtSIOp>())
212 return calculateBitsRequired(sext.getIn().getType());
213 }
else if (lookThroughExtension == ExtensionKind::Zero) {
214 if (
auto zext = value.getDefiningOp<arith::ExtUIOp>())
215 return calculateBitsRequired(
zext.getIn().getType());
219 return calculateBitsRequired(value.getType());
234 template <
typename BinaryOp>
235 struct BinaryOpNarrowingPattern : NarrowingPattern<BinaryOp> {
236 using NarrowingPattern<BinaryOp>::NarrowingPattern;
241 virtual unsigned getResultBitsProduced(
unsigned operandBits)
const = 0;
245 virtual bool isSupported(ExtensionOp)
const {
return true; }
247 LogicalResult matchAndRewrite(BinaryOp op,
248 PatternRewriter &rewriter)
const final {
249 Type origTy = op.getType();
250 FailureOr<unsigned> resultBits = calculateBitsRequired(origTy);
256 FailureOr<ExtensionOp> ext = ExtensionOp::from(op.getLhs().getDefiningOp());
257 if (
failed(ext) || !isSupported(*ext))
260 FailureOr<unsigned> lhsBitsRequired =
261 calculateBitsRequired(ext->getIn(), ext->getKind());
262 if (
failed(lhsBitsRequired) || *lhsBitsRequired >= *resultBits)
265 FailureOr<unsigned> rhsBitsRequired =
266 calculateBitsRequired(op.getRhs(), ext->getKind());
267 if (
failed(rhsBitsRequired) || *rhsBitsRequired >= *resultBits)
272 unsigned commonBitsRequired =
273 getResultBitsProduced(
std::max(*lhsBitsRequired, *rhsBitsRequired));
274 FailureOr<Type> narrowTy = this->getNarrowType(commonBitsRequired, origTy);
275 if (
failed(narrowTy) || calculateBitsRequired(*narrowTy) >= *resultBits)
278 Location loc = op.getLoc();
280 rewriter.createOrFold<arith::TruncIOp>(loc, *narrowTy, op.getLhs());
282 rewriter.createOrFold<arith::TruncIOp>(loc, *narrowTy, op.getRhs());
283 Value newAdd = rewriter.create<BinaryOp>(loc, newLhs, newRhs);
284 ext->recreateAndReplace(rewriter, op, newAdd);
293 struct AddIPattern final : BinaryOpNarrowingPattern<arith::AddIOp> {
294 using BinaryOpNarrowingPattern::BinaryOpNarrowingPattern;
298 unsigned getResultBitsProduced(
unsigned operandBits)
const override {
299 return operandBits + 1;
307 struct SubIPattern final : BinaryOpNarrowingPattern<arith::SubIOp> {
308 using BinaryOpNarrowingPattern::BinaryOpNarrowingPattern;
311 bool isSupported(ExtensionOp ext)
const override {
312 return ext.getKind() == ExtensionKind::Sign;
317 unsigned getResultBitsProduced(
unsigned operandBits)
const override {
318 return operandBits + 1;
326 struct MulIPattern final : BinaryOpNarrowingPattern<arith::MulIOp> {
327 using BinaryOpNarrowingPattern::BinaryOpNarrowingPattern;
331 unsigned getResultBitsProduced(
unsigned operandBits)
const override {
332 return 2 * operandBits;
340 struct DivSIPattern final : BinaryOpNarrowingPattern<arith::DivSIOp> {
341 using BinaryOpNarrowingPattern::BinaryOpNarrowingPattern;
344 bool isSupported(ExtensionOp ext)
const override {
345 return ext.getKind() == ExtensionKind::Sign;
350 unsigned getResultBitsProduced(
unsigned operandBits)
const override {
351 return operandBits + 1;
359 struct DivUIPattern final : BinaryOpNarrowingPattern<arith::DivUIOp> {
360 using BinaryOpNarrowingPattern::BinaryOpNarrowingPattern;
363 bool isSupported(ExtensionOp ext)
const override {
364 return ext.getKind() == ExtensionKind::Zero;
368 unsigned getResultBitsProduced(
unsigned operandBits)
const override {
377 template <
typename MinMaxOp, ExtensionKind Kind>
378 struct MinMaxPattern final : BinaryOpNarrowingPattern<MinMaxOp> {
379 using BinaryOpNarrowingPattern<MinMaxOp>::BinaryOpNarrowingPattern;
381 bool isSupported(ExtensionOp ext)
const override {
382 return ext.getKind() ==
Kind;
387 unsigned getResultBitsProduced(
unsigned operandBits)
const override {
391 using MaxSIPattern = MinMaxPattern<arith::MaxSIOp, ExtensionKind::Sign>;
392 using MaxUIPattern = MinMaxPattern<arith::MaxUIOp, ExtensionKind::Zero>;
393 using MinSIPattern = MinMaxPattern<arith::MinSIOp, ExtensionKind::Sign>;
394 using MinUIPattern = MinMaxPattern<arith::MinUIOp, ExtensionKind::Zero>;
400 template <
typename IToFPOp, ExtensionKind Extension>
401 struct IToFPPattern final : NarrowingPattern<IToFPOp> {
402 using NarrowingPattern<IToFPOp>::NarrowingPattern;
404 LogicalResult matchAndRewrite(IToFPOp op,
405 PatternRewriter &rewriter)
const override {
406 FailureOr<unsigned> narrowestWidth =
407 calculateBitsRequired(op.getIn(), Extension);
408 if (
failed(narrowestWidth))
411 FailureOr<Type> narrowTy =
412 this->getNarrowType(*narrowestWidth, op.getIn().getType());
416 Value newIn = rewriter.createOrFold<arith::TruncIOp>(op.getLoc(), *narrowTy,
418 rewriter.replaceOpWithNewOp<IToFPOp>(op, op.getType(), newIn);
422 using SIToFPPattern = IToFPPattern<arith::SIToFPOp, ExtensionKind::Sign>;
423 using UIToFPPattern = IToFPPattern<arith::UIToFPOp, ExtensionKind::Zero>;
432 template <
typename CastOp, ExtensionKind Kind>
433 struct IndexCastPattern final : NarrowingPattern<CastOp> {
434 using NarrowingPattern<CastOp>::NarrowingPattern;
436 LogicalResult matchAndRewrite(CastOp op,
437 PatternRewriter &rewriter)
const override {
438 Value in = op.getIn();
440 if (!isa<IndexType>(in.getType()))
457 assert(*lb <= *ub &&
"Invalid bounds");
458 unsigned lbBitsRequired = calculateBitsRequired(APInt(64, *lb),
Kind);
459 unsigned ubBitsRequired = calculateBitsRequired(APInt(64, *ub),
Kind);
460 unsigned bitsRequired =
std::max(lbBitsRequired, ubBitsRequired);
462 IntegerType resultTy = cast<IntegerType>(op.getType());
463 if (resultTy.getWidth() <= bitsRequired)
466 FailureOr<Type> narrowTy = this->getNarrowType(bitsRequired, resultTy);
470 Value newCast = rewriter.create<CastOp>(op.getLoc(), *narrowTy, op.getIn());
472 if (
Kind == ExtensionKind::Sign)
473 rewriter.replaceOpWithNewOp<arith::ExtSIOp>(op, resultTy, newCast);
475 rewriter.replaceOpWithNewOp<arith::ExtUIOp>(op, resultTy, newCast);
479 using IndexCastSIPattern =
480 IndexCastPattern<arith::IndexCastOp, ExtensionKind::Sign>;
481 using IndexCastUIPattern =
482 IndexCastPattern<arith::IndexCastUIOp, ExtensionKind::Zero>;
488 struct ExtensionOverBroadcast final : NarrowingPattern<vector::BroadcastOp> {
489 using NarrowingPattern::NarrowingPattern;
491 LogicalResult matchAndRewrite(vector::BroadcastOp op,
492 PatternRewriter &rewriter)
const override {
493 FailureOr<ExtensionOp> ext =
494 ExtensionOp::from(op.getSource().getDefiningOp());
498 VectorType origTy = op.getResultVectorType();
500 origTy.cloneWith(origTy.getShape(), ext->getInElementType());
502 rewriter.create<vector::BroadcastOp>(op.getLoc(), newTy, ext->getIn());
503 ext->recreateAndReplace(rewriter, op, newBroadcast);
508 struct ExtensionOverExtract final : NarrowingPattern<vector::ExtractOp> {
509 using NarrowingPattern::NarrowingPattern;
511 LogicalResult matchAndRewrite(vector::ExtractOp op,
512 PatternRewriter &rewriter)
const override {
513 FailureOr<ExtensionOp> ext =
514 ExtensionOp::from(op.getVector().getDefiningOp());
518 Value newExtract = rewriter.create<vector::ExtractOp>(
519 op.getLoc(), ext->getIn(), op.getMixedPosition());
520 ext->recreateAndReplace(rewriter, op, newExtract);
525 struct ExtensionOverExtractElement final
526 : NarrowingPattern<vector::ExtractElementOp> {
527 using NarrowingPattern::NarrowingPattern;
529 LogicalResult matchAndRewrite(vector::ExtractElementOp op,
530 PatternRewriter &rewriter)
const override {
531 FailureOr<ExtensionOp> ext =
532 ExtensionOp::from(op.getVector().getDefiningOp());
536 Value newExtract = rewriter.create<vector::ExtractElementOp>(
537 op.getLoc(), ext->getIn(), op.getPosition());
538 ext->recreateAndReplace(rewriter, op, newExtract);
543 struct ExtensionOverExtractStridedSlice final
544 : NarrowingPattern<vector::ExtractStridedSliceOp> {
545 using NarrowingPattern::NarrowingPattern;
547 LogicalResult matchAndRewrite(vector::ExtractStridedSliceOp op,
548 PatternRewriter &rewriter)
const override {
549 FailureOr<ExtensionOp> ext =
550 ExtensionOp::from(op.getVector().getDefiningOp());
554 VectorType origTy = op.getType();
555 VectorType extractTy =
556 origTy.cloneWith(origTy.getShape(), ext->getInElementType());
557 Value newExtract = rewriter.create<vector::ExtractStridedSliceOp>(
558 op.getLoc(), extractTy, ext->getIn(), op.getOffsets(), op.getSizes(),
560 ext->recreateAndReplace(rewriter, op, newExtract);
566 template <
typename InsertionOp>
567 struct ExtensionOverInsertionPattern : NarrowingPattern<InsertionOp> {
568 using NarrowingPattern<InsertionOp>::NarrowingPattern;
572 virtual InsertionOp createInsertionOp(PatternRewriter &rewriter,
573 InsertionOp origInsert,
575 Value narrowDest)
const = 0;
577 LogicalResult matchAndRewrite(InsertionOp op,
578 PatternRewriter &rewriter)
const final {
579 FailureOr<ExtensionOp> ext =
580 ExtensionOp::from(op.getSource().getDefiningOp());
584 FailureOr<InsertionOp> newInsert = createNarrowInsert(op, rewriter, *ext);
587 ext->recreateAndReplace(rewriter, op, *newInsert);
591 FailureOr<InsertionOp> createNarrowInsert(InsertionOp op,
592 PatternRewriter &rewriter,
593 ExtensionOp insValue)
const {
600 FailureOr<unsigned> origBitsRequired = calculateBitsRequired(op.getType());
601 if (
failed(origBitsRequired))
606 FailureOr<unsigned> destBitsRequired =
607 calculateBitsRequired(op.getDest(), insValue.getKind());
608 if (
failed(destBitsRequired) || *destBitsRequired >= *origBitsRequired)
611 FailureOr<unsigned> insertedBitsRequired =
612 calculateBitsRequired(insValue.getIn(), insValue.getKind());
613 if (
failed(insertedBitsRequired) ||
614 *insertedBitsRequired >= *origBitsRequired)
619 unsigned newInsertionBits =
620 std::max(*destBitsRequired, *insertedBitsRequired);
621 FailureOr<Type> newVecTy =
622 this->getNarrowType(newInsertionBits, op.getType());
623 if (
failed(newVecTy) || *newVecTy == op.getType())
626 FailureOr<Type> newInsertedValueTy =
627 this->getNarrowType(newInsertionBits, insValue.getType());
628 if (
failed(newInsertedValueTy))
631 Location loc = op.getLoc();
632 Value narrowValue = rewriter.createOrFold<arith::TruncIOp>(
633 loc, *newInsertedValueTy, insValue.getResult());
635 rewriter.createOrFold<arith::TruncIOp>(loc, *newVecTy, op.getDest());
636 return createInsertionOp(rewriter, op, narrowValue, narrowDest);
640 struct ExtensionOverInsert final
641 : ExtensionOverInsertionPattern<vector::InsertOp> {
642 using ExtensionOverInsertionPattern::ExtensionOverInsertionPattern;
644 vector::InsertOp createInsertionOp(PatternRewriter &rewriter,
645 vector::InsertOp origInsert,
647 Value narrowDest)
const override {
648 return rewriter.create<vector::InsertOp>(origInsert.getLoc(), narrowValue,
650 origInsert.getMixedPosition());
654 struct ExtensionOverInsertElement final
655 : ExtensionOverInsertionPattern<vector::InsertElementOp> {
656 using ExtensionOverInsertionPattern::ExtensionOverInsertionPattern;
658 vector::InsertElementOp createInsertionOp(PatternRewriter &rewriter,
659 vector::InsertElementOp origInsert,
661 Value narrowDest)
const override {
662 return rewriter.create<vector::InsertElementOp>(
663 origInsert.getLoc(), narrowValue, narrowDest, origInsert.getPosition());
667 struct ExtensionOverInsertStridedSlice final
668 : ExtensionOverInsertionPattern<vector::InsertStridedSliceOp> {
669 using ExtensionOverInsertionPattern::ExtensionOverInsertionPattern;
671 vector::InsertStridedSliceOp
672 createInsertionOp(PatternRewriter &rewriter,
673 vector::InsertStridedSliceOp origInsert,
Value narrowValue,
674 Value narrowDest)
const override {
675 return rewriter.create<vector::InsertStridedSliceOp>(
676 origInsert.getLoc(), narrowValue, narrowDest, origInsert.getOffsets(),
677 origInsert.getStrides());
681 struct ExtensionOverShapeCast final : NarrowingPattern<vector::ShapeCastOp> {
682 using NarrowingPattern::NarrowingPattern;
684 LogicalResult matchAndRewrite(vector::ShapeCastOp op,
685 PatternRewriter &rewriter)
const override {
686 FailureOr<ExtensionOp> ext =
687 ExtensionOp::from(op.getSource().getDefiningOp());
691 VectorType origTy = op.getResultVectorType();
693 origTy.cloneWith(origTy.getShape(), ext->getInElementType());
695 rewriter.create<vector::ShapeCastOp>(op.getLoc(), newTy, ext->getIn());
696 ext->recreateAndReplace(rewriter, op, newCast);
701 struct ExtensionOverTranspose final : NarrowingPattern<vector::TransposeOp> {
702 using NarrowingPattern::NarrowingPattern;
704 LogicalResult matchAndRewrite(vector::TransposeOp op,
705 PatternRewriter &rewriter)
const override {
706 FailureOr<ExtensionOp> ext =
707 ExtensionOp::from(op.getVector().getDefiningOp());
711 VectorType origTy = op.getResultVectorType();
713 origTy.cloneWith(origTy.getShape(), ext->getInElementType());
714 Value newTranspose = rewriter.create<vector::TransposeOp>(
715 op.getLoc(), newTy, ext->getIn(), op.getPermutation());
716 ext->recreateAndReplace(rewriter, op, newTranspose);
721 struct ExtensionOverFlatTranspose final
722 : NarrowingPattern<vector::FlatTransposeOp> {
723 using NarrowingPattern::NarrowingPattern;
725 LogicalResult matchAndRewrite(vector::FlatTransposeOp op,
726 PatternRewriter &rewriter)
const override {
727 FailureOr<ExtensionOp> ext =
728 ExtensionOp::from(op.getMatrix().getDefiningOp());
732 VectorType origTy = op.getType();
734 origTy.cloneWith(origTy.getShape(), ext->getInElementType());
735 Value newTranspose = rewriter.create<vector::FlatTransposeOp>(
736 op.getLoc(), newTy, ext->getIn(), op.getRowsAttr(),
737 op.getColumnsAttr());
738 ext->recreateAndReplace(rewriter, op, newTranspose);
747 struct ArithIntNarrowingPass final
748 : impl::ArithIntNarrowingBase<ArithIntNarrowingPass> {
749 using ArithIntNarrowingBase::ArithIntNarrowingBase;
751 void runOnOperation()
override {
752 if (bitwidthsSupported.empty() ||
753 llvm::is_contained(bitwidthsSupported, 0)) {
755 return signalPassFailure();
758 Operation *op = getOperation();
759 MLIRContext *ctx = op->getContext();
760 RewritePatternSet patterns(ctx);
762 patterns, ArithIntNarrowingOptions{bitwidthsSupported});
777 patterns.
add<ExtensionOverBroadcast, ExtensionOverExtract,
778 ExtensionOverExtractElement, ExtensionOverExtractStridedSlice,
779 ExtensionOverInsert, ExtensionOverInsertElement,
780 ExtensionOverInsertStridedSlice, ExtensionOverShapeCast,
781 ExtensionOverTranspose, ExtensionOverFlatTranspose>(
784 patterns.
add<AddIPattern, SubIPattern, MulIPattern, DivSIPattern,
785 DivUIPattern, MaxSIPattern, MaxUIPattern, MinSIPattern,
786 MinUIPattern, SIToFPPattern, UIToFPPattern, IndexCastSIPattern,
static uint64_t zext(uint32_t arg)
static llvm::ManagedStatic< PassManagerOptions > options
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
static Type getElementType(Type type, ArrayRef< int32_t > indices, function_ref< InFlightDiagnostic(StringRef)> emitErrorFn)
Walks the given type hierarchy with the given indices, potentially down to component granularity,...
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
MLIRContext * getContext() const
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
static FailureOr< int64_t > computeConstantBound(presburger::BoundType type, Value value, std::optional< int64_t > dim=std::nullopt, StopConditionFn stopCondition=nullptr, bool closedUB=false)
Compute a constant bound for the given affine map, where dims and symbols are bound to the given oper...
void populateArithIntNarrowingPatterns(RewritePatternSet &patterns, const ArithIntNarrowingOptions &options)
Add patterns for integer bitwidth narrowing.
@ Type
An inlay hint that for a type annotation.
Kind
An enumeration of the kinds of predicates.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
LogicalResult applyPatternsAndFoldGreedily(Region ®ion, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr)
Rewrite ops in the given region, which must be isolated from above, by repeatedly applying the highes...
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value.