MLIR  22.0.0git
LowerVectorMask.cpp
Go to the documentation of this file.
1 //===- LowerVectorMask.cpp - Lower 'vector.mask' operation ----------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements target-independent rewrites and utilities to lower the
10 // 'vector.mask' operation.
11 //
12 //===----------------------------------------------------------------------===//
13 
19 #include "mlir/IR/PatternMatch.h"
21 
22 #define DEBUG_TYPE "lower-vector-mask"
23 
24 namespace mlir {
25 namespace vector {
26 #define GEN_PASS_DEF_LOWERVECTORMASKPASS
27 #include "mlir/Dialect/Vector/Transforms/Passes.h.inc"
28 } // namespace vector
29 } // namespace mlir
30 
31 using namespace mlir;
32 using namespace mlir::vector;
33 
34 //===----------------------------------------------------------------------===//
35 // populateVectorMaskOpLoweringPatterns
36 //===----------------------------------------------------------------------===//
37 
38 namespace {
39 /// Progressive lowering of CreateMaskOp.
40 /// One:
41 /// %x = vector.create_mask %a, ... : vector<dx...>
42 /// is replaced by:
43 /// %l = vector.create_mask ... : vector<...> ; one lower rank
44 /// %0 = arith.cmpi "slt", %ci, %a |
45 /// %1 = select %0, %l, %zeroes |
46 /// %r = vector.insert %1, %pr [i] | d-times
47 /// %x = ....
48 /// until a one-dimensional vector is reached.
49 class CreateMaskOpLowering : public OpRewritePattern<vector::CreateMaskOp> {
50 public:
52 
53  LogicalResult matchAndRewrite(vector::CreateMaskOp op,
54  PatternRewriter &rewriter) const override {
55  auto dstType = cast<VectorType>(op.getResult().getType());
56  int64_t rank = dstType.getRank();
57  if (rank <= 1)
58  return rewriter.notifyMatchFailure(
59  op, "0-D and 1-D vectors are handled separately");
60 
61  if (dstType.getScalableDims().front())
62  return rewriter.notifyMatchFailure(
63  op, "Cannot unroll leading scalable dim in dstType");
64 
65  auto loc = op.getLoc();
66  int64_t dim = dstType.getDimSize(0);
67  Value idx = op.getOperand(0);
68 
69  VectorType lowType = VectorType::Builder(dstType).dropDim(0);
70  Value trueVal = vector::CreateMaskOp::create(rewriter, loc, lowType,
71  op.getOperands().drop_front());
72  Value falseVal = arith::ConstantOp::create(rewriter, loc, lowType,
73  rewriter.getZeroAttr(lowType));
74  Value result = arith::ConstantOp::create(rewriter, loc, dstType,
75  rewriter.getZeroAttr(dstType));
76  for (int64_t d = 0; d < dim; d++) {
77  Value bnd =
78  arith::ConstantOp::create(rewriter, loc, rewriter.getIndexAttr(d));
79  Value val = arith::CmpIOp::create(rewriter, loc,
80  arith::CmpIPredicate::slt, bnd, idx);
81  Value sel =
82  arith::SelectOp::create(rewriter, loc, val, trueVal, falseVal);
83  result = vector::InsertOp::create(rewriter, loc, sel, result, d);
84  }
85  rewriter.replaceOp(op, result);
86  return success();
87  }
88 };
89 
90 /// Progressive lowering of ConstantMaskOp.
91 /// One:
92 /// %x = vector.constant_mask [a,b]
93 /// is replaced by:
94 /// %z = zero-result
95 /// %l = vector.constant_mask [b]
96 /// %4 = vector.insert %l, %z[0]
97 /// ..
98 /// %x = vector.insert %l, %..[a-1]
99 /// until a one-dimensional vector is reached. All these operations
100 /// will be folded at LLVM IR level.
101 class ConstantMaskOpLowering : public OpRewritePattern<vector::ConstantMaskOp> {
102 public:
104 
105  LogicalResult matchAndRewrite(vector::ConstantMaskOp op,
106  PatternRewriter &rewriter) const override {
107  auto loc = op.getLoc();
108  auto dstType = op.getType();
109  auto dimSizes = op.getMaskDimSizes();
110  int64_t rank = dstType.getRank();
111 
112  if (rank == 0) {
113  assert(dimSizes.size() == 1 &&
114  "Expected exactly one dim size for a 0-D vector");
115  bool value = dimSizes.front() == 1;
116  rewriter.replaceOpWithNewOp<arith::ConstantOp>(
117  op, dstType,
119  value));
120  return success();
121  }
122 
123  int64_t trueDimSize = dimSizes.front();
124 
125  if (rank == 1) {
126  if (trueDimSize == 0 || trueDimSize == dstType.getDimSize(0)) {
127  // Use constant splat for 'all set' or 'none set' dims.
128  // This produces correct code for scalable dimensions (it will lower to
129  // a constant splat).
130  rewriter.replaceOpWithNewOp<arith::ConstantOp>(
131  op, DenseElementsAttr::get(dstType, trueDimSize != 0));
132  } else {
133  // Express constant 1-D case in explicit vector form:
134  // [T,..,T,F,..,F].
135  // Note: The verifier would reject this case for scalable vectors.
136  SmallVector<bool> values(dstType.getDimSize(0), false);
137  for (int64_t d = 0; d < trueDimSize; d++)
138  values[d] = true;
139  rewriter.replaceOpWithNewOp<arith::ConstantOp>(
140  op, dstType, rewriter.getBoolVectorAttr(values));
141  }
142  return success();
143  }
144 
145  if (dstType.getScalableDims().front())
146  return rewriter.notifyMatchFailure(
147  op, "Cannot unroll leading scalable dim in dstType");
148 
149  VectorType lowType = VectorType::Builder(dstType).dropDim(0);
150  Value trueVal = vector::ConstantMaskOp::create(rewriter, loc, lowType,
151  dimSizes.drop_front());
152  Value result = arith::ConstantOp::create(rewriter, loc, dstType,
153  rewriter.getZeroAttr(dstType));
154  for (int64_t d = 0; d < trueDimSize; d++)
155  result = vector::InsertOp::create(rewriter, loc, trueVal, result, d);
156 
157  rewriter.replaceOp(op, result);
158  return success();
159  }
160 };
161 } // namespace
162 
165  patterns.add<CreateMaskOpLowering, ConstantMaskOpLowering>(
166  patterns.getContext(), benefit);
167 }
168 
169 //===----------------------------------------------------------------------===//
170 // populateVectorMaskLoweringPatternsForSideEffectingOps
171 //===----------------------------------------------------------------------===//
172 
173 namespace {
174 
175 /// The `MaskOpRewritePattern` implements a pattern that follows a two-fold
176 /// matching:
177 /// 1. It matches a `vector.mask` operation.
178 /// 2. It invokes `matchAndRewriteMaskableOp` on `MaskableOpInterface` nested
179 /// in the matched `vector.mask` operation.
180 ///
181 /// It is required that the replacement op in the pattern replaces the
182 /// `vector.mask` operation and not the nested `MaskableOpInterface`. This
183 /// approach allows having patterns that "stop" at every `vector.mask` operation
184 /// and actually match the traits of its the nested `MaskableOpInterface`.
185 template <class SourceOp>
186 struct MaskOpRewritePattern : OpRewritePattern<MaskOp> {
188 
189 private:
190  LogicalResult matchAndRewrite(MaskOp maskOp,
191  PatternRewriter &rewriter) const final {
192  auto maskableOp = cast_or_null<MaskableOpInterface>(maskOp.getMaskableOp());
193  if (!maskableOp)
194  return failure();
195  SourceOp sourceOp = dyn_cast<SourceOp>(maskableOp.getOperation());
196  if (!sourceOp)
197  return failure();
198 
199  return matchAndRewriteMaskableOp(sourceOp, maskOp, rewriter);
200  }
201 
202 protected:
203  virtual LogicalResult
204  matchAndRewriteMaskableOp(SourceOp sourceOp, MaskingOpInterface maskingOp,
205  PatternRewriter &rewriter) const = 0;
206 };
207 
208 /// Lowers a masked `vector.transfer_read` operation.
209 struct MaskedTransferReadOpPattern
210  : public MaskOpRewritePattern<TransferReadOp> {
211 public:
212  using MaskOpRewritePattern<TransferReadOp>::MaskOpRewritePattern;
213 
214  LogicalResult
215  matchAndRewriteMaskableOp(TransferReadOp readOp, MaskingOpInterface maskingOp,
216  PatternRewriter &rewriter) const override {
217  // TODO: The 'vector.mask' passthru is a vector and 'vector.transfer_read'
218  // expects a scalar. We could only lower one to the other for cases where
219  // the passthru is a broadcast of a scalar.
220  if (maskingOp.hasPassthru())
221  return rewriter.notifyMatchFailure(
222  maskingOp, "Can't lower passthru to vector.transfer_read");
223 
224  // Replace the `vector.mask` operation.
225  rewriter.replaceOpWithNewOp<TransferReadOp>(
226  maskingOp.getOperation(), readOp.getVectorType(), readOp.getBase(),
227  readOp.getIndices(), readOp.getPermutationMap(), readOp.getPadding(),
228  maskingOp.getMask(), readOp.getInBounds());
229  return success();
230  }
231 };
232 
233 /// Lowers a masked `vector.transfer_write` operation.
234 struct MaskedTransferWriteOpPattern
235  : public MaskOpRewritePattern<TransferWriteOp> {
236 public:
237  using MaskOpRewritePattern<TransferWriteOp>::MaskOpRewritePattern;
238 
239  LogicalResult
240  matchAndRewriteMaskableOp(TransferWriteOp writeOp,
241  MaskingOpInterface maskingOp,
242  PatternRewriter &rewriter) const override {
243  Type resultType =
244  writeOp.getResult() ? writeOp.getResult().getType() : Type();
245 
246  // Replace the `vector.mask` operation.
247  rewriter.replaceOpWithNewOp<TransferWriteOp>(
248  maskingOp.getOperation(), resultType, writeOp.getVector(),
249  writeOp.getBase(), writeOp.getIndices(), writeOp.getPermutationMap(),
250  maskingOp.getMask(), writeOp.getInBounds());
251  return success();
252  }
253 };
254 
255 /// Lowers a masked `vector.gather` operation.
256 struct MaskedGatherOpPattern : public MaskOpRewritePattern<GatherOp> {
257 public:
258  using MaskOpRewritePattern<GatherOp>::MaskOpRewritePattern;
259 
260  LogicalResult
261  matchAndRewriteMaskableOp(GatherOp gatherOp, MaskingOpInterface maskingOp,
262  PatternRewriter &rewriter) const override {
263  Value passthru = maskingOp.hasPassthru()
264  ? maskingOp.getPassthru()
265  : arith::ConstantOp::create(
266  rewriter, gatherOp.getLoc(),
267  rewriter.getZeroAttr(gatherOp.getVectorType()));
268 
269  // Replace the `vector.mask` operation.
270  rewriter.replaceOpWithNewOp<GatherOp>(
271  maskingOp.getOperation(), gatherOp.getVectorType(), gatherOp.getBase(),
272  gatherOp.getOffsets(), gatherOp.getIndices(), maskingOp.getMask(),
273  passthru);
274  return success();
275  }
276 };
277 
278 struct LowerVectorMaskPass
279  : public vector::impl::LowerVectorMaskPassBase<LowerVectorMaskPass> {
280  using Base::Base;
281 
282  void runOnOperation() override {
283  Operation *op = getOperation();
284  MLIRContext *context = op->getContext();
285 
286  RewritePatternSet loweringPatterns(context);
288  MaskOp::getCanonicalizationPatterns(loweringPatterns, context);
289 
290  if (failed(applyPatternsGreedily(op, std::move(loweringPatterns))))
291  signalPassFailure();
292  }
293 
294  void getDependentDialects(DialectRegistry &registry) const override {
295  registry.insert<vector::VectorDialect>();
296  }
297 };
298 
299 } // namespace
300 
301 /// Populates instances of `MaskOpRewritePattern` to lower masked operations
302 /// with `vector.mask`. Patterns should rewrite the `vector.mask` operation and
303 /// not its nested `MaskableOpInterface`.
306  patterns.add<MaskedTransferReadOpPattern, MaskedTransferWriteOpPattern,
307  MaskedGatherOpPattern>(patterns.getContext());
308 }
309 
311  return std::make_unique<LowerVectorMaskPass>();
312 }
IntegerAttr getIndexAttr(int64_t value)
Definition: Builders.cpp:103
DenseIntElementsAttr getBoolVectorAttr(ArrayRef< bool > values)
Vector-typed DenseIntElementsAttr getters. values must not be empty.
Definition: Builders.cpp:111
TypedAttr getZeroAttr(Type type)
Definition: Builders.cpp:319
IntegerType getI1Type()
Definition: Builders.cpp:52
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
static DenseIntElementsAttr get(const ShapedType &type, Arg &&arg)
Get an instance of a DenseIntElementsAttr with the given arguments.
The DialectRegistry maps a dialect namespace to a constructor for the matching dialect.
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:63
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
MLIRContext * getContext()
Return the context this operation is associated with.
Definition: Operation.h:216
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
Definition: PatternMatch.h:34
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:783
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
Definition: PatternMatch.h:716
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Definition: PatternMatch.h:519
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
This is a builder type that keeps local references to arguments.
Definition: BuiltinTypes.h:286
Builder & dropDim(unsigned pos)
Erase a dim from shape @pos.
Definition: BuiltinTypes.h:311
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition: Remarks.h:491
std::unique_ptr< Pass > createLowerVectorMaskPass()
Creates an instance of the vector.mask lowering pass.
void populateVectorMaskOpLoweringPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Populate the pattern set with the following patterns:
void populateVectorMaskLoweringPatternsForSideEffectingOps(RewritePatternSet &patterns)
Populates instances of MaskOpRewritePattern to lower masked operations with vector....
Include the generated interface declarations.
LogicalResult applyPatternsGreedily(Region &region, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr)
Rewrite ops in the given region, which must be isolated from above, by repeatedly applying the highes...
const FrozenRewritePatternSet & patterns
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Definition: PatternMatch.h:314
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...
Definition: PatternMatch.h:319