MLIR  20.0.0git
VectorMaskElimination.cpp
Go to the documentation of this file.
1 //===- VectorMaskElimination.cpp - Eliminate Vector Masks -----------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
15 
16 using namespace mlir;
17 using namespace mlir::vector;
18 namespace {
19 
20 /// Attempts to resolve a (scalable) CreateMaskOp to an all-true constant mask.
21 /// All-true masks can then be eliminated by simple folds.
22 LogicalResult resolveAllTrueCreateMaskOp(IRRewriter &rewriter,
23  vector::CreateMaskOp createMaskOp,
24  VscaleRange vscaleRange) {
25  auto maskType = createMaskOp.getVectorType();
26  auto maskTypeDimScalableFlags = maskType.getScalableDims();
27  auto maskTypeDimSizes = maskType.getShape();
28 
29  struct UnknownMaskDim {
30  size_t position;
31  Value dimSize;
32  };
33 
34  // Loop over the CreateMaskOp operands and collect unknown dims (i.e. dims
35  // that are not obviously constant). If any constant dimension is not all-true
36  // bail out early (as this transform only trying to resolve all-true masks).
37  // This avoids doing value-bounds anaylis in cases like:
38  // `%mask = vector.create_mask %dynamicValue, %c2 : vector<8x4xi1>`
39  // ...where it is known the mask is not all-true by looking at `%c2`.
40  SmallVector<UnknownMaskDim> unknownDims;
41  for (auto [i, dimSize] : llvm::enumerate(createMaskOp.getOperands())) {
42  if (auto intSize = getConstantIntValue(dimSize)) {
43  // Mask not all-true for this dim.
44  if (maskTypeDimScalableFlags[i] || intSize < maskTypeDimSizes[i])
45  return failure();
46  } else if (auto vscaleMultiplier = getConstantVscaleMultiplier(dimSize)) {
47  // Mask not all-true for this dim.
48  if (vscaleMultiplier < maskTypeDimSizes[i])
49  return failure();
50  } else {
51  // Unknown (without further analysis).
52  unknownDims.push_back(UnknownMaskDim{i, dimSize});
53  }
54  }
55 
56  for (auto [i, dimSize] : unknownDims) {
57  // Compute the lower bound for the unknown dimension (i.e. the smallest
58  // value it could be).
59  FailureOr<ConstantOrScalableBound> dimLowerBound =
61  dimSize, {}, vscaleRange.vscaleMin, vscaleRange.vscaleMax,
63  if (failed(dimLowerBound))
64  return failure();
65  auto dimLowerBoundSize = dimLowerBound->getSize();
66  if (failed(dimLowerBoundSize))
67  return failure();
68  if (dimLowerBoundSize->scalable) {
69  // 1. The lower bound, LB, is scalable. If LB is < the mask dim size then
70  // this dim is not all-true.
71  if (dimLowerBoundSize->baseSize < maskTypeDimSizes[i])
72  return failure();
73  } else {
74  // 2. The lower bound, LB, is a constant.
75  // - If the mask dim size is scalable then this dim is not all-true.
76  if (maskTypeDimScalableFlags[i])
77  return failure();
78  // - If LB < the _fixed-size_ mask dim size then this dim is not all-true.
79  if (dimLowerBoundSize->baseSize < maskTypeDimSizes[i])
80  return failure();
81  }
82  }
83 
84  // Replace createMaskOp with an all-true constant. This should result in the
85  // mask being removed in most cases (as xfer ops + vector.mask have folds to
86  // remove all-true masks).
87  auto allTrue = rewriter.create<vector::ConstantMaskOp>(
88  createMaskOp.getLoc(), maskType, ConstantMaskKind::AllTrue);
89  rewriter.replaceAllUsesWith(createMaskOp, allTrue);
90  return success();
91 }
92 
93 } // namespace
94 
95 namespace mlir::vector {
96 
97 void eliminateVectorMasks(IRRewriter &rewriter, FunctionOpInterface function,
98  std::optional<VscaleRange> vscaleRange) {
99  // TODO: Support fixed-size case. This is less likely to be useful as for
100  // fixed-size code dimensions are all static so masks tend to fold away.
101  if (!vscaleRange)
102  return;
103 
104  OpBuilder::InsertionGuard g(rewriter);
105 
106  // Build worklist so we can safely insert new ops in
107  // `resolveAllTrueCreateMaskOp()`.
109  function.walk([&](vector::CreateMaskOp createMaskOp) {
110  worklist.push_back(createMaskOp);
111  });
112 
113  rewriter.setInsertionPointToStart(&function.front());
114  for (auto mask : worklist)
115  (void)resolveAllTrueCreateMaskOp(rewriter, mask, *vscaleRange);
116 }
117 
118 } // namespace mlir::vector
This class coordinates rewriting a piece of IR outside of a pattern rewrite, providing a way to keep ...
Definition: PatternMatch.h:772
RAII guard to reset the insertion point of the builder when destroyed.
Definition: Builders.h:356
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
Definition: Builders.h:439
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:497
void replaceAllUsesWith(Value from, Value to)
Find uses of from and replace them with to.
Definition: PatternMatch.h:644
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:344
std::optional< int64_t > getConstantVscaleMultiplier(Value value)
If value is a constant multiple of vector.vscale (e.g.
Definition: VectorOps.cpp:354
void eliminateVectorMasks(IRRewriter &rewriter, FunctionOpInterface function, std::optional< VscaleRange > vscaleRange={})
Attempts to eliminate redundant vector masks by replacing them with all-true constants at the top of ...
Include the generated interface declarations.
std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
static FailureOr< ConstantOrScalableBound > computeScalableBound(Value value, std::optional< int64_t > dim, unsigned vscaleMin, unsigned vscaleMax, presburger::BoundType boundType, bool closedUB=true, StopConditionFn stopCondition=nullptr)
Computes a (possibly) scalable bound for a given value.