49struct VectorMaskedLoadOpConverter final
53 LogicalResult matchAndRewrite(vector::MaskedLoadOp maskedLoadOp,
54 PatternRewriter &rewriter)
const override {
55 VectorType maskVType = maskedLoadOp.getMaskVectorType();
56 if (maskVType.getShape().size() != 1)
58 maskedLoadOp,
"expected vector.maskedstore with 1-D mask");
60 Location loc = maskedLoadOp.getLoc();
61 int64_t maskLength = maskVType.getShape()[0];
64 Value mask = maskedLoadOp.getMask();
65 Value base = maskedLoadOp.getBase();
66 Value iValue = maskedLoadOp.getPassThru();
67 auto indices = llvm::to_vector_of<Value>(maskedLoadOp.getIndices());
68 Value one = arith::ConstantOp::create(rewriter, loc, indexType,
69 IntegerAttr::get(indexType, 1));
70 for (int64_t i = 0; i < maskLength; ++i) {
71 auto maskBit = vector::ExtractOp::create(rewriter, loc, mask, i);
73 auto ifOp = scf::IfOp::create(
74 rewriter, loc, maskBit,
75 [&](OpBuilder &builder, Location loc) {
76 auto loadedValue = memref::LoadOp::create(
77 builder, loc, base,
indices,
false,
78 llvm::MaybeAlign(maskedLoadOp.getAlignment().value_or(0)));
80 vector::InsertOp::create(builder, loc, loadedValue, iValue, i);
81 scf::YieldOp::create(builder, loc, combinedValue.getResult());
83 [&](OpBuilder &builder, Location loc) {
84 scf::YieldOp::create(builder, loc, iValue);
86 iValue = ifOp.getResult(0);
89 arith::AddIOp::create(rewriter, loc,
indices.back(), one);
118struct VectorMaskedStoreOpConverter final
122 LogicalResult matchAndRewrite(vector::MaskedStoreOp maskedStoreOp,
123 PatternRewriter &rewriter)
const override {
124 VectorType maskVType = maskedStoreOp.getMaskVectorType();
125 if (maskVType.getShape().size() != 1)
127 maskedStoreOp,
"expected vector.maskedstore with 1-D mask");
129 Location loc = maskedStoreOp.getLoc();
130 int64_t maskLength = maskVType.getShape()[0];
133 Value mask = maskedStoreOp.getMask();
134 Value base = maskedStoreOp.getBase();
135 Value value = maskedStoreOp.getValueToStore();
136 bool nontemporal =
false;
137 auto indices = llvm::to_vector_of<Value>(maskedStoreOp.getIndices());
138 Value one = arith::ConstantOp::create(rewriter, loc, indexType,
139 IntegerAttr::get(indexType, 1));
140 for (int64_t i = 0; i < maskLength; ++i) {
141 auto maskBit = vector::ExtractOp::create(rewriter, loc, mask, i);
143 auto ifOp = scf::IfOp::create(rewriter, loc, maskBit,
false);
145 auto extractedValue = vector::ExtractOp::create(rewriter, loc, value, i);
146 memref::StoreOp::create(
147 rewriter, loc, extractedValue, base,
indices, nontemporal,
148 llvm::MaybeAlign(maskedStoreOp.getAlignment().value_or(0)));
152 arith::AddIOp::create(rewriter, loc,
indices.back(), one);
155 rewriter.
eraseOp(maskedStoreOp);
165 patterns.add<VectorMaskedLoadOpConverter, VectorMaskedStoreOpConverter>(
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
void populateVectorMaskedLoadStoreEmulationPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Populate the pattern set with the following patterns:
Include the generated interface declarations.
const FrozenRewritePatternSet & patterns
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...