MLIR  19.0.0git
VectorEmulateMaskedLoadStore.cpp
Go to the documentation of this file.
1 //=- VectorEmulateMaskedLoadStore.cpp - Emulate 'vector.maskedload/store' op =//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements target-independent rewrites and utilities to emulate the
10 // 'vector.maskedload' and 'vector.maskedstore' operation.
11 //
12 //===----------------------------------------------------------------------===//
13 
17 
18 using namespace mlir;
19 
20 namespace {
21 
22 /// Convert vector.maskedload
23 ///
24 /// Before:
25 ///
26 /// vector.maskedload %base[%idx_0, %idx_1], %mask, %pass_thru
27 ///
28 /// After:
29 ///
30 /// %ivalue = %pass_thru
31 /// %m = vector.extract %mask[0]
32 /// %result0 = scf.if %m {
33 /// %v = memref.load %base[%idx_0, %idx_1]
34 /// %combined = vector.insert %v, %ivalue[0]
35 /// scf.yield %combined
36 /// } else {
37 /// scf.yield %ivalue
38 /// }
39 /// %m = vector.extract %mask[1]
40 /// %result1 = scf.if %m {
41 /// %v = memref.load %base[%idx_0, %idx_1 + 1]
42 /// %combined = vector.insert %v, %result0[1]
43 /// scf.yield %combined
44 /// } else {
45 /// scf.yield %result0
46 /// }
47 /// ...
48 ///
49 struct VectorMaskedLoadOpConverter final
50  : OpRewritePattern<vector::MaskedLoadOp> {
52 
53  LogicalResult matchAndRewrite(vector::MaskedLoadOp maskedLoadOp,
54  PatternRewriter &rewriter) const override {
55  VectorType maskVType = maskedLoadOp.getMaskVectorType();
56  if (maskVType.getShape().size() != 1)
57  return rewriter.notifyMatchFailure(
58  maskedLoadOp, "expected vector.maskedstore with 1-D mask");
59 
60  Location loc = maskedLoadOp.getLoc();
61  int64_t maskLength = maskVType.getShape()[0];
62 
63  Type indexType = rewriter.getIndexType();
64  Value mask = maskedLoadOp.getMask();
65  Value base = maskedLoadOp.getBase();
66  Value iValue = maskedLoadOp.getPassThru();
67  auto indices = llvm::to_vector_of<Value>(maskedLoadOp.getIndices());
68  Value one = rewriter.create<arith::ConstantOp>(
69  loc, indexType, IntegerAttr::get(indexType, 1));
70  for (int64_t i = 0; i < maskLength; ++i) {
71  auto maskBit = rewriter.create<vector::ExtractOp>(loc, mask, i);
72 
73  auto ifOp = rewriter.create<scf::IfOp>(
74  loc, maskBit,
75  [&](OpBuilder &builder, Location loc) {
76  auto loadedValue =
77  builder.create<memref::LoadOp>(loc, base, indices);
78  auto combinedValue =
79  builder.create<vector::InsertOp>(loc, loadedValue, iValue, i);
80  builder.create<scf::YieldOp>(loc, combinedValue.getResult());
81  },
82  [&](OpBuilder &builder, Location loc) {
83  builder.create<scf::YieldOp>(loc, iValue);
84  });
85  iValue = ifOp.getResult(0);
86 
87  indices.back() = rewriter.create<arith::AddIOp>(loc, indices.back(), one);
88  }
89 
90  rewriter.replaceOp(maskedLoadOp, iValue);
91 
92  return success();
93  }
94 };
95 
96 /// Convert vector.maskedstore
97 ///
98 /// Before:
99 ///
100 /// vector.maskedstore %base[%idx_0, %idx_1], %mask, %value
101 ///
102 /// After:
103 ///
104 /// %m = vector.extract %mask[0]
105 /// scf.if %m {
106 /// %extracted = vector.extract %value[0]
107 /// memref.store %extracted, %base[%idx_0, %idx_1]
108 /// }
109 /// %m = vector.extract %mask[1]
110 /// scf.if %m {
111 /// %extracted = vector.extract %value[1]
112 /// memref.store %extracted, %base[%idx_0, %idx_1 + 1]
113 /// }
114 /// ...
115 ///
116 struct VectorMaskedStoreOpConverter final
117  : OpRewritePattern<vector::MaskedStoreOp> {
119 
120  LogicalResult matchAndRewrite(vector::MaskedStoreOp maskedStoreOp,
121  PatternRewriter &rewriter) const override {
122  VectorType maskVType = maskedStoreOp.getMaskVectorType();
123  if (maskVType.getShape().size() != 1)
124  return rewriter.notifyMatchFailure(
125  maskedStoreOp, "expected vector.maskedstore with 1-D mask");
126 
127  Location loc = maskedStoreOp.getLoc();
128  int64_t maskLength = maskVType.getShape()[0];
129 
130  Type indexType = rewriter.getIndexType();
131  Value mask = maskedStoreOp.getMask();
132  Value base = maskedStoreOp.getBase();
133  Value value = maskedStoreOp.getValueToStore();
134  auto indices = llvm::to_vector_of<Value>(maskedStoreOp.getIndices());
135  Value one = rewriter.create<arith::ConstantOp>(
136  loc, indexType, IntegerAttr::get(indexType, 1));
137  for (int64_t i = 0; i < maskLength; ++i) {
138  auto maskBit = rewriter.create<vector::ExtractOp>(loc, mask, i);
139 
140  auto ifOp = rewriter.create<scf::IfOp>(loc, maskBit, /*else=*/false);
141  rewriter.setInsertionPointToStart(&ifOp.getThenRegion().front());
142  auto extractedValue = rewriter.create<vector::ExtractOp>(loc, value, i);
143  rewriter.create<memref::StoreOp>(loc, extractedValue, base, indices);
144 
145  rewriter.setInsertionPointAfter(ifOp);
146  indices.back() = rewriter.create<arith::AddIOp>(loc, indices.back(), one);
147  }
148 
149  rewriter.eraseOp(maskedStoreOp);
150 
151  return success();
152  }
153 };
154 
155 } // namespace
156 
158  RewritePatternSet &patterns, PatternBenefit benefit) {
159  patterns.add<VectorMaskedLoadOpConverter, VectorMaskedStoreOpConverter>(
160  patterns.getContext(), benefit);
161 }
IndexType getIndexType()
Definition: Builders.cpp:71
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:63
This class helps build Operations.
Definition: Builders.h:209
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
Definition: Builders.h:433
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:464
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
Definition: Builders.h:414
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Definition: Operation.h:402
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
Definition: PatternMatch.h:34
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:785
MLIRContext * getContext() const
Definition: PatternMatch.h:823
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
Definition: PatternMatch.h:847
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
Definition: PatternMatch.h:718
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
void populateVectorMaskedLoadStoreEmulationPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Populate the pattern set with the following patterns:
Include the generated interface declarations.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Definition: PatternMatch.h:358
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...
Definition: PatternMatch.h:362