MLIR  22.0.0git
VectorMaskElimination.cpp
Go to the documentation of this file.
1 //===- VectorMaskElimination.cpp - Eliminate Vector Masks -----------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
13 
14 using namespace mlir;
15 using namespace mlir::vector;
16 namespace {
17 
18 /// Attempts to resolve a (scalable) CreateMaskOp to an all-true constant mask.
19 /// All-true masks can then be eliminated by simple folds.
20 LogicalResult resolveAllTrueCreateMaskOp(IRRewriter &rewriter,
21  vector::CreateMaskOp createMaskOp,
22  VscaleRange vscaleRange) {
23  auto maskType = createMaskOp.getVectorType();
24  auto maskTypeDimScalableFlags = maskType.getScalableDims();
25  auto maskTypeDimSizes = maskType.getShape();
26 
27  struct UnknownMaskDim {
28  size_t position;
29  Value dimSize;
30  };
31 
32  // Loop over the CreateMaskOp operands and collect unknown dims (i.e. dims
33  // that are not obviously constant). If any constant dimension is not all-true
34  // bail out early (as this transform only trying to resolve all-true masks).
35  // This avoids doing value-bounds anaylis in cases like:
36  // `%mask = vector.create_mask %dynamicValue, %c2 : vector<8x4xi1>`
37  // ...where it is known the mask is not all-true by looking at `%c2`.
38  SmallVector<UnknownMaskDim> unknownDims;
39  for (auto [i, dimSize] : llvm::enumerate(createMaskOp.getOperands())) {
40  if (auto intSize = getConstantIntValue(dimSize)) {
41  // Mask not all-true for this dim.
42  if (maskTypeDimScalableFlags[i] || intSize < maskTypeDimSizes[i])
43  return failure();
44  } else if (auto vscaleMultiplier = getConstantVscaleMultiplier(dimSize)) {
45  // Mask not all-true for this dim.
46  if (vscaleMultiplier < maskTypeDimSizes[i])
47  return failure();
48  } else {
49  // Unknown (without further analysis).
50  unknownDims.push_back(UnknownMaskDim{i, dimSize});
51  }
52  }
53 
54  for (auto [i, dimSize] : unknownDims) {
55  // Compute the lower bound for the unknown dimension (i.e. the smallest
56  // value it could be).
57  FailureOr<ConstantOrScalableBound> dimLowerBound =
59  dimSize, {}, vscaleRange.vscaleMin, vscaleRange.vscaleMax,
61  if (failed(dimLowerBound))
62  return failure();
63  auto dimLowerBoundSize = dimLowerBound->getSize();
64  if (failed(dimLowerBoundSize))
65  return failure();
66  if (dimLowerBoundSize->scalable) {
67  // 1. The lower bound, LB, is scalable. If LB is < the mask dim size then
68  // this dim is not all-true.
69  if (dimLowerBoundSize->baseSize < maskTypeDimSizes[i])
70  return failure();
71  } else {
72  // 2. The lower bound, LB, is a constant.
73  // - If the mask dim size is scalable then this dim is not all-true.
74  if (maskTypeDimScalableFlags[i])
75  return failure();
76  // - If LB < the _fixed-size_ mask dim size then this dim is not all-true.
77  if (dimLowerBoundSize->baseSize < maskTypeDimSizes[i])
78  return failure();
79  }
80  }
81 
82  // Replace createMaskOp with an all-true constant. This should result in the
83  // mask being removed in most cases (as xfer ops + vector.mask have folds to
84  // remove all-true masks).
85  auto allTrue = vector::ConstantMaskOp::create(
86  rewriter, createMaskOp.getLoc(), maskType, ConstantMaskKind::AllTrue);
87  rewriter.replaceAllUsesWith(createMaskOp, allTrue);
88  return success();
89 }
90 
91 } // namespace
92 
93 namespace mlir::vector {
94 
95 void eliminateVectorMasks(IRRewriter &rewriter, FunctionOpInterface function,
96  std::optional<VscaleRange> vscaleRange) {
97  // TODO: Support fixed-size case. This is less likely to be useful as for
98  // fixed-size code dimensions are all static so masks tend to fold away.
99  if (!vscaleRange)
100  return;
101 
102  OpBuilder::InsertionGuard g(rewriter);
103 
104  // Build worklist so we can safely insert new ops in
105  // `resolveAllTrueCreateMaskOp()`.
107  function.walk([&](vector::CreateMaskOp createMaskOp) {
108  worklist.push_back(createMaskOp);
109  });
110 
111  rewriter.setInsertionPointToStart(&function.front());
112  for (auto mask : worklist)
113  (void)resolveAllTrueCreateMaskOp(rewriter, mask, *vscaleRange);
114 }
115 
116 } // namespace mlir::vector
This class coordinates rewriting a piece of IR outside of a pattern rewrite, providing a way to keep ...
Definition: PatternMatch.h:764
RAII guard to reset the insertion point of the builder when destroyed.
Definition: Builders.h:346
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
Definition: Builders.h:429
void replaceAllUsesWith(Value from, Value to)
Find uses of from and replace them with to.
Definition: PatternMatch.h:636
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:344
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition: Remarks.h:491
std::optional< int64_t > getConstantVscaleMultiplier(Value value)
If value is a constant multiple of vector.vscale (e.g.
Definition: VectorOps.cpp:384
void eliminateVectorMasks(IRRewriter &rewriter, FunctionOpInterface function, std::optional< VscaleRange > vscaleRange={})
Attempts to eliminate redundant vector masks by replacing them with all-true constants at the top of ...
Include the generated interface declarations.
std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
static FailureOr< ConstantOrScalableBound > computeScalableBound(Value value, std::optional< int64_t > dim, unsigned vscaleMin, unsigned vscaleMax, presburger::BoundType boundType, bool closedUB=true, StopConditionFn stopCondition=nullptr)
Computes a (possibly) scalable bound for a given value.