MLIR  20.0.0git
IntNarrowing.cpp
Go to the documentation of this file.
1 //===- IntNarrowing.cpp - Integer bitwidth reduction optimizations --------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
10 
17 #include "mlir/IR/BuiltinTypes.h"
18 #include "mlir/IR/MLIRContext.h"
19 #include "mlir/IR/Matchers.h"
20 #include "mlir/IR/Operation.h"
21 #include "mlir/IR/PatternMatch.h"
22 #include "mlir/IR/TypeUtilities.h"
25 #include "llvm/ADT/STLExtras.h"
26 #include "llvm/ADT/SmallVector.h"
27 #include <cassert>
28 #include <cstdint>
29 
30 namespace mlir::arith {
31 #define GEN_PASS_DEF_ARITHINTNARROWING
32 #include "mlir/Dialect/Arith/Transforms/Passes.h.inc"
33 } // namespace mlir::arith
34 
35 namespace mlir::arith {
36 namespace {
37 //===----------------------------------------------------------------------===//
38 // Common Helpers
39 //===----------------------------------------------------------------------===//
40 
41 /// The base for integer bitwidth narrowing patterns.
42 template <typename SourceOp>
43 struct NarrowingPattern : OpRewritePattern<SourceOp> {
44  NarrowingPattern(MLIRContext *ctx, const ArithIntNarrowingOptions &options,
45  PatternBenefit benefit = 1)
46  : OpRewritePattern<SourceOp>(ctx, benefit),
47  supportedBitwidths(options.bitwidthsSupported.begin(),
48  options.bitwidthsSupported.end()) {
49  assert(!supportedBitwidths.empty() && "Invalid options");
50  assert(!llvm::is_contained(supportedBitwidths, 0) && "Invalid bitwidth");
51  llvm::sort(supportedBitwidths);
52  }
53 
54  FailureOr<unsigned>
55  getNarrowestCompatibleBitwidth(unsigned bitsRequired) const {
56  for (unsigned candidate : supportedBitwidths)
57  if (candidate >= bitsRequired)
58  return candidate;
59 
60  return failure();
61  }
62 
63  /// Returns the narrowest supported type that fits `bitsRequired`.
64  FailureOr<Type> getNarrowType(unsigned bitsRequired, Type origTy) const {
65  assert(origTy);
66  FailureOr<unsigned> bestBitwidth =
67  getNarrowestCompatibleBitwidth(bitsRequired);
68  if (failed(bestBitwidth))
69  return failure();
70 
71  Type elemTy = getElementTypeOrSelf(origTy);
72  if (!isa<IntegerType>(elemTy))
73  return failure();
74 
75  auto newElemTy = IntegerType::get(origTy.getContext(), *bestBitwidth);
76  if (newElemTy == elemTy)
77  return failure();
78 
79  if (origTy == elemTy)
80  return newElemTy;
81 
82  if (auto shapedTy = dyn_cast<ShapedType>(origTy))
83  if (dyn_cast<IntegerType>(shapedTy.getElementType()))
84  return shapedTy.clone(shapedTy.getShape(), newElemTy);
85 
86  return failure();
87  }
88 
89 private:
90  // Supported integer bitwidths in the ascending order.
91  llvm::SmallVector<unsigned, 6> supportedBitwidths;
92 };
93 
94 /// Returns the integer bitwidth required to represent `type`.
95 FailureOr<unsigned> calculateBitsRequired(Type type) {
96  assert(type);
97  if (auto intTy = dyn_cast<IntegerType>(getElementTypeOrSelf(type)))
98  return intTy.getWidth();
99 
100  return failure();
101 }
102 
103 enum class ExtensionKind { Sign, Zero };
104 
105 /// Wrapper around `arith::ExtSIOp` and `arith::ExtUIOp` ops that abstracts away
106 /// the exact op type. Exposes helper functions to query the types, operands,
107 /// and the result. This is so that we can handle both extension kinds without
108 /// needing to use templates or branching.
109 class ExtensionOp {
110 public:
111  /// Attemps to create a new extension op from `op`. Returns an extension op
112  /// wrapper when `op` is either `arith.extsi` or `arith.extui`, and failure
113  /// otherwise.
114  static FailureOr<ExtensionOp> from(Operation *op) {
115  if (dyn_cast_or_null<arith::ExtSIOp>(op))
116  return ExtensionOp{op, ExtensionKind::Sign};
117  if (dyn_cast_or_null<arith::ExtUIOp>(op))
118  return ExtensionOp{op, ExtensionKind::Zero};
119 
120  return failure();
121  }
122 
123  ExtensionOp(const ExtensionOp &) = default;
124  ExtensionOp &operator=(const ExtensionOp &) = default;
125 
126  /// Creates a new extension op of the same kind.
127  Operation *recreate(PatternRewriter &rewriter, Location loc, Type newType,
128  Value in) {
129  if (kind == ExtensionKind::Sign)
130  return rewriter.create<arith::ExtSIOp>(loc, newType, in);
131 
132  return rewriter.create<arith::ExtUIOp>(loc, newType, in);
133  }
134 
135  /// Replaces `toReplace` with a new extension op of the same kind.
136  void recreateAndReplace(PatternRewriter &rewriter, Operation *toReplace,
137  Value in) {
138  assert(toReplace->getNumResults() == 1);
139  Type newType = toReplace->getResult(0).getType();
140  Operation *newOp = recreate(rewriter, toReplace->getLoc(), newType, in);
141  rewriter.replaceOp(toReplace, newOp->getResult(0));
142  }
143 
144  ExtensionKind getKind() { return kind; }
145 
146  Value getResult() { return op->getResult(0); }
147  Value getIn() { return op->getOperand(0); }
148 
149  Type getType() { return getResult().getType(); }
151  Type getInType() { return getIn().getType(); }
152  Type getInElementType() { return getElementTypeOrSelf(getInType()); }
153 
154 private:
155  ExtensionOp(Operation *op, ExtensionKind kind) : op(op), kind(kind) {
156  assert(op);
157  assert((isa<arith::ExtSIOp, arith::ExtUIOp>(op)) && "Not an extension op");
158  }
159  Operation *op = nullptr;
160  ExtensionKind kind = {};
161 };
162 
163 /// Returns the integer bitwidth required to represent `value`.
164 unsigned calculateBitsRequired(const APInt &value,
165  ExtensionKind lookThroughExtension) {
166  // For unsigned values, we only need the active bits. As a special case, zero
167  // requires one bit.
168  if (lookThroughExtension == ExtensionKind::Zero)
169  return std::max(value.getActiveBits(), 1u);
170 
171  // If a signed value is nonnegative, we need one extra bit for the sign.
172  if (value.isNonNegative())
173  return value.getActiveBits() + 1;
174 
175  // For the signed min, we need all the bits.
176  if (value.isMinSignedValue())
177  return value.getBitWidth();
178 
179  // For negative values, we need all the non-sign bits and one extra bit for
180  // the sign.
181  return value.getBitWidth() - value.getNumSignBits() + 1;
182 }
183 
184 /// Returns the integer bitwidth required to represent `value`.
185 /// Looks through either sign- or zero-extension as specified by
186 /// `lookThroughExtension`.
187 FailureOr<unsigned> calculateBitsRequired(Value value,
188  ExtensionKind lookThroughExtension) {
189  // Handle constants.
190  if (TypedAttr attr; matchPattern(value, m_Constant(&attr))) {
191  if (auto intAttr = dyn_cast<IntegerAttr>(attr))
192  return calculateBitsRequired(intAttr.getValue(), lookThroughExtension);
193 
194  if (auto elemsAttr = dyn_cast<DenseElementsAttr>(attr)) {
195  if (elemsAttr.getElementType().isIntOrIndex()) {
196  if (elemsAttr.isSplat())
197  return calculateBitsRequired(elemsAttr.getSplatValue<APInt>(),
198  lookThroughExtension);
199 
200  unsigned maxBits = 1;
201  for (const APInt &elemValue : elemsAttr.getValues<APInt>())
202  maxBits = std::max(
203  maxBits, calculateBitsRequired(elemValue, lookThroughExtension));
204  return maxBits;
205  }
206  }
207  }
208 
209  if (lookThroughExtension == ExtensionKind::Sign) {
210  if (auto sext = value.getDefiningOp<arith::ExtSIOp>())
211  return calculateBitsRequired(sext.getIn().getType());
212  } else if (lookThroughExtension == ExtensionKind::Zero) {
213  if (auto zext = value.getDefiningOp<arith::ExtUIOp>())
214  return calculateBitsRequired(zext.getIn().getType());
215  }
216 
217  // If nothing else worked, return the type requirements for this element type.
218  return calculateBitsRequired(value.getType());
219 }
220 
221 /// Base pattern for arith binary ops.
222 /// Example:
223 /// ```
224 /// %lhs = arith.extsi %a : i8 to i32
225 /// %rhs = arith.extsi %b : i8 to i32
226 /// %r = arith.addi %lhs, %rhs : i32
227 /// ==>
228 /// %lhs = arith.extsi %a : i8 to i16
229 /// %rhs = arith.extsi %b : i8 to i16
230 /// %add = arith.addi %lhs, %rhs : i16
231 /// %r = arith.extsi %add : i16 to i32
232 /// ```
233 template <typename BinaryOp>
234 struct BinaryOpNarrowingPattern : NarrowingPattern<BinaryOp> {
235  using NarrowingPattern<BinaryOp>::NarrowingPattern;
236 
237  /// Returns the number of bits required to represent the full result, assuming
238  /// that both operands are `operandBits`-wide. Derived classes must implement
239  /// this, taking into account `BinaryOp` semantics.
240  virtual unsigned getResultBitsProduced(unsigned operandBits) const = 0;
241 
242  /// Customization point for patterns that should only apply with
243  /// zero/sign-extension ops as arguments.
244  virtual bool isSupported(ExtensionOp) const { return true; }
245 
246  LogicalResult matchAndRewrite(BinaryOp op,
247  PatternRewriter &rewriter) const final {
248  Type origTy = op.getType();
249  FailureOr<unsigned> resultBits = calculateBitsRequired(origTy);
250  if (failed(resultBits))
251  return failure();
252 
253  // For the optimization to apply, we expect the lhs to be an extension op,
254  // and for the rhs to either be the same extension op or a constant.
255  FailureOr<ExtensionOp> ext = ExtensionOp::from(op.getLhs().getDefiningOp());
256  if (failed(ext) || !isSupported(*ext))
257  return failure();
258 
259  FailureOr<unsigned> lhsBitsRequired =
260  calculateBitsRequired(ext->getIn(), ext->getKind());
261  if (failed(lhsBitsRequired) || *lhsBitsRequired >= *resultBits)
262  return failure();
263 
264  FailureOr<unsigned> rhsBitsRequired =
265  calculateBitsRequired(op.getRhs(), ext->getKind());
266  if (failed(rhsBitsRequired) || *rhsBitsRequired >= *resultBits)
267  return failure();
268 
269  // Negotiate a common bit requirements for both lhs and rhs, accounting for
270  // the result requiring more bits than the operands.
271  unsigned commonBitsRequired =
272  getResultBitsProduced(std::max(*lhsBitsRequired, *rhsBitsRequired));
273  FailureOr<Type> narrowTy = this->getNarrowType(commonBitsRequired, origTy);
274  if (failed(narrowTy) || calculateBitsRequired(*narrowTy) >= *resultBits)
275  return failure();
276 
277  Location loc = op.getLoc();
278  Value newLhs =
279  rewriter.createOrFold<arith::TruncIOp>(loc, *narrowTy, op.getLhs());
280  Value newRhs =
281  rewriter.createOrFold<arith::TruncIOp>(loc, *narrowTy, op.getRhs());
282  Value newAdd = rewriter.create<BinaryOp>(loc, newLhs, newRhs);
283  ext->recreateAndReplace(rewriter, op, newAdd);
284  return success();
285  }
286 };
287 
288 //===----------------------------------------------------------------------===//
289 // AddIOp Pattern
290 //===----------------------------------------------------------------------===//
291 
292 struct AddIPattern final : BinaryOpNarrowingPattern<arith::AddIOp> {
293  using BinaryOpNarrowingPattern::BinaryOpNarrowingPattern;
294 
295  // Addition may require one extra bit for the result.
296  // Example: `UINT8_MAX + 1 == 255 + 1 == 256`.
297  unsigned getResultBitsProduced(unsigned operandBits) const override {
298  return operandBits + 1;
299  }
300 };
301 
302 //===----------------------------------------------------------------------===//
303 // SubIOp Pattern
304 //===----------------------------------------------------------------------===//
305 
306 struct SubIPattern final : BinaryOpNarrowingPattern<arith::SubIOp> {
307  using BinaryOpNarrowingPattern::BinaryOpNarrowingPattern;
308 
309  // This optimization only applies to signed arguments.
310  bool isSupported(ExtensionOp ext) const override {
311  return ext.getKind() == ExtensionKind::Sign;
312  }
313 
314  // Subtraction may require one extra bit for the result.
315  // Example: `INT8_MAX - (-1) == 127 - (-1) == 128`.
316  unsigned getResultBitsProduced(unsigned operandBits) const override {
317  return operandBits + 1;
318  }
319 };
320 
321 //===----------------------------------------------------------------------===//
322 // MulIOp Pattern
323 //===----------------------------------------------------------------------===//
324 
325 struct MulIPattern final : BinaryOpNarrowingPattern<arith::MulIOp> {
326  using BinaryOpNarrowingPattern::BinaryOpNarrowingPattern;
327 
328  // Multiplication may require up double the operand bits.
329  // Example: `UNT8_MAX * UINT8_MAX == 255 * 255 == 65025`.
330  unsigned getResultBitsProduced(unsigned operandBits) const override {
331  return 2 * operandBits;
332  }
333 };
334 
335 //===----------------------------------------------------------------------===//
336 // DivSIOp Pattern
337 //===----------------------------------------------------------------------===//
338 
339 struct DivSIPattern final : BinaryOpNarrowingPattern<arith::DivSIOp> {
340  using BinaryOpNarrowingPattern::BinaryOpNarrowingPattern;
341 
342  // This optimization only applies to signed arguments.
343  bool isSupported(ExtensionOp ext) const override {
344  return ext.getKind() == ExtensionKind::Sign;
345  }
346 
347  // Unlike multiplication, signed division requires only one more result bit.
348  // Example: `INT8_MIN / (-1) == -128 / (-1) == 128`.
349  unsigned getResultBitsProduced(unsigned operandBits) const override {
350  return operandBits + 1;
351  }
352 };
353 
354 //===----------------------------------------------------------------------===//
355 // DivUIOp Pattern
356 //===----------------------------------------------------------------------===//
357 
358 struct DivUIPattern final : BinaryOpNarrowingPattern<arith::DivUIOp> {
359  using BinaryOpNarrowingPattern::BinaryOpNarrowingPattern;
360 
361  // This optimization only applies to unsigned arguments.
362  bool isSupported(ExtensionOp ext) const override {
363  return ext.getKind() == ExtensionKind::Zero;
364  }
365 
366  // Unsigned division does not require any extra result bits.
367  unsigned getResultBitsProduced(unsigned operandBits) const override {
368  return operandBits;
369  }
370 };
371 
372 //===----------------------------------------------------------------------===//
373 // Min/Max Patterns
374 //===----------------------------------------------------------------------===//
375 
376 template <typename MinMaxOp, ExtensionKind Kind>
377 struct MinMaxPattern final : BinaryOpNarrowingPattern<MinMaxOp> {
378  using BinaryOpNarrowingPattern<MinMaxOp>::BinaryOpNarrowingPattern;
379 
380  bool isSupported(ExtensionOp ext) const override {
381  return ext.getKind() == Kind;
382  }
383 
384  // Min/max returns one of the arguments and does not require any extra result
385  // bits.
386  unsigned getResultBitsProduced(unsigned operandBits) const override {
387  return operandBits;
388  }
389 };
390 using MaxSIPattern = MinMaxPattern<arith::MaxSIOp, ExtensionKind::Sign>;
391 using MaxUIPattern = MinMaxPattern<arith::MaxUIOp, ExtensionKind::Zero>;
392 using MinSIPattern = MinMaxPattern<arith::MinSIOp, ExtensionKind::Sign>;
393 using MinUIPattern = MinMaxPattern<arith::MinUIOp, ExtensionKind::Zero>;
394 
395 //===----------------------------------------------------------------------===//
396 // *IToFPOp Patterns
397 //===----------------------------------------------------------------------===//
398 
399 template <typename IToFPOp, ExtensionKind Extension>
400 struct IToFPPattern final : NarrowingPattern<IToFPOp> {
401  using NarrowingPattern<IToFPOp>::NarrowingPattern;
402 
403  LogicalResult matchAndRewrite(IToFPOp op,
404  PatternRewriter &rewriter) const override {
405  FailureOr<unsigned> narrowestWidth =
406  calculateBitsRequired(op.getIn(), Extension);
407  if (failed(narrowestWidth))
408  return failure();
409 
410  FailureOr<Type> narrowTy =
411  this->getNarrowType(*narrowestWidth, op.getIn().getType());
412  if (failed(narrowTy))
413  return failure();
414 
415  Value newIn = rewriter.createOrFold<arith::TruncIOp>(op.getLoc(), *narrowTy,
416  op.getIn());
417  rewriter.replaceOpWithNewOp<IToFPOp>(op, op.getType(), newIn);
418  return success();
419  }
420 };
421 using SIToFPPattern = IToFPPattern<arith::SIToFPOp, ExtensionKind::Sign>;
422 using UIToFPPattern = IToFPPattern<arith::UIToFPOp, ExtensionKind::Zero>;
423 
424 //===----------------------------------------------------------------------===//
425 // Index Cast Patterns
426 //===----------------------------------------------------------------------===//
427 
428 // These rely on the `ValueBounds` interface for index values. For example, we
429 // can often statically tell index value bounds of loop induction variables.
430 
431 template <typename CastOp, ExtensionKind Kind>
432 struct IndexCastPattern final : NarrowingPattern<CastOp> {
433  using NarrowingPattern<CastOp>::NarrowingPattern;
434 
435  LogicalResult matchAndRewrite(CastOp op,
436  PatternRewriter &rewriter) const override {
437  Value in = op.getIn();
438  // We only support scalar index -> integer casts.
439  if (!isa<IndexType>(in.getType()))
440  return failure();
441 
442  // Check the lower bound in both the signed and unsigned cast case. We
443  // conservatively assume that even unsigned casts may be performed on
444  // negative indices.
445  FailureOr<int64_t> lb = ValueBoundsConstraintSet::computeConstantBound(
447  if (failed(lb))
448  return failure();
449 
450  FailureOr<int64_t> ub = ValueBoundsConstraintSet::computeConstantBound(
452  /*stopCondition=*/nullptr, /*closedUB=*/true);
453  if (failed(ub))
454  return failure();
455 
456  assert(*lb <= *ub && "Invalid bounds");
457  unsigned lbBitsRequired = calculateBitsRequired(APInt(64, *lb), Kind);
458  unsigned ubBitsRequired = calculateBitsRequired(APInt(64, *ub), Kind);
459  unsigned bitsRequired = std::max(lbBitsRequired, ubBitsRequired);
460 
461  IntegerType resultTy = cast<IntegerType>(op.getType());
462  if (resultTy.getWidth() <= bitsRequired)
463  return failure();
464 
465  FailureOr<Type> narrowTy = this->getNarrowType(bitsRequired, resultTy);
466  if (failed(narrowTy))
467  return failure();
468 
469  Value newCast = rewriter.create<CastOp>(op.getLoc(), *narrowTy, op.getIn());
470 
471  if (Kind == ExtensionKind::Sign)
472  rewriter.replaceOpWithNewOp<arith::ExtSIOp>(op, resultTy, newCast);
473  else
474  rewriter.replaceOpWithNewOp<arith::ExtUIOp>(op, resultTy, newCast);
475  return success();
476  }
477 };
478 using IndexCastSIPattern =
479  IndexCastPattern<arith::IndexCastOp, ExtensionKind::Sign>;
480 using IndexCastUIPattern =
481  IndexCastPattern<arith::IndexCastUIOp, ExtensionKind::Zero>;
482 
483 //===----------------------------------------------------------------------===//
484 // Patterns to Commute Extension Ops
485 //===----------------------------------------------------------------------===//
486 
487 struct ExtensionOverBroadcast final : NarrowingPattern<vector::BroadcastOp> {
488  using NarrowingPattern::NarrowingPattern;
489 
490  LogicalResult matchAndRewrite(vector::BroadcastOp op,
491  PatternRewriter &rewriter) const override {
492  FailureOr<ExtensionOp> ext =
493  ExtensionOp::from(op.getSource().getDefiningOp());
494  if (failed(ext))
495  return failure();
496 
497  VectorType origTy = op.getResultVectorType();
498  VectorType newTy =
499  origTy.cloneWith(origTy.getShape(), ext->getInElementType());
500  Value newBroadcast =
501  rewriter.create<vector::BroadcastOp>(op.getLoc(), newTy, ext->getIn());
502  ext->recreateAndReplace(rewriter, op, newBroadcast);
503  return success();
504  }
505 };
506 
507 struct ExtensionOverExtract final : NarrowingPattern<vector::ExtractOp> {
508  using NarrowingPattern::NarrowingPattern;
509 
510  LogicalResult matchAndRewrite(vector::ExtractOp op,
511  PatternRewriter &rewriter) const override {
512  FailureOr<ExtensionOp> ext =
513  ExtensionOp::from(op.getVector().getDefiningOp());
514  if (failed(ext))
515  return failure();
516 
517  Value newExtract = rewriter.create<vector::ExtractOp>(
518  op.getLoc(), ext->getIn(), op.getMixedPosition());
519  ext->recreateAndReplace(rewriter, op, newExtract);
520  return success();
521  }
522 };
523 
524 struct ExtensionOverExtractElement final
525  : NarrowingPattern<vector::ExtractElementOp> {
526  using NarrowingPattern::NarrowingPattern;
527 
528  LogicalResult matchAndRewrite(vector::ExtractElementOp op,
529  PatternRewriter &rewriter) const override {
530  FailureOr<ExtensionOp> ext =
531  ExtensionOp::from(op.getVector().getDefiningOp());
532  if (failed(ext))
533  return failure();
534 
535  Value newExtract = rewriter.create<vector::ExtractElementOp>(
536  op.getLoc(), ext->getIn(), op.getPosition());
537  ext->recreateAndReplace(rewriter, op, newExtract);
538  return success();
539  }
540 };
541 
542 struct ExtensionOverExtractStridedSlice final
543  : NarrowingPattern<vector::ExtractStridedSliceOp> {
544  using NarrowingPattern::NarrowingPattern;
545 
546  LogicalResult matchAndRewrite(vector::ExtractStridedSliceOp op,
547  PatternRewriter &rewriter) const override {
548  FailureOr<ExtensionOp> ext =
549  ExtensionOp::from(op.getVector().getDefiningOp());
550  if (failed(ext))
551  return failure();
552 
553  VectorType origTy = op.getType();
554  VectorType extractTy =
555  origTy.cloneWith(origTy.getShape(), ext->getInElementType());
556  Value newExtract = rewriter.create<vector::ExtractStridedSliceOp>(
557  op.getLoc(), extractTy, ext->getIn(), op.getOffsets(), op.getSizes(),
558  op.getStrides());
559  ext->recreateAndReplace(rewriter, op, newExtract);
560  return success();
561  }
562 };
563 
564 /// Base pattern for `vector.insert` narrowing patterns.
565 template <typename InsertionOp>
566 struct ExtensionOverInsertionPattern : NarrowingPattern<InsertionOp> {
567  using NarrowingPattern<InsertionOp>::NarrowingPattern;
568 
569  /// Derived classes must provide a function to create the matching insertion
570  /// op based on the original op and new arguments.
571  virtual InsertionOp createInsertionOp(PatternRewriter &rewriter,
572  InsertionOp origInsert,
573  Value narrowValue,
574  Value narrowDest) const = 0;
575 
576  LogicalResult matchAndRewrite(InsertionOp op,
577  PatternRewriter &rewriter) const final {
578  FailureOr<ExtensionOp> ext =
579  ExtensionOp::from(op.getSource().getDefiningOp());
580  if (failed(ext))
581  return failure();
582 
583  FailureOr<InsertionOp> newInsert = createNarrowInsert(op, rewriter, *ext);
584  if (failed(newInsert))
585  return failure();
586  ext->recreateAndReplace(rewriter, op, *newInsert);
587  return success();
588  }
589 
590  FailureOr<InsertionOp> createNarrowInsert(InsertionOp op,
591  PatternRewriter &rewriter,
592  ExtensionOp insValue) const {
593  // Calculate the operand and result bitwidths. We can only apply narrowing
594  // when the inserted source value and destination vector require fewer bits
595  // than the result. Because the source and destination may have different
596  // bitwidths requirements, we have to find the common narrow bitwidth that
597  // is greater equal to the operand bitwidth requirements and still narrower
598  // than the result.
599  FailureOr<unsigned> origBitsRequired = calculateBitsRequired(op.getType());
600  if (failed(origBitsRequired))
601  return failure();
602 
603  // TODO: We could relax this check by disregarding bitwidth requirements of
604  // elements that we know will be replaced by the insertion.
605  FailureOr<unsigned> destBitsRequired =
606  calculateBitsRequired(op.getDest(), insValue.getKind());
607  if (failed(destBitsRequired) || *destBitsRequired >= *origBitsRequired)
608  return failure();
609 
610  FailureOr<unsigned> insertedBitsRequired =
611  calculateBitsRequired(insValue.getIn(), insValue.getKind());
612  if (failed(insertedBitsRequired) ||
613  *insertedBitsRequired >= *origBitsRequired)
614  return failure();
615 
616  // Find a narrower element type that satisfies the bitwidth requirements of
617  // both the source and the destination values.
618  unsigned newInsertionBits =
619  std::max(*destBitsRequired, *insertedBitsRequired);
620  FailureOr<Type> newVecTy =
621  this->getNarrowType(newInsertionBits, op.getType());
622  if (failed(newVecTy) || *newVecTy == op.getType())
623  return failure();
624 
625  FailureOr<Type> newInsertedValueTy =
626  this->getNarrowType(newInsertionBits, insValue.getType());
627  if (failed(newInsertedValueTy))
628  return failure();
629 
630  Location loc = op.getLoc();
631  Value narrowValue = rewriter.createOrFold<arith::TruncIOp>(
632  loc, *newInsertedValueTy, insValue.getResult());
633  Value narrowDest =
634  rewriter.createOrFold<arith::TruncIOp>(loc, *newVecTy, op.getDest());
635  return createInsertionOp(rewriter, op, narrowValue, narrowDest);
636  }
637 };
638 
639 struct ExtensionOverInsert final
640  : ExtensionOverInsertionPattern<vector::InsertOp> {
641  using ExtensionOverInsertionPattern::ExtensionOverInsertionPattern;
642 
643  vector::InsertOp createInsertionOp(PatternRewriter &rewriter,
644  vector::InsertOp origInsert,
645  Value narrowValue,
646  Value narrowDest) const override {
647  return rewriter.create<vector::InsertOp>(origInsert.getLoc(), narrowValue,
648  narrowDest,
649  origInsert.getMixedPosition());
650  }
651 };
652 
653 struct ExtensionOverInsertElement final
654  : ExtensionOverInsertionPattern<vector::InsertElementOp> {
655  using ExtensionOverInsertionPattern::ExtensionOverInsertionPattern;
656 
657  vector::InsertElementOp createInsertionOp(PatternRewriter &rewriter,
658  vector::InsertElementOp origInsert,
659  Value narrowValue,
660  Value narrowDest) const override {
661  return rewriter.create<vector::InsertElementOp>(
662  origInsert.getLoc(), narrowValue, narrowDest, origInsert.getPosition());
663  }
664 };
665 
666 struct ExtensionOverInsertStridedSlice final
667  : ExtensionOverInsertionPattern<vector::InsertStridedSliceOp> {
668  using ExtensionOverInsertionPattern::ExtensionOverInsertionPattern;
669 
670  vector::InsertStridedSliceOp
671  createInsertionOp(PatternRewriter &rewriter,
672  vector::InsertStridedSliceOp origInsert, Value narrowValue,
673  Value narrowDest) const override {
674  return rewriter.create<vector::InsertStridedSliceOp>(
675  origInsert.getLoc(), narrowValue, narrowDest, origInsert.getOffsets(),
676  origInsert.getStrides());
677  }
678 };
679 
680 struct ExtensionOverShapeCast final : NarrowingPattern<vector::ShapeCastOp> {
681  using NarrowingPattern::NarrowingPattern;
682 
683  LogicalResult matchAndRewrite(vector::ShapeCastOp op,
684  PatternRewriter &rewriter) const override {
685  FailureOr<ExtensionOp> ext =
686  ExtensionOp::from(op.getSource().getDefiningOp());
687  if (failed(ext))
688  return failure();
689 
690  VectorType origTy = op.getResultVectorType();
691  VectorType newTy =
692  origTy.cloneWith(origTy.getShape(), ext->getInElementType());
693  Value newCast =
694  rewriter.create<vector::ShapeCastOp>(op.getLoc(), newTy, ext->getIn());
695  ext->recreateAndReplace(rewriter, op, newCast);
696  return success();
697  }
698 };
699 
700 struct ExtensionOverTranspose final : NarrowingPattern<vector::TransposeOp> {
701  using NarrowingPattern::NarrowingPattern;
702 
703  LogicalResult matchAndRewrite(vector::TransposeOp op,
704  PatternRewriter &rewriter) const override {
705  FailureOr<ExtensionOp> ext =
706  ExtensionOp::from(op.getVector().getDefiningOp());
707  if (failed(ext))
708  return failure();
709 
710  VectorType origTy = op.getResultVectorType();
711  VectorType newTy =
712  origTy.cloneWith(origTy.getShape(), ext->getInElementType());
713  Value newTranspose = rewriter.create<vector::TransposeOp>(
714  op.getLoc(), newTy, ext->getIn(), op.getPermutation());
715  ext->recreateAndReplace(rewriter, op, newTranspose);
716  return success();
717  }
718 };
719 
720 struct ExtensionOverFlatTranspose final
721  : NarrowingPattern<vector::FlatTransposeOp> {
722  using NarrowingPattern::NarrowingPattern;
723 
724  LogicalResult matchAndRewrite(vector::FlatTransposeOp op,
725  PatternRewriter &rewriter) const override {
726  FailureOr<ExtensionOp> ext =
727  ExtensionOp::from(op.getMatrix().getDefiningOp());
728  if (failed(ext))
729  return failure();
730 
731  VectorType origTy = op.getType();
732  VectorType newTy =
733  origTy.cloneWith(origTy.getShape(), ext->getInElementType());
734  Value newTranspose = rewriter.create<vector::FlatTransposeOp>(
735  op.getLoc(), newTy, ext->getIn(), op.getRowsAttr(),
736  op.getColumnsAttr());
737  ext->recreateAndReplace(rewriter, op, newTranspose);
738  return success();
739  }
740 };
741 
742 //===----------------------------------------------------------------------===//
743 // Pass Definitions
744 //===----------------------------------------------------------------------===//
745 
746 struct ArithIntNarrowingPass final
747  : impl::ArithIntNarrowingBase<ArithIntNarrowingPass> {
748  using ArithIntNarrowingBase::ArithIntNarrowingBase;
749 
750  void runOnOperation() override {
751  if (bitwidthsSupported.empty() ||
752  llvm::is_contained(bitwidthsSupported, 0)) {
753  // Invalid pass options.
754  return signalPassFailure();
755  }
756 
757  Operation *op = getOperation();
758  MLIRContext *ctx = op->getContext();
759  RewritePatternSet patterns(ctx);
761  patterns, ArithIntNarrowingOptions{
762  llvm::to_vector_of<unsigned>(bitwidthsSupported)});
763  if (failed(applyPatternsAndFoldGreedily(op, std::move(patterns))))
764  signalPassFailure();
765  }
766 };
767 } // namespace
768 
769 //===----------------------------------------------------------------------===//
770 // Public API
771 //===----------------------------------------------------------------------===//
772 
774  RewritePatternSet &patterns, const ArithIntNarrowingOptions &options) {
775  // Add commute patterns with a higher benefit. This is to expose more
776  // optimization opportunities to narrowing patterns.
777  patterns.add<ExtensionOverBroadcast, ExtensionOverExtract,
778  ExtensionOverExtractElement, ExtensionOverExtractStridedSlice,
779  ExtensionOverInsert, ExtensionOverInsertElement,
780  ExtensionOverInsertStridedSlice, ExtensionOverShapeCast,
781  ExtensionOverTranspose, ExtensionOverFlatTranspose>(
782  patterns.getContext(), options, PatternBenefit(2));
783 
784  patterns.add<AddIPattern, SubIPattern, MulIPattern, DivSIPattern,
785  DivUIPattern, MaxSIPattern, MaxUIPattern, MinSIPattern,
786  MinUIPattern, SIToFPPattern, UIToFPPattern, IndexCastSIPattern,
787  IndexCastUIPattern>(patterns.getContext(), options);
788 }
789 
790 } // namespace mlir::arith
static uint64_t zext(uint32_t arg)
static llvm::ManagedStatic< PassManagerOptions > options
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
static Type getElementType(Type type, ArrayRef< int32_t > indices, function_ref< InFlightDiagnostic(StringRef)> emitErrorFn)
Walks the given type hierarchy with the given indices, potentially down to component granularity,...
Definition: SPIRVOps.cpp:215
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
Definition: PatternMatch.h:34
MLIRContext * getContext() const
Definition: PatternMatch.h:823
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
Definition: PatternMatch.h:847
static FailureOr< int64_t > computeConstantBound(presburger::BoundType type, const Variable &var, StopConditionFn stopCondition=nullptr, bool closedUB=false)
Compute a constant bound for the given variable.
void populateArithIntNarrowingPatterns(RewritePatternSet &patterns, const ArithIntNarrowingOptions &options)
Add patterns for integer bitwidth narrowing.
@ Type
An inlay hint that for a type annotation.
Kind
An enumeration of the kinds of predicates.
Definition: Predicate.h:44
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
Definition: Matchers.h:485
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
Definition: Utils.cpp:305
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
LogicalResult applyPatternsAndFoldGreedily(Region &region, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr)
Rewrite ops in the given region, which must be isolated from above, by repeatedly applying the highes...
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
Definition: Matchers.h:369