MLIR 22.0.0git
VectorMaskElimination.cpp
Go to the documentation of this file.
1//===- VectorMaskElimination.cpp - Eliminate Vector Masks -----------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
13
14using namespace mlir;
15using namespace mlir::vector;
16namespace {
17
18/// Attempts to resolve a (scalable) CreateMaskOp to an all-true constant mask.
19/// All-true masks can then be eliminated by simple folds.
20LogicalResult resolveAllTrueCreateMaskOp(IRRewriter &rewriter,
21 vector::CreateMaskOp createMaskOp,
22 VscaleRange vscaleRange) {
23 auto maskType = createMaskOp.getVectorType();
24 auto maskTypeDimScalableFlags = maskType.getScalableDims();
25 auto maskTypeDimSizes = maskType.getShape();
26
27 struct UnknownMaskDim {
28 size_t position;
29 Value dimSize;
30 };
31
32 // Loop over the CreateMaskOp operands and collect unknown dims (i.e. dims
33 // that are not obviously constant). If any constant dimension is not all-true
34 // bail out early (as this transform only trying to resolve all-true masks).
35 // This avoids doing value-bounds anaylis in cases like:
36 // `%mask = vector.create_mask %dynamicValue, %c2 : vector<8x4xi1>`
37 // ...where it is known the mask is not all-true by looking at `%c2`.
39 for (auto [i, dimSize] : llvm::enumerate(createMaskOp.getOperands())) {
40 if (auto intSize = getConstantIntValue(dimSize)) {
41 // Mask not all-true for this dim.
42 if (maskTypeDimScalableFlags[i] || intSize < maskTypeDimSizes[i])
43 return failure();
44 } else if (auto vscaleMultiplier = getConstantVscaleMultiplier(dimSize)) {
45 // Mask not all-true for this dim.
46 if (vscaleMultiplier < maskTypeDimSizes[i])
47 return failure();
48 } else {
49 // Unknown (without further analysis).
50 unknownDims.push_back(UnknownMaskDim{i, dimSize});
51 }
52 }
53
54 for (auto [i, dimSize] : unknownDims) {
55 // Compute the lower bound for the unknown dimension (i.e. the smallest
56 // value it could be).
57 FailureOr<ConstantOrScalableBound> dimLowerBound =
59 dimSize, {}, vscaleRange.vscaleMin, vscaleRange.vscaleMax,
61 if (failed(dimLowerBound))
62 return failure();
63 auto dimLowerBoundSize = dimLowerBound->getSize();
64 if (failed(dimLowerBoundSize))
65 return failure();
66 if (dimLowerBoundSize->scalable) {
67 // 1. The lower bound, LB, is scalable. If LB is < the mask dim size then
68 // this dim is not all-true.
69 if (dimLowerBoundSize->baseSize < maskTypeDimSizes[i])
70 return failure();
71 } else {
72 // 2. The lower bound, LB, is a constant.
73 // - If the mask dim size is scalable then this dim is not all-true.
74 if (maskTypeDimScalableFlags[i])
75 return failure();
76 // - If LB < the _fixed-size_ mask dim size then this dim is not all-true.
77 if (dimLowerBoundSize->baseSize < maskTypeDimSizes[i])
78 return failure();
79 }
80 }
81
82 // Replace createMaskOp with an all-true constant. This should result in the
83 // mask being removed in most cases (as xfer ops + vector.mask have folds to
84 // remove all-true masks).
85 auto allTrue = vector::ConstantMaskOp::create(
86 rewriter, createMaskOp.getLoc(), maskType, ConstantMaskKind::AllTrue);
87 rewriter.replaceAllUsesWith(createMaskOp, allTrue);
88 return success();
89}
90
91} // namespace
92
93namespace mlir::vector {
94
95void eliminateVectorMasks(IRRewriter &rewriter, FunctionOpInterface function,
96 std::optional<VscaleRange> vscaleRange) {
97 // TODO: Support fixed-size case. This is less likely to be useful as for
98 // fixed-size code dimensions are all static so masks tend to fold away.
99 if (!vscaleRange)
100 return;
101
102 OpBuilder::InsertionGuard g(rewriter);
103
104 // Build worklist so we can safely insert new ops in
105 // `resolveAllTrueCreateMaskOp()`.
107 function.walk([&](vector::CreateMaskOp createMaskOp) {
108 worklist.push_back(createMaskOp);
109 });
110
111 rewriter.setInsertionPointToStart(&function.front());
112 for (auto mask : worklist)
113 (void)resolveAllTrueCreateMaskOp(rewriter, mask, *vscaleRange);
114}
115
116} // namespace mlir::vector
return success()
This class coordinates rewriting a piece of IR outside of a pattern rewrite, providing a way to keep ...
RAII guard to reset the insertion point of the builder when destroyed.
Definition Builders.h:348
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
Definition Builders.h:431
virtual void replaceAllUsesWith(Value from, Value to)
Find uses of from and replace them with to.
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition Remarks.h:561
std::optional< int64_t > getConstantVscaleMultiplier(Value value)
If value is a constant multiple of vector.vscale (e.g.
void eliminateVectorMasks(IRRewriter &rewriter, FunctionOpInterface function, std::optional< VscaleRange > vscaleRange={})
Split a vector.transfer operation into an in-bounds (i.e., no out-of-bounds masking) fastpath and a s...
Include the generated interface declarations.
std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
static FailureOr< ConstantOrScalableBound > computeScalableBound(Value value, std::optional< int64_t > dim, unsigned vscaleMin, unsigned vscaleMax, presburger::BoundType boundType, bool closedUB=true, const StopConditionFn &stopCondition=nullptr)
Computes a (possibly) scalable bound for a given value.