49 struct VectorMaskedLoadOpConverter final
53 LogicalResult matchAndRewrite(vector::MaskedLoadOp maskedLoadOp,
55 VectorType maskVType = maskedLoadOp.getMaskVectorType();
56 if (maskVType.getShape().size() != 1)
58 maskedLoadOp,
"expected vector.maskedstore with 1-D mask");
60 Location loc = maskedLoadOp.getLoc();
61 int64_t maskLength = maskVType.getShape()[0];
64 Value mask = maskedLoadOp.getMask();
65 Value base = maskedLoadOp.getBase();
66 Value iValue = maskedLoadOp.getPassThru();
67 std::optional<uint64_t> alignment = maskedLoadOp.getAlignment();
68 auto indices = llvm::to_vector_of<Value>(maskedLoadOp.getIndices());
69 Value one = arith::ConstantOp::create(rewriter, loc, indexType,
71 for (int64_t i = 0; i < maskLength; ++i) {
72 auto maskBit = vector::ExtractOp::create(rewriter, loc, mask, i);
74 auto ifOp = scf::IfOp::create(
75 rewriter, loc, maskBit,
77 auto loadedValue = memref::LoadOp::create(
78 builder, loc, base, indices,
false,
79 alignment.value_or(0));
81 vector::InsertOp::create(builder, loc, loadedValue, iValue, i);
82 scf::YieldOp::create(builder, loc, combinedValue.getResult());
85 scf::YieldOp::create(builder, loc, iValue);
87 iValue = ifOp.getResult(0);
90 arith::AddIOp::create(rewriter, loc, indices.back(), one);
119 struct VectorMaskedStoreOpConverter final
123 LogicalResult matchAndRewrite(vector::MaskedStoreOp maskedStoreOp,
125 VectorType maskVType = maskedStoreOp.getMaskVectorType();
126 if (maskVType.getShape().size() != 1)
128 maskedStoreOp,
"expected vector.maskedstore with 1-D mask");
130 Location loc = maskedStoreOp.getLoc();
131 int64_t maskLength = maskVType.getShape()[0];
134 Value mask = maskedStoreOp.getMask();
135 Value base = maskedStoreOp.getBase();
136 Value value = maskedStoreOp.getValueToStore();
137 bool nontemporal =
false;
138 std::optional<uint64_t> alignment = maskedStoreOp.getAlignment();
139 auto indices = llvm::to_vector_of<Value>(maskedStoreOp.getIndices());
140 Value one = arith::ConstantOp::create(rewriter, loc, indexType,
142 for (int64_t i = 0; i < maskLength; ++i) {
143 auto maskBit = vector::ExtractOp::create(rewriter, loc, mask, i);
145 auto ifOp = scf::IfOp::create(rewriter, loc, maskBit,
false);
147 auto extractedValue = vector::ExtractOp::create(rewriter, loc, value, i);
148 memref::StoreOp::create(rewriter, loc, extractedValue, base, indices,
149 nontemporal, alignment.value_or(0));
153 arith::AddIOp::create(rewriter, loc, indices.back(), one);
156 rewriter.
eraseOp(maskedStoreOp);
166 patterns.add<VectorMaskedLoadOpConverter, VectorMaskedStoreOpConverter>(
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
This class helps build Operations.
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
void populateVectorMaskedLoadStoreEmulationPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Populate the pattern set with the following patterns:
Include the generated interface declarations.
const FrozenRewritePatternSet & patterns
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...