MLIR  22.0.0git
VectorEmulateMaskedLoadStore.cpp
Go to the documentation of this file.
1 //=- VectorEmulateMaskedLoadStore.cpp - Emulate 'vector.maskedload/store' op =//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements target-independent rewrites and utilities to emulate the
10 // 'vector.maskedload' and 'vector.maskedstore' operation.
11 //
12 //===----------------------------------------------------------------------===//
13 
17 
18 using namespace mlir;
19 
20 namespace {
21 
22 /// Convert vector.maskedload
23 ///
24 /// Before:
25 ///
26 /// vector.maskedload %base[%idx_0, %idx_1], %mask, %pass_thru
27 ///
28 /// After:
29 ///
30 /// %ivalue = %pass_thru
31 /// %m = vector.extract %mask[0]
32 /// %result0 = scf.if %m {
33 /// %v = memref.load %base[%idx_0, %idx_1]
34 /// %combined = vector.insert %v, %ivalue[0]
35 /// scf.yield %combined
36 /// } else {
37 /// scf.yield %ivalue
38 /// }
39 /// %m = vector.extract %mask[1]
40 /// %result1 = scf.if %m {
41 /// %v = memref.load %base[%idx_0, %idx_1 + 1]
42 /// %combined = vector.insert %v, %result0[1]
43 /// scf.yield %combined
44 /// } else {
45 /// scf.yield %result0
46 /// }
47 /// ...
48 ///
49 struct VectorMaskedLoadOpConverter final
50  : OpRewritePattern<vector::MaskedLoadOp> {
52 
53  LogicalResult matchAndRewrite(vector::MaskedLoadOp maskedLoadOp,
54  PatternRewriter &rewriter) const override {
55  VectorType maskVType = maskedLoadOp.getMaskVectorType();
56  if (maskVType.getShape().size() != 1)
57  return rewriter.notifyMatchFailure(
58  maskedLoadOp, "expected vector.maskedstore with 1-D mask");
59 
60  Location loc = maskedLoadOp.getLoc();
61  int64_t maskLength = maskVType.getShape()[0];
62 
63  Type indexType = rewriter.getIndexType();
64  Value mask = maskedLoadOp.getMask();
65  Value base = maskedLoadOp.getBase();
66  Value iValue = maskedLoadOp.getPassThru();
67  std::optional<uint64_t> alignment = maskedLoadOp.getAlignment();
68  auto indices = llvm::to_vector_of<Value>(maskedLoadOp.getIndices());
69  Value one = arith::ConstantOp::create(rewriter, loc, indexType,
70  IntegerAttr::get(indexType, 1));
71  for (int64_t i = 0; i < maskLength; ++i) {
72  auto maskBit = vector::ExtractOp::create(rewriter, loc, mask, i);
73 
74  auto ifOp = scf::IfOp::create(
75  rewriter, loc, maskBit,
76  [&](OpBuilder &builder, Location loc) {
77  auto loadedValue = memref::LoadOp::create(
78  builder, loc, base, indices, /*nontemporal=*/false,
79  alignment.value_or(0));
80  auto combinedValue =
81  vector::InsertOp::create(builder, loc, loadedValue, iValue, i);
82  scf::YieldOp::create(builder, loc, combinedValue.getResult());
83  },
84  [&](OpBuilder &builder, Location loc) {
85  scf::YieldOp::create(builder, loc, iValue);
86  });
87  iValue = ifOp.getResult(0);
88 
89  indices.back() =
90  arith::AddIOp::create(rewriter, loc, indices.back(), one);
91  }
92 
93  rewriter.replaceOp(maskedLoadOp, iValue);
94 
95  return success();
96  }
97 };
98 
99 /// Convert vector.maskedstore
100 ///
101 /// Before:
102 ///
103 /// vector.maskedstore %base[%idx_0, %idx_1], %mask, %value
104 ///
105 /// After:
106 ///
107 /// %m = vector.extract %mask[0]
108 /// scf.if %m {
109 /// %extracted = vector.extract %value[0]
110 /// memref.store %extracted, %base[%idx_0, %idx_1]
111 /// }
112 /// %m = vector.extract %mask[1]
113 /// scf.if %m {
114 /// %extracted = vector.extract %value[1]
115 /// memref.store %extracted, %base[%idx_0, %idx_1 + 1]
116 /// }
117 /// ...
118 ///
119 struct VectorMaskedStoreOpConverter final
120  : OpRewritePattern<vector::MaskedStoreOp> {
122 
123  LogicalResult matchAndRewrite(vector::MaskedStoreOp maskedStoreOp,
124  PatternRewriter &rewriter) const override {
125  VectorType maskVType = maskedStoreOp.getMaskVectorType();
126  if (maskVType.getShape().size() != 1)
127  return rewriter.notifyMatchFailure(
128  maskedStoreOp, "expected vector.maskedstore with 1-D mask");
129 
130  Location loc = maskedStoreOp.getLoc();
131  int64_t maskLength = maskVType.getShape()[0];
132 
133  Type indexType = rewriter.getIndexType();
134  Value mask = maskedStoreOp.getMask();
135  Value base = maskedStoreOp.getBase();
136  Value value = maskedStoreOp.getValueToStore();
137  bool nontemporal = false;
138  std::optional<uint64_t> alignment = maskedStoreOp.getAlignment();
139  auto indices = llvm::to_vector_of<Value>(maskedStoreOp.getIndices());
140  Value one = arith::ConstantOp::create(rewriter, loc, indexType,
141  IntegerAttr::get(indexType, 1));
142  for (int64_t i = 0; i < maskLength; ++i) {
143  auto maskBit = vector::ExtractOp::create(rewriter, loc, mask, i);
144 
145  auto ifOp = scf::IfOp::create(rewriter, loc, maskBit, /*else=*/false);
146  rewriter.setInsertionPointToStart(&ifOp.getThenRegion().front());
147  auto extractedValue = vector::ExtractOp::create(rewriter, loc, value, i);
148  memref::StoreOp::create(rewriter, loc, extractedValue, base, indices,
149  nontemporal, alignment.value_or(0));
150 
151  rewriter.setInsertionPointAfter(ifOp);
152  indices.back() =
153  arith::AddIOp::create(rewriter, loc, indices.back(), one);
154  }
155 
156  rewriter.eraseOp(maskedStoreOp);
157 
158  return success();
159  }
160 };
161 
162 } // namespace
163 
166  patterns.add<VectorMaskedLoadOpConverter, VectorMaskedStoreOpConverter>(
167  patterns.getContext(), benefit);
168 }
IndexType getIndexType()
Definition: Builders.cpp:50
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:76
This class helps build Operations.
Definition: Builders.h:207
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
Definition: Builders.h:431
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
Definition: Builders.h:412
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
Definition: PatternMatch.h:34
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:783
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
Definition: PatternMatch.h:716
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
void populateVectorMaskedLoadStoreEmulationPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Populate the pattern set with the following patterns:
Include the generated interface declarations.
const FrozenRewritePatternSet & patterns
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Definition: PatternMatch.h:314
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...
Definition: PatternMatch.h:319