MLIR 22.0.0git
LowerVectorMask.cpp
Go to the documentation of this file.
1//===- LowerVectorMask.cpp - Lower 'vector.mask' operation ----------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements target-independent rewrites and utilities to lower the
10// 'vector.mask' operation.
11//
12//===----------------------------------------------------------------------===//
13
21
22#define DEBUG_TYPE "lower-vector-mask"
23
24namespace mlir {
25namespace vector {
26#define GEN_PASS_DEF_LOWERVECTORMASKPASS
27#include "mlir/Dialect/Vector/Transforms/Passes.h.inc"
28} // namespace vector
29} // namespace mlir
31using namespace mlir;
32using namespace mlir::vector;
33
34//===----------------------------------------------------------------------===//
35// populateVectorMaskOpLoweringPatterns
36//===----------------------------------------------------------------------===//
38namespace {
39/// Progressive lowering of CreateMaskOp.
40/// One:
41/// %x = vector.create_mask %a, ... : vector<dx...>
42/// is replaced by:
43/// %l = vector.create_mask ... : vector<...> ; one lower rank
44/// %0 = arith.cmpi "slt", %ci, %a |
45/// %1 = select %0, %l, %zeroes |
46/// %r = vector.insert %1, %pr [i] | d-times
47/// %x = ....
48/// until a one-dimensional vector is reached.
49class CreateMaskOpLowering : public OpRewritePattern<vector::CreateMaskOp> {
50public:
51 using Base::Base;
52
53 LogicalResult matchAndRewrite(vector::CreateMaskOp op,
54 PatternRewriter &rewriter) const override {
55 auto dstType = cast<VectorType>(op.getResult().getType());
56 int64_t rank = dstType.getRank();
57 if (rank <= 1)
58 return rewriter.notifyMatchFailure(
59 op, "0-D and 1-D vectors are handled separately");
60
61 if (dstType.getScalableDims().front())
62 return rewriter.notifyMatchFailure(
63 op, "Cannot unroll leading scalable dim in dstType");
64
65 auto loc = op.getLoc();
66 int64_t dim = dstType.getDimSize(0);
67 Value idx = op.getOperand(0);
68
69 VectorType lowType = VectorType::Builder(dstType).dropDim(0);
70 Value trueVal = vector::CreateMaskOp::create(rewriter, loc, lowType,
71 op.getOperands().drop_front());
72 Value falseVal = arith::ConstantOp::create(rewriter, loc, lowType,
73 rewriter.getZeroAttr(lowType));
74 Value result = arith::ConstantOp::create(rewriter, loc, dstType,
75 rewriter.getZeroAttr(dstType));
76 for (int64_t d = 0; d < dim; d++) {
77 Value bnd =
78 arith::ConstantOp::create(rewriter, loc, rewriter.getIndexAttr(d));
79 Value val = arith::CmpIOp::create(rewriter, loc,
80 arith::CmpIPredicate::slt, bnd, idx);
81 Value sel =
82 arith::SelectOp::create(rewriter, loc, val, trueVal, falseVal);
83 result = vector::InsertOp::create(rewriter, loc, sel, result, d);
84 }
85 rewriter.replaceOp(op, result);
86 return success();
87 }
88};
89
90/// Progressive lowering of ConstantMaskOp.
91/// One:
92/// %x = vector.constant_mask [a,b]
93/// is replaced by:
94/// %z = zero-result
95/// %l = vector.constant_mask [b]
96/// %4 = vector.insert %l, %z[0]
97/// ..
98/// %x = vector.insert %l, %..[a-1]
99/// until a one-dimensional vector is reached. All these operations
100/// will be folded at LLVM IR level.
101class ConstantMaskOpLowering : public OpRewritePattern<vector::ConstantMaskOp> {
102public:
103 using Base::Base;
104
105 LogicalResult matchAndRewrite(vector::ConstantMaskOp op,
106 PatternRewriter &rewriter) const override {
107 auto loc = op.getLoc();
108 auto dstType = op.getType();
109 auto dimSizes = op.getMaskDimSizes();
110 int64_t rank = dstType.getRank();
111
112 if (rank == 0) {
113 assert(dimSizes.size() == 1 &&
114 "Expected exactly one dim size for a 0-D vector");
115 bool value = dimSizes.front() == 1;
116 rewriter.replaceOpWithNewOp<arith::ConstantOp>(
117 op, dstType,
118 DenseIntElementsAttr::get(VectorType::get({}, rewriter.getI1Type()),
119 value));
120 return success();
121 }
122
123 int64_t trueDimSize = dimSizes.front();
124
125 if (rank == 1) {
126 if (trueDimSize == 0 || trueDimSize == dstType.getDimSize(0)) {
127 // Use constant splat for 'all set' or 'none set' dims.
128 // This produces correct code for scalable dimensions (it will lower to
129 // a constant splat).
130 rewriter.replaceOpWithNewOp<arith::ConstantOp>(
131 op, DenseElementsAttr::get(dstType, trueDimSize != 0));
132 } else {
133 // Express constant 1-D case in explicit vector form:
134 // [T,..,T,F,..,F].
135 // Note: The verifier would reject this case for scalable vectors.
136 SmallVector<bool> values(dstType.getDimSize(0), false);
137 for (int64_t d = 0; d < trueDimSize; d++)
138 values[d] = true;
139 rewriter.replaceOpWithNewOp<arith::ConstantOp>(
140 op, dstType, rewriter.getBoolVectorAttr(values));
141 }
142 return success();
143 }
144
145 if (dstType.getScalableDims().front())
146 return rewriter.notifyMatchFailure(
147 op, "Cannot unroll leading scalable dim in dstType");
148
149 VectorType lowType = VectorType::Builder(dstType).dropDim(0);
150 Value trueVal = vector::ConstantMaskOp::create(rewriter, loc, lowType,
151 dimSizes.drop_front());
152 Value result = arith::ConstantOp::create(rewriter, loc, dstType,
153 rewriter.getZeroAttr(dstType));
154 for (int64_t d = 0; d < trueDimSize; d++)
155 result = vector::InsertOp::create(rewriter, loc, trueVal, result, d);
156
157 rewriter.replaceOp(op, result);
158 return success();
159 }
160};
161} // namespace
162
165 patterns.add<CreateMaskOpLowering, ConstantMaskOpLowering>(
166 patterns.getContext(), benefit);
167}
168
169//===----------------------------------------------------------------------===//
170// populateVectorMaskLoweringPatternsForSideEffectingOps
171//===----------------------------------------------------------------------===//
172
173namespace {
174
175/// The `MaskOpRewritePattern` implements a pattern that follows a two-fold
176/// matching:
177/// 1. It matches a `vector.mask` operation.
178/// 2. It invokes `matchAndRewriteMaskableOp` on `MaskableOpInterface` nested
179/// in the matched `vector.mask` operation.
180///
181/// It is required that the replacement op in the pattern replaces the
182/// `vector.mask` operation and not the nested `MaskableOpInterface`. This
183/// approach allows having patterns that "stop" at every `vector.mask` operation
184/// and actually match the traits of its the nested `MaskableOpInterface`.
185template <class SourceOp>
186struct MaskOpRewritePattern : OpRewritePattern<MaskOp> {
187 using Base::Base;
188
189private:
190 LogicalResult matchAndRewrite(MaskOp maskOp,
191 PatternRewriter &rewriter) const final {
192 auto maskableOp = cast_or_null<MaskableOpInterface>(maskOp.getMaskableOp());
193 if (!maskableOp)
194 return failure();
195 SourceOp sourceOp = dyn_cast<SourceOp>(maskableOp.getOperation());
196 if (!sourceOp)
197 return failure();
198
199 return matchAndRewriteMaskableOp(sourceOp, maskOp, rewriter);
200 }
201
202protected:
203 virtual LogicalResult
204 matchAndRewriteMaskableOp(SourceOp sourceOp, MaskingOpInterface maskingOp,
205 PatternRewriter &rewriter) const = 0;
206};
207
208/// Lowers a masked `vector.transfer_read` operation.
209struct MaskedTransferReadOpPattern
210 : public MaskOpRewritePattern<TransferReadOp> {
211public:
212 using MaskOpRewritePattern<TransferReadOp>::MaskOpRewritePattern;
213
214 LogicalResult
215 matchAndRewriteMaskableOp(TransferReadOp readOp, MaskingOpInterface maskingOp,
216 PatternRewriter &rewriter) const override {
217 // TODO: The 'vector.mask' passthru is a vector and 'vector.transfer_read'
218 // expects a scalar. We could only lower one to the other for cases where
219 // the passthru is a broadcast of a scalar.
220 if (maskingOp.hasPassthru())
221 return rewriter.notifyMatchFailure(
222 maskingOp, "Can't lower passthru to vector.transfer_read");
223
224 // Replace the `vector.mask` operation.
225 rewriter.replaceOpWithNewOp<TransferReadOp>(
226 maskingOp.getOperation(), readOp.getVectorType(), readOp.getBase(),
227 readOp.getIndices(), readOp.getPermutationMap(), readOp.getPadding(),
228 maskingOp.getMask(), readOp.getInBounds());
229 return success();
230 }
231};
232
233/// Lowers a masked `vector.transfer_write` operation.
234struct MaskedTransferWriteOpPattern
235 : public MaskOpRewritePattern<TransferWriteOp> {
236public:
237 using MaskOpRewritePattern<TransferWriteOp>::MaskOpRewritePattern;
238
239 LogicalResult
240 matchAndRewriteMaskableOp(TransferWriteOp writeOp,
241 MaskingOpInterface maskingOp,
242 PatternRewriter &rewriter) const override {
243 Type resultType =
244 writeOp.getResult() ? writeOp.getResult().getType() : Type();
245
246 // Replace the `vector.mask` operation.
247 rewriter.replaceOpWithNewOp<TransferWriteOp>(
248 maskingOp.getOperation(), resultType, writeOp.getVector(),
249 writeOp.getBase(), writeOp.getIndices(), writeOp.getPermutationMap(),
250 maskingOp.getMask(), writeOp.getInBounds());
251 return success();
252 }
253};
254
255/// Lowers a masked `vector.gather` operation.
256struct MaskedGatherOpPattern : public MaskOpRewritePattern<GatherOp> {
257public:
258 using MaskOpRewritePattern<GatherOp>::MaskOpRewritePattern;
259
260 LogicalResult
261 matchAndRewriteMaskableOp(GatherOp gatherOp, MaskingOpInterface maskingOp,
262 PatternRewriter &rewriter) const override {
263 Value passthru = maskingOp.hasPassthru()
264 ? maskingOp.getPassthru()
265 : arith::ConstantOp::create(
266 rewriter, gatherOp.getLoc(),
267 rewriter.getZeroAttr(gatherOp.getVectorType()));
268
269 // Replace the `vector.mask` operation.
270 rewriter.replaceOpWithNewOp<GatherOp>(
271 maskingOp.getOperation(), gatherOp.getVectorType(), gatherOp.getBase(),
272 gatherOp.getOffsets(), gatherOp.getIndices(), maskingOp.getMask(),
273 passthru);
274 return success();
275 }
276};
277
278struct LowerVectorMaskPass
279 : public vector::impl::LowerVectorMaskPassBase<LowerVectorMaskPass> {
280 using Base::Base;
281
282 void runOnOperation() override {
283 Operation *op = getOperation();
284 MLIRContext *context = op->getContext();
285
286 RewritePatternSet loweringPatterns(context);
288 MaskOp::getCanonicalizationPatterns(loweringPatterns, context);
289
290 if (failed(applyPatternsGreedily(op, std::move(loweringPatterns))))
291 signalPassFailure();
292 }
293
294 void getDependentDialects(DialectRegistry &registry) const override {
295 registry.insert<vector::VectorDialect>();
296 }
297};
298
299} // namespace
300
301/// Populates instances of `MaskOpRewritePattern` to lower masked operations
302/// with `vector.mask`. Patterns should rewrite the `vector.mask` operation and
303/// not its nested `MaskableOpInterface`.
306 patterns.add<MaskedTransferReadOpPattern, MaskedTransferWriteOpPattern,
307 MaskedGatherOpPattern>(patterns.getContext());
308}
309
311 return std::make_unique<LowerVectorMaskPass>();
312}
return success()
IntegerAttr getIndexAttr(int64_t value)
Definition Builders.cpp:108
DenseIntElementsAttr getBoolVectorAttr(ArrayRef< bool > values)
Vector-typed DenseIntElementsAttr getters. values must not be empty.
Definition Builders.cpp:116
TypedAttr getZeroAttr(Type type)
Definition Builders.cpp:324
IntegerType getI1Type()
Definition Builders.cpp:53
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
static DenseIntElementsAttr get(const ShapedType &type, Arg &&arg)
Get an instance of a DenseIntElementsAttr with the given arguments.
MLIRContext * getContext()
Return the context this operation is associated with.
Definition Operation.h:216
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
This is a builder type that keeps local references to arguments.
Builder & dropDim(unsigned pos)
Erase a dim from shape @pos.
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition Remarks.h:561
std::unique_ptr< Pass > createLowerVectorMaskPass()
Creates an instance of the vector.mask lowering pass.
void populateVectorMaskOpLoweringPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Populate the pattern set with the following patterns:
void populateVectorMaskLoweringPatternsForSideEffectingOps(RewritePatternSet &patterns)
Populates instances of MaskOpRewritePattern to lower masked operations with vector....
Include the generated interface declarations.
LogicalResult applyPatternsGreedily(Region &region, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr)
Rewrite ops in the given region, which must be isolated from above, by repeatedly applying the highes...
const FrozenRewritePatternSet & patterns
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...