MLIR  22.0.0git
VectorEmulateMaskedLoadStore.cpp
Go to the documentation of this file.
1 //=- VectorEmulateMaskedLoadStore.cpp - Emulate 'vector.maskedload/store' op =//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements target-independent rewrites and utilities to emulate the
10 // 'vector.maskedload' and 'vector.maskedstore' operation.
11 //
12 //===----------------------------------------------------------------------===//
13 
17 
18 using namespace mlir;
19 
20 namespace {
21 
22 /// Convert vector.maskedload
23 ///
24 /// Before:
25 ///
26 /// vector.maskedload %base[%idx_0, %idx_1], %mask, %pass_thru
27 ///
28 /// After:
29 ///
30 /// %ivalue = %pass_thru
31 /// %m = vector.extract %mask[0]
32 /// %result0 = scf.if %m {
33 /// %v = memref.load %base[%idx_0, %idx_1]
34 /// %combined = vector.insert %v, %ivalue[0]
35 /// scf.yield %combined
36 /// } else {
37 /// scf.yield %ivalue
38 /// }
39 /// %m = vector.extract %mask[1]
40 /// %result1 = scf.if %m {
41 /// %v = memref.load %base[%idx_0, %idx_1 + 1]
42 /// %combined = vector.insert %v, %result0[1]
43 /// scf.yield %combined
44 /// } else {
45 /// scf.yield %result0
46 /// }
47 /// ...
48 ///
49 struct VectorMaskedLoadOpConverter final
50  : OpRewritePattern<vector::MaskedLoadOp> {
51  using Base::Base;
52 
53  LogicalResult matchAndRewrite(vector::MaskedLoadOp maskedLoadOp,
54  PatternRewriter &rewriter) const override {
55  VectorType maskVType = maskedLoadOp.getMaskVectorType();
56  if (maskVType.getShape().size() != 1)
57  return rewriter.notifyMatchFailure(
58  maskedLoadOp, "expected vector.maskedstore with 1-D mask");
59 
60  Location loc = maskedLoadOp.getLoc();
61  int64_t maskLength = maskVType.getShape()[0];
62 
63  Type indexType = rewriter.getIndexType();
64  Value mask = maskedLoadOp.getMask();
65  Value base = maskedLoadOp.getBase();
66  Value iValue = maskedLoadOp.getPassThru();
67  auto indices = llvm::to_vector_of<Value>(maskedLoadOp.getIndices());
68  Value one = arith::ConstantOp::create(rewriter, loc, indexType,
69  IntegerAttr::get(indexType, 1));
70  for (int64_t i = 0; i < maskLength; ++i) {
71  auto maskBit = vector::ExtractOp::create(rewriter, loc, mask, i);
72 
73  auto ifOp = scf::IfOp::create(
74  rewriter, loc, maskBit,
75  [&](OpBuilder &builder, Location loc) {
76  auto loadedValue = memref::LoadOp::create(
77  builder, loc, base, indices, /*nontemporal=*/false,
78  llvm::MaybeAlign(maskedLoadOp.getAlignment().value_or(0)));
79  auto combinedValue =
80  vector::InsertOp::create(builder, loc, loadedValue, iValue, i);
81  scf::YieldOp::create(builder, loc, combinedValue.getResult());
82  },
83  [&](OpBuilder &builder, Location loc) {
84  scf::YieldOp::create(builder, loc, iValue);
85  });
86  iValue = ifOp.getResult(0);
87 
88  indices.back() =
89  arith::AddIOp::create(rewriter, loc, indices.back(), one);
90  }
91 
92  rewriter.replaceOp(maskedLoadOp, iValue);
93 
94  return success();
95  }
96 };
97 
98 /// Convert vector.maskedstore
99 ///
100 /// Before:
101 ///
102 /// vector.maskedstore %base[%idx_0, %idx_1], %mask, %value
103 ///
104 /// After:
105 ///
106 /// %m = vector.extract %mask[0]
107 /// scf.if %m {
108 /// %extracted = vector.extract %value[0]
109 /// memref.store %extracted, %base[%idx_0, %idx_1]
110 /// }
111 /// %m = vector.extract %mask[1]
112 /// scf.if %m {
113 /// %extracted = vector.extract %value[1]
114 /// memref.store %extracted, %base[%idx_0, %idx_1 + 1]
115 /// }
116 /// ...
117 ///
118 struct VectorMaskedStoreOpConverter final
119  : OpRewritePattern<vector::MaskedStoreOp> {
120  using Base::Base;
121 
122  LogicalResult matchAndRewrite(vector::MaskedStoreOp maskedStoreOp,
123  PatternRewriter &rewriter) const override {
124  VectorType maskVType = maskedStoreOp.getMaskVectorType();
125  if (maskVType.getShape().size() != 1)
126  return rewriter.notifyMatchFailure(
127  maskedStoreOp, "expected vector.maskedstore with 1-D mask");
128 
129  Location loc = maskedStoreOp.getLoc();
130  int64_t maskLength = maskVType.getShape()[0];
131 
132  Type indexType = rewriter.getIndexType();
133  Value mask = maskedStoreOp.getMask();
134  Value base = maskedStoreOp.getBase();
135  Value value = maskedStoreOp.getValueToStore();
136  bool nontemporal = false;
137  auto indices = llvm::to_vector_of<Value>(maskedStoreOp.getIndices());
138  Value one = arith::ConstantOp::create(rewriter, loc, indexType,
139  IntegerAttr::get(indexType, 1));
140  for (int64_t i = 0; i < maskLength; ++i) {
141  auto maskBit = vector::ExtractOp::create(rewriter, loc, mask, i);
142 
143  auto ifOp = scf::IfOp::create(rewriter, loc, maskBit, /*else=*/false);
144  rewriter.setInsertionPointToStart(&ifOp.getThenRegion().front());
145  auto extractedValue = vector::ExtractOp::create(rewriter, loc, value, i);
146  memref::StoreOp::create(
147  rewriter, loc, extractedValue, base, indices, nontemporal,
148  llvm::MaybeAlign(maskedStoreOp.getAlignment().value_or(0)));
149 
150  rewriter.setInsertionPointAfter(ifOp);
151  indices.back() =
152  arith::AddIOp::create(rewriter, loc, indices.back(), one);
153  }
154 
155  rewriter.eraseOp(maskedStoreOp);
156 
157  return success();
158  }
159 };
160 
161 } // namespace
162 
165  patterns.add<VectorMaskedLoadOpConverter, VectorMaskedStoreOpConverter>(
166  patterns.getContext(), benefit);
167 }
IndexType getIndexType()
Definition: Builders.cpp:51
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:76
This class helps build Operations.
Definition: Builders.h:207
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
Definition: Builders.h:431
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
Definition: Builders.h:412
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
Definition: PatternMatch.h:34
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:793
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
Definition: PatternMatch.h:726
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
void populateVectorMaskedLoadStoreEmulationPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Populate the pattern set with the following patterns:
Include the generated interface declarations.
const FrozenRewritePatternSet & patterns
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Definition: PatternMatch.h:314