MLIR 22.0.0git
VectorEmulateMaskedLoadStore.cpp
Go to the documentation of this file.
1//=- VectorEmulateMaskedLoadStore.cpp - Emulate 'vector.maskedload/store' op =//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements target-independent rewrites and utilities to emulate the
10// 'vector.maskedload' and 'vector.maskedstore' operation.
11//
12//===----------------------------------------------------------------------===//
13
17
18using namespace mlir;
19
20namespace {
21
22/// Convert vector.maskedload
23///
24/// Before:
25///
26/// vector.maskedload %base[%idx_0, %idx_1], %mask, %pass_thru
27///
28/// After:
29///
30/// %ivalue = %pass_thru
31/// %m = vector.extract %mask[0]
32/// %result0 = scf.if %m {
33/// %v = memref.load %base[%idx_0, %idx_1]
34/// %combined = vector.insert %v, %ivalue[0]
35/// scf.yield %combined
36/// } else {
37/// scf.yield %ivalue
38/// }
39/// %m = vector.extract %mask[1]
40/// %result1 = scf.if %m {
41/// %v = memref.load %base[%idx_0, %idx_1 + 1]
42/// %combined = vector.insert %v, %result0[1]
43/// scf.yield %combined
44/// } else {
45/// scf.yield %result0
46/// }
47/// ...
48///
49struct VectorMaskedLoadOpConverter final
50 : OpRewritePattern<vector::MaskedLoadOp> {
51 using Base::Base;
52
53 LogicalResult matchAndRewrite(vector::MaskedLoadOp maskedLoadOp,
54 PatternRewriter &rewriter) const override {
55 VectorType maskVType = maskedLoadOp.getMaskVectorType();
56 if (maskVType.getShape().size() != 1)
57 return rewriter.notifyMatchFailure(
58 maskedLoadOp, "expected vector.maskedstore with 1-D mask");
59
60 Location loc = maskedLoadOp.getLoc();
61 int64_t maskLength = maskVType.getShape()[0];
62
63 Type indexType = rewriter.getIndexType();
64 Value mask = maskedLoadOp.getMask();
65 Value base = maskedLoadOp.getBase();
66 Value iValue = maskedLoadOp.getPassThru();
67 auto indices = llvm::to_vector_of<Value>(maskedLoadOp.getIndices());
68 Value one = arith::ConstantOp::create(rewriter, loc, indexType,
69 IntegerAttr::get(indexType, 1));
70 for (int64_t i = 0; i < maskLength; ++i) {
71 auto maskBit = vector::ExtractOp::create(rewriter, loc, mask, i);
72
73 auto ifOp = scf::IfOp::create(
74 rewriter, loc, maskBit,
75 [&](OpBuilder &builder, Location loc) {
76 auto loadedValue = memref::LoadOp::create(
77 builder, loc, base, indices, /*nontemporal=*/false,
78 llvm::MaybeAlign(maskedLoadOp.getAlignment().value_or(0)));
79 auto combinedValue =
80 vector::InsertOp::create(builder, loc, loadedValue, iValue, i);
81 scf::YieldOp::create(builder, loc, combinedValue.getResult());
82 },
83 [&](OpBuilder &builder, Location loc) {
84 scf::YieldOp::create(builder, loc, iValue);
85 });
86 iValue = ifOp.getResult(0);
87
88 indices.back() =
89 arith::AddIOp::create(rewriter, loc, indices.back(), one);
90 }
91
92 rewriter.replaceOp(maskedLoadOp, iValue);
93
94 return success();
95 }
96};
97
98/// Convert vector.maskedstore
99///
100/// Before:
101///
102/// vector.maskedstore %base[%idx_0, %idx_1], %mask, %value
103///
104/// After:
105///
106/// %m = vector.extract %mask[0]
107/// scf.if %m {
108/// %extracted = vector.extract %value[0]
109/// memref.store %extracted, %base[%idx_0, %idx_1]
110/// }
111/// %m = vector.extract %mask[1]
112/// scf.if %m {
113/// %extracted = vector.extract %value[1]
114/// memref.store %extracted, %base[%idx_0, %idx_1 + 1]
115/// }
116/// ...
117///
118struct VectorMaskedStoreOpConverter final
119 : OpRewritePattern<vector::MaskedStoreOp> {
120 using Base::Base;
121
122 LogicalResult matchAndRewrite(vector::MaskedStoreOp maskedStoreOp,
123 PatternRewriter &rewriter) const override {
124 VectorType maskVType = maskedStoreOp.getMaskVectorType();
125 if (maskVType.getShape().size() != 1)
126 return rewriter.notifyMatchFailure(
127 maskedStoreOp, "expected vector.maskedstore with 1-D mask");
128
129 Location loc = maskedStoreOp.getLoc();
130 int64_t maskLength = maskVType.getShape()[0];
131
132 Type indexType = rewriter.getIndexType();
133 Value mask = maskedStoreOp.getMask();
134 Value base = maskedStoreOp.getBase();
135 Value value = maskedStoreOp.getValueToStore();
136 bool nontemporal = false;
137 auto indices = llvm::to_vector_of<Value>(maskedStoreOp.getIndices());
138 Value one = arith::ConstantOp::create(rewriter, loc, indexType,
139 IntegerAttr::get(indexType, 1));
140 for (int64_t i = 0; i < maskLength; ++i) {
141 auto maskBit = vector::ExtractOp::create(rewriter, loc, mask, i);
142
143 auto ifOp = scf::IfOp::create(rewriter, loc, maskBit, /*else=*/false);
144 rewriter.setInsertionPointToStart(&ifOp.getThenRegion().front());
145 auto extractedValue = vector::ExtractOp::create(rewriter, loc, value, i);
146 memref::StoreOp::create(
147 rewriter, loc, extractedValue, base, indices, nontemporal,
148 llvm::MaybeAlign(maskedStoreOp.getAlignment().value_or(0)));
149
150 rewriter.setInsertionPointAfter(ifOp);
151 indices.back() =
152 arith::AddIOp::create(rewriter, loc, indices.back(), one);
153 }
154
155 rewriter.eraseOp(maskedStoreOp);
156
157 return success();
158 }
159};
160
161} // namespace
162
165 patterns.add<VectorMaskedLoadOpConverter, VectorMaskedStoreOpConverter>(
166 patterns.getContext(), benefit);
167}
return success()
IndexType getIndexType()
Definition Builders.cpp:51
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
Definition Builders.h:431
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
Definition Builders.h:412
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
void populateVectorMaskedLoadStoreEmulationPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Populate the pattern set with the following patterns:
Include the generated interface declarations.
const FrozenRewritePatternSet & patterns
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...