MLIR  19.0.0git
LowerVectorMask.cpp
Go to the documentation of this file.
1 //===- LowerVectorMask.cpp - Lower 'vector.mask' operation ----------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements target-independent rewrites and utilities to lower the
10 // 'vector.mask' operation.
11 //
12 //===----------------------------------------------------------------------===//
13 
19 #include "mlir/IR/PatternMatch.h"
21 
22 #define DEBUG_TYPE "lower-vector-mask"
23 
24 namespace mlir {
25 namespace vector {
26 #define GEN_PASS_DEF_LOWERVECTORMASKPASS
27 #include "mlir/Dialect/Vector/Transforms/Passes.h.inc"
28 } // namespace vector
29 } // namespace mlir
30 
31 using namespace mlir;
32 using namespace mlir::vector;
33 
34 //===----------------------------------------------------------------------===//
35 // populateVectorMaskOpLoweringPatterns
36 //===----------------------------------------------------------------------===//
37 
38 namespace {
39 /// Progressive lowering of CreateMaskOp.
40 /// One:
41 /// %x = vector.create_mask %a, ... : vector<dx...>
42 /// is replaced by:
43 /// %l = vector.create_mask ... : vector<...> ; one lower rank
44 /// %0 = arith.cmpi "slt", %ci, %a |
45 /// %1 = select %0, %l, %zeroes |
46 /// %r = vector.insert %1, %pr [i] | d-times
47 /// %x = ....
48 /// until a one-dimensional vector is reached.
49 class CreateMaskOpLowering : public OpRewritePattern<vector::CreateMaskOp> {
50 public:
52 
53  LogicalResult matchAndRewrite(vector::CreateMaskOp op,
54  PatternRewriter &rewriter) const override {
55  auto dstType = cast<VectorType>(op.getResult().getType());
56  int64_t rank = dstType.getRank();
57  if (rank <= 1)
58  return rewriter.notifyMatchFailure(
59  op, "0-D and 1-D vectors are handled separately");
60 
61  if (dstType.getScalableDims().front())
62  return rewriter.notifyMatchFailure(
63  op, "Cannot unroll leading scalable dim in dstType");
64 
65  auto loc = op.getLoc();
66  int64_t dim = dstType.getDimSize(0);
67  Value idx = op.getOperand(0);
68 
69  VectorType lowType = VectorType::Builder(dstType).dropDim(0);
70  Value trueVal = rewriter.create<vector::CreateMaskOp>(
71  loc, lowType, op.getOperands().drop_front());
72  Value falseVal = rewriter.create<arith::ConstantOp>(
73  loc, lowType, rewriter.getZeroAttr(lowType));
74  Value result = rewriter.create<arith::ConstantOp>(
75  loc, dstType, rewriter.getZeroAttr(dstType));
76  for (int64_t d = 0; d < dim; d++) {
77  Value bnd =
78  rewriter.create<arith::ConstantOp>(loc, rewriter.getIndexAttr(d));
79  Value val = rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::slt,
80  bnd, idx);
81  Value sel = rewriter.create<arith::SelectOp>(loc, val, trueVal, falseVal);
82  result = rewriter.create<vector::InsertOp>(loc, sel, result, d);
83  }
84  rewriter.replaceOp(op, result);
85  return success();
86  }
87 };
88 
89 /// Progressive lowering of ConstantMaskOp.
90 /// One:
91 /// %x = vector.constant_mask [a,b]
92 /// is replaced by:
93 /// %z = zero-result
94 /// %l = vector.constant_mask [b]
95 /// %4 = vector.insert %l, %z[0]
96 /// ..
97 /// %x = vector.insert %l, %..[a-1]
98 /// until a one-dimensional vector is reached. All these operations
99 /// will be folded at LLVM IR level.
100 class ConstantMaskOpLowering : public OpRewritePattern<vector::ConstantMaskOp> {
101 public:
103 
104  LogicalResult matchAndRewrite(vector::ConstantMaskOp op,
105  PatternRewriter &rewriter) const override {
106  auto loc = op.getLoc();
107  auto dstType = op.getType();
108  auto dimSizes = op.getMaskDimSizes();
109  int64_t rank = dstType.getRank();
110 
111  if (rank == 0) {
112  assert(dimSizes.size() == 1 &&
113  "Expected exactly one dim size for a 0-D vector");
114  bool value = cast<IntegerAttr>(dimSizes[0]).getInt() == 1;
115  rewriter.replaceOpWithNewOp<arith::ConstantOp>(
116  op, dstType,
118  value));
119  return success();
120  }
121 
122  int64_t trueDimSize = cast<IntegerAttr>(dimSizes[0]).getInt();
123 
124  if (rank == 1) {
125  if (trueDimSize == 0 || trueDimSize == dstType.getDimSize(0)) {
126  // Use constant splat for 'all set' or 'none set' dims.
127  // This produces correct code for scalable dimensions (it will lower to
128  // a constant splat).
129  rewriter.replaceOpWithNewOp<arith::ConstantOp>(
130  op, DenseElementsAttr::get(dstType, trueDimSize != 0));
131  } else {
132  // Express constant 1-D case in explicit vector form:
133  // [T,..,T,F,..,F].
134  // Note: The verifier would reject this case for scalable vectors.
135  SmallVector<bool> values(dstType.getDimSize(0), false);
136  for (int64_t d = 0; d < trueDimSize; d++)
137  values[d] = true;
138  rewriter.replaceOpWithNewOp<arith::ConstantOp>(
139  op, dstType, rewriter.getBoolVectorAttr(values));
140  }
141  return success();
142  }
143 
144  if (dstType.getScalableDims().front())
145  return rewriter.notifyMatchFailure(
146  op, "Cannot unroll leading scalable dim in dstType");
147 
148  VectorType lowType = VectorType::Builder(dstType).dropDim(0);
149  Value trueVal = rewriter.create<vector::ConstantMaskOp>(
150  loc, lowType, rewriter.getArrayAttr(dimSizes.getValue().drop_front()));
151  Value result = rewriter.create<arith::ConstantOp>(
152  loc, dstType, rewriter.getZeroAttr(dstType));
153  for (int64_t d = 0; d < trueDimSize; d++)
154  result = rewriter.create<vector::InsertOp>(loc, trueVal, result, d);
155 
156  rewriter.replaceOp(op, result);
157  return success();
158  }
159 };
160 } // namespace
161 
163  RewritePatternSet &patterns, PatternBenefit benefit) {
164  patterns.add<CreateMaskOpLowering, ConstantMaskOpLowering>(
165  patterns.getContext(), benefit);
166 }
167 
168 //===----------------------------------------------------------------------===//
169 // populateVectorMaskLoweringPatternsForSideEffectingOps
170 //===----------------------------------------------------------------------===//
171 
172 namespace {
173 
174 /// The `MaskOpRewritePattern` implements a pattern that follows a two-fold
175 /// matching:
176 /// 1. It matches a `vector.mask` operation.
177 /// 2. It invokes `matchAndRewriteMaskableOp` on `MaskableOpInterface` nested
178 /// in the matched `vector.mask` operation.
179 ///
180 /// It is required that the replacement op in the pattern replaces the
181 /// `vector.mask` operation and not the nested `MaskableOpInterface`. This
182 /// approach allows having patterns that "stop" at every `vector.mask` operation
183 /// and actually match the traits of its the nested `MaskableOpInterface`.
184 template <class SourceOp>
185 struct MaskOpRewritePattern : OpRewritePattern<MaskOp> {
187 
188 private:
189  LogicalResult matchAndRewrite(MaskOp maskOp,
190  PatternRewriter &rewriter) const final {
191  auto maskableOp = cast_or_null<MaskableOpInterface>(maskOp.getMaskableOp());
192  if (!maskableOp)
193  return failure();
194  SourceOp sourceOp = dyn_cast<SourceOp>(maskableOp.getOperation());
195  if (!sourceOp)
196  return failure();
197 
198  return matchAndRewriteMaskableOp(sourceOp, maskOp, rewriter);
199  }
200 
201 protected:
202  virtual LogicalResult
203  matchAndRewriteMaskableOp(SourceOp sourceOp, MaskingOpInterface maskingOp,
204  PatternRewriter &rewriter) const = 0;
205 };
206 
207 /// Lowers a masked `vector.transfer_read` operation.
208 struct MaskedTransferReadOpPattern
209  : public MaskOpRewritePattern<TransferReadOp> {
210 public:
211  using MaskOpRewritePattern<TransferReadOp>::MaskOpRewritePattern;
212 
214  matchAndRewriteMaskableOp(TransferReadOp readOp, MaskingOpInterface maskingOp,
215  PatternRewriter &rewriter) const override {
216  // TODO: The 'vector.mask' passthru is a vector and 'vector.transfer_read'
217  // expects a scalar. We could only lower one to the other for cases where
218  // the passthru is a broadcast of a scalar.
219  if (maskingOp.hasPassthru())
220  return rewriter.notifyMatchFailure(
221  maskingOp, "Can't lower passthru to vector.transfer_read");
222 
223  // Replace the `vector.mask` operation.
224  rewriter.replaceOpWithNewOp<TransferReadOp>(
225  maskingOp.getOperation(), readOp.getVectorType(), readOp.getSource(),
226  readOp.getIndices(), readOp.getPermutationMap(), readOp.getPadding(),
227  maskingOp.getMask(), readOp.getInBounds().value_or(ArrayAttr()));
228  return success();
229  }
230 };
231 
232 /// Lowers a masked `vector.transfer_write` operation.
233 struct MaskedTransferWriteOpPattern
234  : public MaskOpRewritePattern<TransferWriteOp> {
235 public:
236  using MaskOpRewritePattern<TransferWriteOp>::MaskOpRewritePattern;
237 
239  matchAndRewriteMaskableOp(TransferWriteOp writeOp,
240  MaskingOpInterface maskingOp,
241  PatternRewriter &rewriter) const override {
242  Type resultType =
243  writeOp.getResult() ? writeOp.getResult().getType() : Type();
244 
245  // Replace the `vector.mask` operation.
246  rewriter.replaceOpWithNewOp<TransferWriteOp>(
247  maskingOp.getOperation(), resultType, writeOp.getVector(),
248  writeOp.getSource(), writeOp.getIndices(), writeOp.getPermutationMap(),
249  maskingOp.getMask(), writeOp.getInBounds().value_or(ArrayAttr()));
250  return success();
251  }
252 };
253 
254 /// Lowers a masked `vector.gather` operation.
255 struct MaskedGatherOpPattern : public MaskOpRewritePattern<GatherOp> {
256 public:
257  using MaskOpRewritePattern<GatherOp>::MaskOpRewritePattern;
258 
260  matchAndRewriteMaskableOp(GatherOp gatherOp, MaskingOpInterface maskingOp,
261  PatternRewriter &rewriter) const override {
262  Value passthru = maskingOp.hasPassthru()
263  ? maskingOp.getPassthru()
264  : rewriter.create<arith::ConstantOp>(
265  gatherOp.getLoc(),
266  rewriter.getZeroAttr(gatherOp.getVectorType()));
267 
268  // Replace the `vector.mask` operation.
269  rewriter.replaceOpWithNewOp<GatherOp>(
270  maskingOp.getOperation(), gatherOp.getVectorType(), gatherOp.getBase(),
271  gatherOp.getIndices(), gatherOp.getIndexVec(), maskingOp.getMask(),
272  passthru);
273  return success();
274  }
275 };
276 
277 struct LowerVectorMaskPass
278  : public vector::impl::LowerVectorMaskPassBase<LowerVectorMaskPass> {
279  using Base::Base;
280 
281  void runOnOperation() override {
282  Operation *op = getOperation();
283  MLIRContext *context = op->getContext();
284 
285  RewritePatternSet loweringPatterns(context);
287  MaskOp::getCanonicalizationPatterns(loweringPatterns, context);
288 
289  if (failed(applyPatternsAndFoldGreedily(op, std::move(loweringPatterns))))
290  signalPassFailure();
291  }
292 
293  void getDependentDialects(DialectRegistry &registry) const override {
294  registry.insert<vector::VectorDialect>();
295  }
296 };
297 
298 } // namespace
299 
300 /// Populates instances of `MaskOpRewritePattern` to lower masked operations
301 /// with `vector.mask`. Patterns should rewrite the `vector.mask` operation and
302 /// not its nested `MaskableOpInterface`.
304  RewritePatternSet &patterns) {
305  patterns.add<MaskedTransferReadOpPattern, MaskedTransferWriteOpPattern,
306  MaskedGatherOpPattern>(patterns.getContext());
307 }
308 
310  return std::make_unique<LowerVectorMaskPass>();
311 }
IntegerAttr getIndexAttr(int64_t value)
Definition: Builders.cpp:124
DenseIntElementsAttr getBoolVectorAttr(ArrayRef< bool > values)
Vector-typed DenseIntElementsAttr getters. values must not be empty.
Definition: Builders.cpp:132
TypedAttr getZeroAttr(Type type)
Definition: Builders.cpp:331
IntegerType getI1Type()
Definition: Builders.cpp:73
ArrayAttr getArrayAttr(ArrayRef< Attribute > value)
Definition: Builders.cpp:273
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
static DenseIntElementsAttr get(const ShapedType &type, Arg &&arg)
Get an instance of a DenseIntElementsAttr with the given arguments.
The DialectRegistry maps a dialect namespace to a constructor for the matching dialect.
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:464
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
Value getOperand(unsigned idx)
Definition: Operation.h:345
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Definition: Operation.h:402
MLIRContext * getContext()
Return the context this operation is associated with.
Definition: Operation.h:216
Location getLoc()
The source location the operation was defined or derived from.
Definition: Operation.h:223
operand_range getOperands()
Returns an iterator on the underlying Value's.
Definition: Operation.h:373
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
Definition: PatternMatch.h:34
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:785
MLIRContext * getContext() const
Definition: PatternMatch.h:822
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
Definition: PatternMatch.h:846
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
Definition: PatternMatch.h:718
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Definition: PatternMatch.h:536
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:125
This is a builder type that keeps local references to arguments.
Definition: BuiltinTypes.h:305
Builder & dropDim(unsigned pos)
Erase a dim from shape @pos.
Definition: BuiltinTypes.h:330
std::unique_ptr< Pass > createLowerVectorMaskPass()
Creates an instance of the vector.mask lowering pass.
void populateVectorMaskOpLoweringPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Populate the pattern set with the following patterns:
void populateVectorMaskLoweringPatternsForSideEffectingOps(RewritePatternSet &patterns)
Populates instances of MaskOpRewritePattern to lower masked operations with vector....
Include the generated interface declarations.
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:62
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:56
LogicalResult applyPatternsAndFoldGreedily(Region &region, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr)
Rewrite ops in the given region, which must be isolated from above, by repeatedly applying the highes...
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value.
Definition: LogicalResult.h:72
This class represents an efficient way to signal success or failure.
Definition: LogicalResult.h:26
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Definition: PatternMatch.h:358
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...
Definition: PatternMatch.h:362