MLIR  19.0.0git
IntNarrowing.cpp
Go to the documentation of this file.
1 //===- IntNarrowing.cpp - Integer bitwidth reduction optimizations --------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
10 
17 #include "mlir/IR/BuiltinTypes.h"
18 #include "mlir/IR/MLIRContext.h"
19 #include "mlir/IR/Matchers.h"
20 #include "mlir/IR/Operation.h"
21 #include "mlir/IR/PatternMatch.h"
22 #include "mlir/IR/TypeUtilities.h"
26 #include "llvm/ADT/STLExtras.h"
27 #include "llvm/ADT/SmallVector.h"
28 #include <cassert>
29 #include <cstdint>
30 
31 namespace mlir::arith {
32 #define GEN_PASS_DEF_ARITHINTNARROWING
33 #include "mlir/Dialect/Arith/Transforms/Passes.h.inc"
34 } // namespace mlir::arith
35 
36 namespace mlir::arith {
37 namespace {
38 //===----------------------------------------------------------------------===//
39 // Common Helpers
40 //===----------------------------------------------------------------------===//
41 
42 /// The base for integer bitwidth narrowing patterns.
43 template <typename SourceOp>
44 struct NarrowingPattern : OpRewritePattern<SourceOp> {
45  NarrowingPattern(MLIRContext *ctx, const ArithIntNarrowingOptions &options,
46  PatternBenefit benefit = 1)
47  : OpRewritePattern<SourceOp>(ctx, benefit),
48  supportedBitwidths(options.bitwidthsSupported.begin(),
49  options.bitwidthsSupported.end()) {
50  assert(!supportedBitwidths.empty() && "Invalid options");
51  assert(!llvm::is_contained(supportedBitwidths, 0) && "Invalid bitwidth");
52  llvm::sort(supportedBitwidths);
53  }
54 
55  FailureOr<unsigned>
56  getNarrowestCompatibleBitwidth(unsigned bitsRequired) const {
57  for (unsigned candidate : supportedBitwidths)
58  if (candidate >= bitsRequired)
59  return candidate;
60 
61  return failure();
62  }
63 
64  /// Returns the narrowest supported type that fits `bitsRequired`.
65  FailureOr<Type> getNarrowType(unsigned bitsRequired, Type origTy) const {
66  assert(origTy);
67  FailureOr<unsigned> bestBitwidth =
68  getNarrowestCompatibleBitwidth(bitsRequired);
69  if (failed(bestBitwidth))
70  return failure();
71 
72  Type elemTy = getElementTypeOrSelf(origTy);
73  if (!isa<IntegerType>(elemTy))
74  return failure();
75 
76  auto newElemTy = IntegerType::get(origTy.getContext(), *bestBitwidth);
77  if (newElemTy == elemTy)
78  return failure();
79 
80  if (origTy == elemTy)
81  return newElemTy;
82 
83  if (auto shapedTy = dyn_cast<ShapedType>(origTy))
84  if (dyn_cast<IntegerType>(shapedTy.getElementType()))
85  return shapedTy.clone(shapedTy.getShape(), newElemTy);
86 
87  return failure();
88  }
89 
90 private:
91  // Supported integer bitwidths in the ascending order.
92  llvm::SmallVector<unsigned, 6> supportedBitwidths;
93 };
94 
95 /// Returns the integer bitwidth required to represent `type`.
96 FailureOr<unsigned> calculateBitsRequired(Type type) {
97  assert(type);
98  if (auto intTy = dyn_cast<IntegerType>(getElementTypeOrSelf(type)))
99  return intTy.getWidth();
100 
101  return failure();
102 }
103 
104 enum class ExtensionKind { Sign, Zero };
105 
106 /// Wrapper around `arith::ExtSIOp` and `arith::ExtUIOp` ops that abstracts away
107 /// the exact op type. Exposes helper functions to query the types, operands,
108 /// and the result. This is so that we can handle both extension kinds without
109 /// needing to use templates or branching.
110 class ExtensionOp {
111 public:
112  /// Attemps to create a new extension op from `op`. Returns an extension op
113  /// wrapper when `op` is either `arith.extsi` or `arith.extui`, and failure
114  /// otherwise.
115  static FailureOr<ExtensionOp> from(Operation *op) {
116  if (dyn_cast_or_null<arith::ExtSIOp>(op))
117  return ExtensionOp{op, ExtensionKind::Sign};
118  if (dyn_cast_or_null<arith::ExtUIOp>(op))
119  return ExtensionOp{op, ExtensionKind::Zero};
120 
121  return failure();
122  }
123 
124  ExtensionOp(const ExtensionOp &) = default;
125  ExtensionOp &operator=(const ExtensionOp &) = default;
126 
127  /// Creates a new extension op of the same kind.
128  Operation *recreate(PatternRewriter &rewriter, Location loc, Type newType,
129  Value in) {
130  if (kind == ExtensionKind::Sign)
131  return rewriter.create<arith::ExtSIOp>(loc, newType, in);
132 
133  return rewriter.create<arith::ExtUIOp>(loc, newType, in);
134  }
135 
136  /// Replaces `toReplace` with a new extension op of the same kind.
137  void recreateAndReplace(PatternRewriter &rewriter, Operation *toReplace,
138  Value in) {
139  assert(toReplace->getNumResults() == 1);
140  Type newType = toReplace->getResult(0).getType();
141  Operation *newOp = recreate(rewriter, toReplace->getLoc(), newType, in);
142  rewriter.replaceOp(toReplace, newOp->getResult(0));
143  }
144 
145  ExtensionKind getKind() { return kind; }
146 
147  Value getResult() { return op->getResult(0); }
148  Value getIn() { return op->getOperand(0); }
149 
150  Type getType() { return getResult().getType(); }
151  Type getElementType() { return getElementTypeOrSelf(getType()); }
152  Type getInType() { return getIn().getType(); }
153  Type getInElementType() { return getElementTypeOrSelf(getInType()); }
154 
155 private:
156  ExtensionOp(Operation *op, ExtensionKind kind) : op(op), kind(kind) {
157  assert(op);
158  assert((isa<arith::ExtSIOp, arith::ExtUIOp>(op)) && "Not an extension op");
159  }
160  Operation *op = nullptr;
161  ExtensionKind kind = {};
162 };
163 
164 /// Returns the integer bitwidth required to represent `value`.
165 unsigned calculateBitsRequired(const APInt &value,
166  ExtensionKind lookThroughExtension) {
167  // For unsigned values, we only need the active bits. As a special case, zero
168  // requires one bit.
169  if (lookThroughExtension == ExtensionKind::Zero)
170  return std::max(value.getActiveBits(), 1u);
171 
172  // If a signed value is nonnegative, we need one extra bit for the sign.
173  if (value.isNonNegative())
174  return value.getActiveBits() + 1;
175 
176  // For the signed min, we need all the bits.
177  if (value.isMinSignedValue())
178  return value.getBitWidth();
179 
180  // For negative values, we need all the non-sign bits and one extra bit for
181  // the sign.
182  return value.getBitWidth() - value.getNumSignBits() + 1;
183 }
184 
185 /// Returns the integer bitwidth required to represent `value`.
186 /// Looks through either sign- or zero-extension as specified by
187 /// `lookThroughExtension`.
188 FailureOr<unsigned> calculateBitsRequired(Value value,
189  ExtensionKind lookThroughExtension) {
190  // Handle constants.
191  if (TypedAttr attr; matchPattern(value, m_Constant(&attr))) {
192  if (auto intAttr = dyn_cast<IntegerAttr>(attr))
193  return calculateBitsRequired(intAttr.getValue(), lookThroughExtension);
194 
195  if (auto elemsAttr = dyn_cast<DenseElementsAttr>(attr)) {
196  if (elemsAttr.getElementType().isIntOrIndex()) {
197  if (elemsAttr.isSplat())
198  return calculateBitsRequired(elemsAttr.getSplatValue<APInt>(),
199  lookThroughExtension);
200 
201  unsigned maxBits = 1;
202  for (const APInt &elemValue : elemsAttr.getValues<APInt>())
203  maxBits = std::max(
204  maxBits, calculateBitsRequired(elemValue, lookThroughExtension));
205  return maxBits;
206  }
207  }
208  }
209 
210  if (lookThroughExtension == ExtensionKind::Sign) {
211  if (auto sext = value.getDefiningOp<arith::ExtSIOp>())
212  return calculateBitsRequired(sext.getIn().getType());
213  } else if (lookThroughExtension == ExtensionKind::Zero) {
214  if (auto zext = value.getDefiningOp<arith::ExtUIOp>())
215  return calculateBitsRequired(zext.getIn().getType());
216  }
217 
218  // If nothing else worked, return the type requirements for this element type.
219  return calculateBitsRequired(value.getType());
220 }
221 
222 /// Base pattern for arith binary ops.
223 /// Example:
224 /// ```
225 /// %lhs = arith.extsi %a : i8 to i32
226 /// %rhs = arith.extsi %b : i8 to i32
227 /// %r = arith.addi %lhs, %rhs : i32
228 /// ==>
229 /// %lhs = arith.extsi %a : i8 to i16
230 /// %rhs = arith.extsi %b : i8 to i16
231 /// %add = arith.addi %lhs, %rhs : i16
232 /// %r = arith.extsi %add : i16 to i32
233 /// ```
234 template <typename BinaryOp>
235 struct BinaryOpNarrowingPattern : NarrowingPattern<BinaryOp> {
236  using NarrowingPattern<BinaryOp>::NarrowingPattern;
237 
238  /// Returns the number of bits required to represent the full result, assuming
239  /// that both operands are `operandBits`-wide. Derived classes must implement
240  /// this, taking into account `BinaryOp` semantics.
241  virtual unsigned getResultBitsProduced(unsigned operandBits) const = 0;
242 
243  /// Customization point for patterns that should only apply with
244  /// zero/sign-extension ops as arguments.
245  virtual bool isSupported(ExtensionOp) const { return true; }
246 
247  LogicalResult matchAndRewrite(BinaryOp op,
248  PatternRewriter &rewriter) const final {
249  Type origTy = op.getType();
250  FailureOr<unsigned> resultBits = calculateBitsRequired(origTy);
251  if (failed(resultBits))
252  return failure();
253 
254  // For the optimization to apply, we expect the lhs to be an extension op,
255  // and for the rhs to either be the same extension op or a constant.
256  FailureOr<ExtensionOp> ext = ExtensionOp::from(op.getLhs().getDefiningOp());
257  if (failed(ext) || !isSupported(*ext))
258  return failure();
259 
260  FailureOr<unsigned> lhsBitsRequired =
261  calculateBitsRequired(ext->getIn(), ext->getKind());
262  if (failed(lhsBitsRequired) || *lhsBitsRequired >= *resultBits)
263  return failure();
264 
265  FailureOr<unsigned> rhsBitsRequired =
266  calculateBitsRequired(op.getRhs(), ext->getKind());
267  if (failed(rhsBitsRequired) || *rhsBitsRequired >= *resultBits)
268  return failure();
269 
270  // Negotiate a common bit requirements for both lhs and rhs, accounting for
271  // the result requiring more bits than the operands.
272  unsigned commonBitsRequired =
273  getResultBitsProduced(std::max(*lhsBitsRequired, *rhsBitsRequired));
274  FailureOr<Type> narrowTy = this->getNarrowType(commonBitsRequired, origTy);
275  if (failed(narrowTy) || calculateBitsRequired(*narrowTy) >= *resultBits)
276  return failure();
277 
278  Location loc = op.getLoc();
279  Value newLhs =
280  rewriter.createOrFold<arith::TruncIOp>(loc, *narrowTy, op.getLhs());
281  Value newRhs =
282  rewriter.createOrFold<arith::TruncIOp>(loc, *narrowTy, op.getRhs());
283  Value newAdd = rewriter.create<BinaryOp>(loc, newLhs, newRhs);
284  ext->recreateAndReplace(rewriter, op, newAdd);
285  return success();
286  }
287 };
288 
289 //===----------------------------------------------------------------------===//
290 // AddIOp Pattern
291 //===----------------------------------------------------------------------===//
292 
293 struct AddIPattern final : BinaryOpNarrowingPattern<arith::AddIOp> {
294  using BinaryOpNarrowingPattern::BinaryOpNarrowingPattern;
295 
296  // Addition may require one extra bit for the result.
297  // Example: `UINT8_MAX + 1 == 255 + 1 == 256`.
298  unsigned getResultBitsProduced(unsigned operandBits) const override {
299  return operandBits + 1;
300  }
301 };
302 
303 //===----------------------------------------------------------------------===//
304 // SubIOp Pattern
305 //===----------------------------------------------------------------------===//
306 
307 struct SubIPattern final : BinaryOpNarrowingPattern<arith::SubIOp> {
308  using BinaryOpNarrowingPattern::BinaryOpNarrowingPattern;
309 
310  // This optimization only applies to signed arguments.
311  bool isSupported(ExtensionOp ext) const override {
312  return ext.getKind() == ExtensionKind::Sign;
313  }
314 
315  // Subtraction may require one extra bit for the result.
316  // Example: `INT8_MAX - (-1) == 127 - (-1) == 128`.
317  unsigned getResultBitsProduced(unsigned operandBits) const override {
318  return operandBits + 1;
319  }
320 };
321 
322 //===----------------------------------------------------------------------===//
323 // MulIOp Pattern
324 //===----------------------------------------------------------------------===//
325 
326 struct MulIPattern final : BinaryOpNarrowingPattern<arith::MulIOp> {
327  using BinaryOpNarrowingPattern::BinaryOpNarrowingPattern;
328 
329  // Multiplication may require up double the operand bits.
330  // Example: `UNT8_MAX * UINT8_MAX == 255 * 255 == 65025`.
331  unsigned getResultBitsProduced(unsigned operandBits) const override {
332  return 2 * operandBits;
333  }
334 };
335 
336 //===----------------------------------------------------------------------===//
337 // DivSIOp Pattern
338 //===----------------------------------------------------------------------===//
339 
340 struct DivSIPattern final : BinaryOpNarrowingPattern<arith::DivSIOp> {
341  using BinaryOpNarrowingPattern::BinaryOpNarrowingPattern;
342 
343  // This optimization only applies to signed arguments.
344  bool isSupported(ExtensionOp ext) const override {
345  return ext.getKind() == ExtensionKind::Sign;
346  }
347 
348  // Unlike multiplication, signed division requires only one more result bit.
349  // Example: `INT8_MIN / (-1) == -128 / (-1) == 128`.
350  unsigned getResultBitsProduced(unsigned operandBits) const override {
351  return operandBits + 1;
352  }
353 };
354 
355 //===----------------------------------------------------------------------===//
356 // DivUIOp Pattern
357 //===----------------------------------------------------------------------===//
358 
359 struct DivUIPattern final : BinaryOpNarrowingPattern<arith::DivUIOp> {
360  using BinaryOpNarrowingPattern::BinaryOpNarrowingPattern;
361 
362  // This optimization only applies to unsigned arguments.
363  bool isSupported(ExtensionOp ext) const override {
364  return ext.getKind() == ExtensionKind::Zero;
365  }
366 
367  // Unsigned division does not require any extra result bits.
368  unsigned getResultBitsProduced(unsigned operandBits) const override {
369  return operandBits;
370  }
371 };
372 
373 //===----------------------------------------------------------------------===//
374 // Min/Max Patterns
375 //===----------------------------------------------------------------------===//
376 
377 template <typename MinMaxOp, ExtensionKind Kind>
378 struct MinMaxPattern final : BinaryOpNarrowingPattern<MinMaxOp> {
379  using BinaryOpNarrowingPattern<MinMaxOp>::BinaryOpNarrowingPattern;
380 
381  bool isSupported(ExtensionOp ext) const override {
382  return ext.getKind() == Kind;
383  }
384 
385  // Min/max returns one of the arguments and does not require any extra result
386  // bits.
387  unsigned getResultBitsProduced(unsigned operandBits) const override {
388  return operandBits;
389  }
390 };
391 using MaxSIPattern = MinMaxPattern<arith::MaxSIOp, ExtensionKind::Sign>;
392 using MaxUIPattern = MinMaxPattern<arith::MaxUIOp, ExtensionKind::Zero>;
393 using MinSIPattern = MinMaxPattern<arith::MinSIOp, ExtensionKind::Sign>;
394 using MinUIPattern = MinMaxPattern<arith::MinUIOp, ExtensionKind::Zero>;
395 
396 //===----------------------------------------------------------------------===//
397 // *IToFPOp Patterns
398 //===----------------------------------------------------------------------===//
399 
400 template <typename IToFPOp, ExtensionKind Extension>
401 struct IToFPPattern final : NarrowingPattern<IToFPOp> {
402  using NarrowingPattern<IToFPOp>::NarrowingPattern;
403 
404  LogicalResult matchAndRewrite(IToFPOp op,
405  PatternRewriter &rewriter) const override {
406  FailureOr<unsigned> narrowestWidth =
407  calculateBitsRequired(op.getIn(), Extension);
408  if (failed(narrowestWidth))
409  return failure();
410 
411  FailureOr<Type> narrowTy =
412  this->getNarrowType(*narrowestWidth, op.getIn().getType());
413  if (failed(narrowTy))
414  return failure();
415 
416  Value newIn = rewriter.createOrFold<arith::TruncIOp>(op.getLoc(), *narrowTy,
417  op.getIn());
418  rewriter.replaceOpWithNewOp<IToFPOp>(op, op.getType(), newIn);
419  return success();
420  }
421 };
422 using SIToFPPattern = IToFPPattern<arith::SIToFPOp, ExtensionKind::Sign>;
423 using UIToFPPattern = IToFPPattern<arith::UIToFPOp, ExtensionKind::Zero>;
424 
425 //===----------------------------------------------------------------------===//
426 // Index Cast Patterns
427 //===----------------------------------------------------------------------===//
428 
429 // These rely on the `ValueBounds` interface for index values. For example, we
430 // can often statically tell index value bounds of loop induction variables.
431 
432 template <typename CastOp, ExtensionKind Kind>
433 struct IndexCastPattern final : NarrowingPattern<CastOp> {
434  using NarrowingPattern<CastOp>::NarrowingPattern;
435 
436  LogicalResult matchAndRewrite(CastOp op,
437  PatternRewriter &rewriter) const override {
438  Value in = op.getIn();
439  // We only support scalar index -> integer casts.
440  if (!isa<IndexType>(in.getType()))
441  return failure();
442 
443  // Check the lower bound in both the signed and unsigned cast case. We
444  // conservatively assume that even unsigned casts may be performed on
445  // negative indices.
446  FailureOr<int64_t> lb = ValueBoundsConstraintSet::computeConstantBound(
448  if (failed(lb))
449  return failure();
450 
451  FailureOr<int64_t> ub = ValueBoundsConstraintSet::computeConstantBound(
453  /*stopCondition=*/nullptr, /*closedUB=*/true);
454  if (failed(ub))
455  return failure();
456 
457  assert(*lb <= *ub && "Invalid bounds");
458  unsigned lbBitsRequired = calculateBitsRequired(APInt(64, *lb), Kind);
459  unsigned ubBitsRequired = calculateBitsRequired(APInt(64, *ub), Kind);
460  unsigned bitsRequired = std::max(lbBitsRequired, ubBitsRequired);
461 
462  IntegerType resultTy = cast<IntegerType>(op.getType());
463  if (resultTy.getWidth() <= bitsRequired)
464  return failure();
465 
466  FailureOr<Type> narrowTy = this->getNarrowType(bitsRequired, resultTy);
467  if (failed(narrowTy))
468  return failure();
469 
470  Value newCast = rewriter.create<CastOp>(op.getLoc(), *narrowTy, op.getIn());
471 
472  if (Kind == ExtensionKind::Sign)
473  rewriter.replaceOpWithNewOp<arith::ExtSIOp>(op, resultTy, newCast);
474  else
475  rewriter.replaceOpWithNewOp<arith::ExtUIOp>(op, resultTy, newCast);
476  return success();
477  }
478 };
479 using IndexCastSIPattern =
480  IndexCastPattern<arith::IndexCastOp, ExtensionKind::Sign>;
481 using IndexCastUIPattern =
482  IndexCastPattern<arith::IndexCastUIOp, ExtensionKind::Zero>;
483 
484 //===----------------------------------------------------------------------===//
485 // Patterns to Commute Extension Ops
486 //===----------------------------------------------------------------------===//
487 
488 struct ExtensionOverBroadcast final : NarrowingPattern<vector::BroadcastOp> {
489  using NarrowingPattern::NarrowingPattern;
490 
491  LogicalResult matchAndRewrite(vector::BroadcastOp op,
492  PatternRewriter &rewriter) const override {
493  FailureOr<ExtensionOp> ext =
494  ExtensionOp::from(op.getSource().getDefiningOp());
495  if (failed(ext))
496  return failure();
497 
498  VectorType origTy = op.getResultVectorType();
499  VectorType newTy =
500  origTy.cloneWith(origTy.getShape(), ext->getInElementType());
501  Value newBroadcast =
502  rewriter.create<vector::BroadcastOp>(op.getLoc(), newTy, ext->getIn());
503  ext->recreateAndReplace(rewriter, op, newBroadcast);
504  return success();
505  }
506 };
507 
508 struct ExtensionOverExtract final : NarrowingPattern<vector::ExtractOp> {
509  using NarrowingPattern::NarrowingPattern;
510 
511  LogicalResult matchAndRewrite(vector::ExtractOp op,
512  PatternRewriter &rewriter) const override {
513  FailureOr<ExtensionOp> ext =
514  ExtensionOp::from(op.getVector().getDefiningOp());
515  if (failed(ext))
516  return failure();
517 
518  Value newExtract = rewriter.create<vector::ExtractOp>(
519  op.getLoc(), ext->getIn(), op.getMixedPosition());
520  ext->recreateAndReplace(rewriter, op, newExtract);
521  return success();
522  }
523 };
524 
525 struct ExtensionOverExtractElement final
526  : NarrowingPattern<vector::ExtractElementOp> {
527  using NarrowingPattern::NarrowingPattern;
528 
529  LogicalResult matchAndRewrite(vector::ExtractElementOp op,
530  PatternRewriter &rewriter) const override {
531  FailureOr<ExtensionOp> ext =
532  ExtensionOp::from(op.getVector().getDefiningOp());
533  if (failed(ext))
534  return failure();
535 
536  Value newExtract = rewriter.create<vector::ExtractElementOp>(
537  op.getLoc(), ext->getIn(), op.getPosition());
538  ext->recreateAndReplace(rewriter, op, newExtract);
539  return success();
540  }
541 };
542 
543 struct ExtensionOverExtractStridedSlice final
544  : NarrowingPattern<vector::ExtractStridedSliceOp> {
545  using NarrowingPattern::NarrowingPattern;
546 
547  LogicalResult matchAndRewrite(vector::ExtractStridedSliceOp op,
548  PatternRewriter &rewriter) const override {
549  FailureOr<ExtensionOp> ext =
550  ExtensionOp::from(op.getVector().getDefiningOp());
551  if (failed(ext))
552  return failure();
553 
554  VectorType origTy = op.getType();
555  VectorType extractTy =
556  origTy.cloneWith(origTy.getShape(), ext->getInElementType());
557  Value newExtract = rewriter.create<vector::ExtractStridedSliceOp>(
558  op.getLoc(), extractTy, ext->getIn(), op.getOffsets(), op.getSizes(),
559  op.getStrides());
560  ext->recreateAndReplace(rewriter, op, newExtract);
561  return success();
562  }
563 };
564 
565 /// Base pattern for `vector.insert` narrowing patterns.
566 template <typename InsertionOp>
567 struct ExtensionOverInsertionPattern : NarrowingPattern<InsertionOp> {
568  using NarrowingPattern<InsertionOp>::NarrowingPattern;
569 
570  /// Derived classes must provide a function to create the matching insertion
571  /// op based on the original op and new arguments.
572  virtual InsertionOp createInsertionOp(PatternRewriter &rewriter,
573  InsertionOp origInsert,
574  Value narrowValue,
575  Value narrowDest) const = 0;
576 
577  LogicalResult matchAndRewrite(InsertionOp op,
578  PatternRewriter &rewriter) const final {
579  FailureOr<ExtensionOp> ext =
580  ExtensionOp::from(op.getSource().getDefiningOp());
581  if (failed(ext))
582  return failure();
583 
584  FailureOr<InsertionOp> newInsert = createNarrowInsert(op, rewriter, *ext);
585  if (failed(newInsert))
586  return failure();
587  ext->recreateAndReplace(rewriter, op, *newInsert);
588  return success();
589  }
590 
591  FailureOr<InsertionOp> createNarrowInsert(InsertionOp op,
592  PatternRewriter &rewriter,
593  ExtensionOp insValue) const {
594  // Calculate the operand and result bitwidths. We can only apply narrowing
595  // when the inserted source value and destination vector require fewer bits
596  // than the result. Because the source and destination may have different
597  // bitwidths requirements, we have to find the common narrow bitwidth that
598  // is greater equal to the operand bitwidth requirements and still narrower
599  // than the result.
600  FailureOr<unsigned> origBitsRequired = calculateBitsRequired(op.getType());
601  if (failed(origBitsRequired))
602  return failure();
603 
604  // TODO: We could relax this check by disregarding bitwidth requirements of
605  // elements that we know will be replaced by the insertion.
606  FailureOr<unsigned> destBitsRequired =
607  calculateBitsRequired(op.getDest(), insValue.getKind());
608  if (failed(destBitsRequired) || *destBitsRequired >= *origBitsRequired)
609  return failure();
610 
611  FailureOr<unsigned> insertedBitsRequired =
612  calculateBitsRequired(insValue.getIn(), insValue.getKind());
613  if (failed(insertedBitsRequired) ||
614  *insertedBitsRequired >= *origBitsRequired)
615  return failure();
616 
617  // Find a narrower element type that satisfies the bitwidth requirements of
618  // both the source and the destination values.
619  unsigned newInsertionBits =
620  std::max(*destBitsRequired, *insertedBitsRequired);
621  FailureOr<Type> newVecTy =
622  this->getNarrowType(newInsertionBits, op.getType());
623  if (failed(newVecTy) || *newVecTy == op.getType())
624  return failure();
625 
626  FailureOr<Type> newInsertedValueTy =
627  this->getNarrowType(newInsertionBits, insValue.getType());
628  if (failed(newInsertedValueTy))
629  return failure();
630 
631  Location loc = op.getLoc();
632  Value narrowValue = rewriter.createOrFold<arith::TruncIOp>(
633  loc, *newInsertedValueTy, insValue.getResult());
634  Value narrowDest =
635  rewriter.createOrFold<arith::TruncIOp>(loc, *newVecTy, op.getDest());
636  return createInsertionOp(rewriter, op, narrowValue, narrowDest);
637  }
638 };
639 
640 struct ExtensionOverInsert final
641  : ExtensionOverInsertionPattern<vector::InsertOp> {
642  using ExtensionOverInsertionPattern::ExtensionOverInsertionPattern;
643 
644  vector::InsertOp createInsertionOp(PatternRewriter &rewriter,
645  vector::InsertOp origInsert,
646  Value narrowValue,
647  Value narrowDest) const override {
648  return rewriter.create<vector::InsertOp>(origInsert.getLoc(), narrowValue,
649  narrowDest,
650  origInsert.getMixedPosition());
651  }
652 };
653 
654 struct ExtensionOverInsertElement final
655  : ExtensionOverInsertionPattern<vector::InsertElementOp> {
656  using ExtensionOverInsertionPattern::ExtensionOverInsertionPattern;
657 
658  vector::InsertElementOp createInsertionOp(PatternRewriter &rewriter,
659  vector::InsertElementOp origInsert,
660  Value narrowValue,
661  Value narrowDest) const override {
662  return rewriter.create<vector::InsertElementOp>(
663  origInsert.getLoc(), narrowValue, narrowDest, origInsert.getPosition());
664  }
665 };
666 
667 struct ExtensionOverInsertStridedSlice final
668  : ExtensionOverInsertionPattern<vector::InsertStridedSliceOp> {
669  using ExtensionOverInsertionPattern::ExtensionOverInsertionPattern;
670 
671  vector::InsertStridedSliceOp
672  createInsertionOp(PatternRewriter &rewriter,
673  vector::InsertStridedSliceOp origInsert, Value narrowValue,
674  Value narrowDest) const override {
675  return rewriter.create<vector::InsertStridedSliceOp>(
676  origInsert.getLoc(), narrowValue, narrowDest, origInsert.getOffsets(),
677  origInsert.getStrides());
678  }
679 };
680 
681 struct ExtensionOverShapeCast final : NarrowingPattern<vector::ShapeCastOp> {
682  using NarrowingPattern::NarrowingPattern;
683 
684  LogicalResult matchAndRewrite(vector::ShapeCastOp op,
685  PatternRewriter &rewriter) const override {
686  FailureOr<ExtensionOp> ext =
687  ExtensionOp::from(op.getSource().getDefiningOp());
688  if (failed(ext))
689  return failure();
690 
691  VectorType origTy = op.getResultVectorType();
692  VectorType newTy =
693  origTy.cloneWith(origTy.getShape(), ext->getInElementType());
694  Value newCast =
695  rewriter.create<vector::ShapeCastOp>(op.getLoc(), newTy, ext->getIn());
696  ext->recreateAndReplace(rewriter, op, newCast);
697  return success();
698  }
699 };
700 
701 struct ExtensionOverTranspose final : NarrowingPattern<vector::TransposeOp> {
702  using NarrowingPattern::NarrowingPattern;
703 
704  LogicalResult matchAndRewrite(vector::TransposeOp op,
705  PatternRewriter &rewriter) const override {
706  FailureOr<ExtensionOp> ext =
707  ExtensionOp::from(op.getVector().getDefiningOp());
708  if (failed(ext))
709  return failure();
710 
711  VectorType origTy = op.getResultVectorType();
712  VectorType newTy =
713  origTy.cloneWith(origTy.getShape(), ext->getInElementType());
714  Value newTranspose = rewriter.create<vector::TransposeOp>(
715  op.getLoc(), newTy, ext->getIn(), op.getPermutation());
716  ext->recreateAndReplace(rewriter, op, newTranspose);
717  return success();
718  }
719 };
720 
721 struct ExtensionOverFlatTranspose final
722  : NarrowingPattern<vector::FlatTransposeOp> {
723  using NarrowingPattern::NarrowingPattern;
724 
725  LogicalResult matchAndRewrite(vector::FlatTransposeOp op,
726  PatternRewriter &rewriter) const override {
727  FailureOr<ExtensionOp> ext =
728  ExtensionOp::from(op.getMatrix().getDefiningOp());
729  if (failed(ext))
730  return failure();
731 
732  VectorType origTy = op.getType();
733  VectorType newTy =
734  origTy.cloneWith(origTy.getShape(), ext->getInElementType());
735  Value newTranspose = rewriter.create<vector::FlatTransposeOp>(
736  op.getLoc(), newTy, ext->getIn(), op.getRowsAttr(),
737  op.getColumnsAttr());
738  ext->recreateAndReplace(rewriter, op, newTranspose);
739  return success();
740  }
741 };
742 
743 //===----------------------------------------------------------------------===//
744 // Pass Definitions
745 //===----------------------------------------------------------------------===//
746 
747 struct ArithIntNarrowingPass final
748  : impl::ArithIntNarrowingBase<ArithIntNarrowingPass> {
749  using ArithIntNarrowingBase::ArithIntNarrowingBase;
750 
751  void runOnOperation() override {
752  if (bitwidthsSupported.empty() ||
753  llvm::is_contained(bitwidthsSupported, 0)) {
754  // Invalid pass options.
755  return signalPassFailure();
756  }
757 
758  Operation *op = getOperation();
759  MLIRContext *ctx = op->getContext();
760  RewritePatternSet patterns(ctx);
762  patterns, ArithIntNarrowingOptions{bitwidthsSupported});
763  if (failed(applyPatternsAndFoldGreedily(op, std::move(patterns))))
764  signalPassFailure();
765  }
766 };
767 } // namespace
768 
769 //===----------------------------------------------------------------------===//
770 // Public API
771 //===----------------------------------------------------------------------===//
772 
774  RewritePatternSet &patterns, const ArithIntNarrowingOptions &options) {
775  // Add commute patterns with a higher benefit. This is to expose more
776  // optimization opportunities to narrowing patterns.
777  patterns.add<ExtensionOverBroadcast, ExtensionOverExtract,
778  ExtensionOverExtractElement, ExtensionOverExtractStridedSlice,
779  ExtensionOverInsert, ExtensionOverInsertElement,
780  ExtensionOverInsertStridedSlice, ExtensionOverShapeCast,
781  ExtensionOverTranspose, ExtensionOverFlatTranspose>(
782  patterns.getContext(), options, PatternBenefit(2));
783 
784  patterns.add<AddIPattern, SubIPattern, MulIPattern, DivSIPattern,
785  DivUIPattern, MaxSIPattern, MaxUIPattern, MinSIPattern,
786  MinUIPattern, SIToFPPattern, UIToFPPattern, IndexCastSIPattern,
787  IndexCastUIPattern>(patterns.getContext(), options);
788 }
789 
790 } // namespace mlir::arith
static uint64_t zext(uint32_t arg)
static llvm::ManagedStatic< PassManagerOptions > options
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
static Type getElementType(Type type, ArrayRef< int32_t > indices, function_ref< InFlightDiagnostic(StringRef)> emitErrorFn)
Walks the given type hierarchy with the given indices, potentially down to component granularity,...
Definition: SPIRVOps.cpp:216
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
Definition: PatternMatch.h:34
MLIRContext * getContext() const
Definition: PatternMatch.h:822
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
Definition: PatternMatch.h:846
static FailureOr< int64_t > computeConstantBound(presburger::BoundType type, const Variable &var, StopConditionFn stopCondition=nullptr, bool closedUB=false)
Compute a constant bound for the given variable.
void populateArithIntNarrowingPatterns(RewritePatternSet &patterns, const ArithIntNarrowingOptions &options)
Add patterns for integer bitwidth narrowing.
@ Type
An inlay hint that for a type annotation.
Kind
An enumeration of the kinds of predicates.
Definition: Predicate.h:44
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
Definition: Matchers.h:401
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:62
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:56
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
LogicalResult applyPatternsAndFoldGreedily(Region &region, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr)
Rewrite ops in the given region, which must be isolated from above, by repeatedly applying the highes...
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
Definition: Matchers.h:310
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value.
Definition: LogicalResult.h:72