MLIR  17.0.0git
LowerVectorMask.cpp
Go to the documentation of this file.
1 //===- LowerVectorMask.cpp - Lower 'vector.mask' operation ----------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements target-independent rewrites and utilities to lower the
10 // 'vector.mask' operation.
11 //
12 //===----------------------------------------------------------------------===//
13 
19 #include "mlir/IR/PatternMatch.h"
21 
22 #define DEBUG_TYPE "lower-vector-mask"
23 
24 namespace mlir {
25 namespace vector {
26 #define GEN_PASS_DEF_LOWERVECTORMASKPASS
27 #include "mlir/Dialect/Vector/Transforms/Passes.h.inc"
28 } // namespace vector
29 } // namespace mlir
30 
31 using namespace mlir;
32 using namespace mlir::vector;
33 
34 //===----------------------------------------------------------------------===//
35 // populateVectorMaskOpLoweringPatterns
36 //===----------------------------------------------------------------------===//
37 
38 namespace {
39 /// Progressive lowering of CreateMaskOp.
40 /// One:
41 /// %x = vector.create_mask %a, ... : vector<dx...>
42 /// is replaced by:
43 /// %l = vector.create_mask ... : vector<...> ; one lower rank
44 /// %0 = arith.cmpi "slt", %ci, %a |
45 /// %1 = select %0, %l, %zeroes |
46 /// %r = vector.insert %1, %pr [i] | d-times
47 /// %x = ....
48 /// until a one-dimensional vector is reached.
49 class CreateMaskOpLowering : public OpRewritePattern<vector::CreateMaskOp> {
50 public:
52 
53  LogicalResult matchAndRewrite(vector::CreateMaskOp op,
54  PatternRewriter &rewriter) const override {
55  auto dstType = op.getResult().getType().cast<VectorType>();
56  int64_t rank = dstType.getRank();
57  if (rank <= 1)
58  return rewriter.notifyMatchFailure(
59  op, "0-D and 1-D vectors are handled separately");
60 
61  auto loc = op.getLoc();
62  auto eltType = dstType.getElementType();
63  int64_t dim = dstType.getDimSize(0);
64  Value idx = op.getOperand(0);
65 
66  VectorType lowType =
67  VectorType::get(dstType.getShape().drop_front(), eltType);
68  Value trueVal = rewriter.create<vector::CreateMaskOp>(
69  loc, lowType, op.getOperands().drop_front());
70  Value falseVal = rewriter.create<arith::ConstantOp>(
71  loc, lowType, rewriter.getZeroAttr(lowType));
72  Value result = rewriter.create<arith::ConstantOp>(
73  loc, dstType, rewriter.getZeroAttr(dstType));
74  for (int64_t d = 0; d < dim; d++) {
75  Value bnd =
76  rewriter.create<arith::ConstantOp>(loc, rewriter.getIndexAttr(d));
77  Value val = rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::slt,
78  bnd, idx);
79  Value sel = rewriter.create<arith::SelectOp>(loc, val, trueVal, falseVal);
80  auto pos = rewriter.getI64ArrayAttr(d);
81  result =
82  rewriter.create<vector::InsertOp>(loc, dstType, sel, result, pos);
83  }
84  rewriter.replaceOp(op, result);
85  return success();
86  }
87 };
88 
89 /// Progressive lowering of ConstantMaskOp.
90 /// One:
91 /// %x = vector.constant_mask [a,b]
92 /// is replaced by:
93 /// %z = zero-result
94 /// %l = vector.constant_mask [b]
95 /// %4 = vector.insert %l, %z[0]
96 /// ..
97 /// %x = vector.insert %l, %..[a-1]
98 /// until a one-dimensional vector is reached. All these operations
99 /// will be folded at LLVM IR level.
100 class ConstantMaskOpLowering : public OpRewritePattern<vector::ConstantMaskOp> {
101 public:
103 
104  LogicalResult matchAndRewrite(vector::ConstantMaskOp op,
105  PatternRewriter &rewriter) const override {
106  auto loc = op.getLoc();
107  auto dstType = op.getType();
108  auto eltType = dstType.getElementType();
109  auto dimSizes = op.getMaskDimSizes();
110  int64_t rank = dstType.getRank();
111 
112  if (rank == 0) {
113  assert(dimSizes.size() == 1 &&
114  "Expected exactly one dim size for a 0-D vector");
115  bool value = dimSizes[0].cast<IntegerAttr>().getInt() == 1;
116  rewriter.replaceOpWithNewOp<arith::ConstantOp>(
117  op, dstType,
119  VectorType::get(ArrayRef<int64_t>{}, rewriter.getI1Type()),
120  ArrayRef<bool>{value}));
121  return success();
122  }
123 
124  // Scalable constant masks can only be lowered for the "none set" case.
125  if (dstType.cast<VectorType>().isScalable()) {
126  rewriter.replaceOpWithNewOp<arith::ConstantOp>(
127  op, DenseElementsAttr::get(dstType, false));
128  return success();
129  }
130 
131  int64_t trueDim = std::min(dstType.getDimSize(0),
132  dimSizes[0].cast<IntegerAttr>().getInt());
133 
134  if (rank == 1) {
135  // Express constant 1-D case in explicit vector form:
136  // [T,..,T,F,..,F].
137  SmallVector<bool> values(dstType.getDimSize(0));
138  for (int64_t d = 0; d < trueDim; d++)
139  values[d] = true;
140  rewriter.replaceOpWithNewOp<arith::ConstantOp>(
141  op, dstType, rewriter.getBoolVectorAttr(values));
142  return success();
143  }
144 
145  VectorType lowType =
146  VectorType::get(dstType.getShape().drop_front(), eltType);
147  SmallVector<int64_t> newDimSizes;
148  for (int64_t r = 1; r < rank; r++)
149  newDimSizes.push_back(dimSizes[r].cast<IntegerAttr>().getInt());
150  Value trueVal = rewriter.create<vector::ConstantMaskOp>(
151  loc, lowType, rewriter.getI64ArrayAttr(newDimSizes));
152  Value result = rewriter.create<arith::ConstantOp>(
153  loc, dstType, rewriter.getZeroAttr(dstType));
154  for (int64_t d = 0; d < trueDim; d++) {
155  auto pos = rewriter.getI64ArrayAttr(d);
156  result =
157  rewriter.create<vector::InsertOp>(loc, dstType, trueVal, result, pos);
158  }
159  rewriter.replaceOp(op, result);
160  return success();
161  }
162 };
163 } // namespace
164 
166  RewritePatternSet &patterns, PatternBenefit benefit) {
167  patterns.add<CreateMaskOpLowering, ConstantMaskOpLowering>(
168  patterns.getContext(), benefit);
169 }
170 
171 //===----------------------------------------------------------------------===//
172 // populateVectorMaskLoweringPatternsForSideEffectingOps
173 //===----------------------------------------------------------------------===//
174 
175 namespace {
176 
177 /// The `MaskOpRewritePattern` implements a pattern that follows a two-fold
178 /// matching:
179 /// 1. It matches a `vector.mask` operation.
180 /// 2. It invokes `matchAndRewriteMaskableOp` on `MaskableOpInterface` nested
181 /// in the matched `vector.mask` operation.
182 ///
183 /// It is required that the replacement op in the pattern replaces the
184 /// `vector.mask` operation and not the nested `MaskableOpInterface`. This
185 /// approach allows having patterns that "stop" at every `vector.mask` operation
186 /// and actually match the traits of its the nested `MaskableOpInterface`.
187 template <class SourceOp>
188 struct MaskOpRewritePattern : OpRewritePattern<MaskOp> {
190 
191 private:
192  LogicalResult matchAndRewrite(MaskOp maskOp,
193  PatternRewriter &rewriter) const final {
194  MaskableOpInterface maskableOp = maskOp.getMaskableOp();
195  SourceOp sourceOp = dyn_cast<SourceOp>(maskableOp.getOperation());
196  if (!sourceOp)
197  return failure();
198 
199  return matchAndRewriteMaskableOp(sourceOp, maskOp, rewriter);
200  }
201 
202 protected:
203  virtual LogicalResult
204  matchAndRewriteMaskableOp(SourceOp sourceOp, MaskingOpInterface maskingOp,
205  PatternRewriter &rewriter) const = 0;
206 };
207 
208 /// Lowers a masked `vector.transfer_read` operation.
209 struct MaskedTransferReadOpPattern
210  : public MaskOpRewritePattern<TransferReadOp> {
211 public:
212  using MaskOpRewritePattern<TransferReadOp>::MaskOpRewritePattern;
213 
215  matchAndRewriteMaskableOp(TransferReadOp readOp, MaskingOpInterface maskingOp,
216  PatternRewriter &rewriter) const override {
217  // TODO: The 'vector.mask' passthru is a vector and 'vector.transfer_read'
218  // expects a scalar. We could only lower one to the other for cases where
219  // the passthru is a broadcast of a scalar.
220  if (maskingOp.hasPassthru())
221  return rewriter.notifyMatchFailure(
222  maskingOp, "Can't lower passthru to vector.transfer_read");
223 
224  // Replace the `vector.mask` operation.
225  rewriter.replaceOpWithNewOp<TransferReadOp>(
226  maskingOp.getOperation(), readOp.getVectorType(), readOp.getSource(),
227  readOp.getIndices(), readOp.getPermutationMap(), readOp.getPadding(),
228  maskingOp.getMask(), readOp.getInBounds().value_or(ArrayAttr()));
229  return success();
230  }
231 };
232 
233 /// Lowers a masked `vector.transfer_write` operation.
234 struct MaskedTransferWriteOpPattern
235  : public MaskOpRewritePattern<TransferWriteOp> {
236 public:
237  using MaskOpRewritePattern<TransferWriteOp>::MaskOpRewritePattern;
238 
240  matchAndRewriteMaskableOp(TransferWriteOp writeOp,
241  MaskingOpInterface maskingOp,
242  PatternRewriter &rewriter) const override {
243  Type resultType =
244  writeOp.getResult() ? writeOp.getResult().getType() : Type();
245 
246  // Replace the `vector.mask` operation.
247  rewriter.replaceOpWithNewOp<TransferWriteOp>(
248  maskingOp.getOperation(), resultType, writeOp.getVector(),
249  writeOp.getSource(), writeOp.getIndices(), writeOp.getPermutationMap(),
250  maskingOp.getMask(), writeOp.getInBounds().value_or(ArrayAttr()));
251  return success();
252  }
253 };
254 
255 /// Lowers a masked `vector.gather` operation.
256 struct MaskedGatherOpPattern : public MaskOpRewritePattern<GatherOp> {
257 public:
258  using MaskOpRewritePattern<GatherOp>::MaskOpRewritePattern;
259 
261  matchAndRewriteMaskableOp(GatherOp gatherOp, MaskingOpInterface maskingOp,
262  PatternRewriter &rewriter) const override {
263  Value passthru = maskingOp.hasPassthru()
264  ? maskingOp.getPassthru()
265  : rewriter.create<arith::ConstantOp>(
266  gatherOp.getLoc(),
267  rewriter.getZeroAttr(gatherOp.getVectorType()));
268 
269  // Replace the `vector.mask` operation.
270  rewriter.replaceOpWithNewOp<GatherOp>(
271  maskingOp.getOperation(), gatherOp.getVectorType(), gatherOp.getBase(),
272  gatherOp.getIndices(), gatherOp.getIndexVec(), maskingOp.getMask(),
273  passthru);
274  return success();
275  }
276 };
277 
278 struct LowerVectorMaskPass
279  : public vector::impl::LowerVectorMaskPassBase<LowerVectorMaskPass> {
280  using Base::Base;
281 
282  void runOnOperation() override {
283  Operation *op = getOperation();
284  MLIRContext *context = op->getContext();
285 
286  RewritePatternSet loweringPatterns(context);
288 
289  if (failed(applyPatternsAndFoldGreedily(op, std::move(loweringPatterns))))
290  signalPassFailure();
291  }
292 
293  void getDependentDialects(DialectRegistry &registry) const override {
294  registry.insert<vector::VectorDialect>();
295  }
296 };
297 
298 } // namespace
299 
300 /// Populates instances of `MaskOpRewritePattern` to lower masked operations
301 /// with `vector.mask`. Patterns should rewrite the `vector.mask` operation and
302 /// not its nested `MaskableOpInterface`.
304  RewritePatternSet &patterns) {
305  patterns.add<MaskedTransferReadOpPattern, MaskedTransferWriteOpPattern,
306  MaskedGatherOpPattern>(patterns.getContext());
307 }
308 
310  return std::make_unique<LowerVectorMaskPass>();
311 }
static Value min(ImplicitLocOpBuilder &builder, Value value, Value bound)
IntegerAttr getIndexAttr(int64_t value)
Definition: Builders.cpp:121
DenseIntElementsAttr getBoolVectorAttr(ArrayRef< bool > values)
Vector-typed DenseIntElementsAttr getters. values must not be empty.
Definition: Builders.cpp:129
Attribute getZeroAttr(Type type)
Definition: Builders.cpp:318
IntegerType getI1Type()
Definition: Builders.cpp:70
ArrayAttr getI64ArrayAttr(ArrayRef< int64_t > values)
Definition: Builders.cpp:274
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
static DenseIntElementsAttr get(const ShapedType &type, Arg &&arg)
Get an instance of a DenseIntElementsAttr with the given arguments.
The DialectRegistry maps a dialect namespace to a constructor for the matching dialect.
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:432
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:75
MLIRContext * getContext()
Return the context this operation is associated with.
Definition: Operation.h:200
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
Definition: PatternMatch.h:33
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:668
MLIRContext * getContext() const
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the rewriter that the IR failed to be rewritten because of a match failure,...
Definition: PatternMatch.h:597
virtual void replaceOp(Operation *op, ValueRange newValues)
This method replaces the results of the operation with the specified list of values.
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replaces the result op with a new op that is created without verification.
Definition: PatternMatch.h:482
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:93
std::unique_ptr< Pass > createLowerVectorMaskPass()
Creates an instance of the vector.mask lowering pass.
void populateVectorMaskOpLoweringPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Populate the pattern set with the following patterns:
void populateVectorMaskLoweringPatternsForSideEffectingOps(RewritePatternSet &patterns)
Populates instances of MaskOpRewritePattern to lower masked operations with vector....
This header declares functions that assit transformations in the MemRef dialect.
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:62
LogicalResult applyPatternsAndFoldGreedily(Region &region, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig())
Rewrite ops in the given region, which must be isolated from above, by repeatedly applying the highes...
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:56
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value.
Definition: LogicalResult.h:72
This class represents an efficient way to signal success or failure.
Definition: LogicalResult.h:26
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Definition: PatternMatch.h:357
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...
Definition: PatternMatch.h:361