17 #define GEN_PASS_DEF_LEGALIZEVECTORSTORAGE
18 #include "mlir/Dialect/ArmSVE/Transforms/Passes.h.inc"
44 bool isSVEMaskType(VectorType type) {
45 return type.getRank() > 0 && type.getElementType().isInteger(1) &&
46 type.getScalableDims().back() && type.getShape().back() < 16 &&
47 llvm::isPowerOf2_32(type.getShape().back()) &&
48 !llvm::is_contained(type.getScalableDims().drop_back(),
true);
51 VectorType widenScalableMaskTypeToSvbool(VectorType type) {
52 assert(isSVEMaskType(type));
58 template <
typename TOp,
typename TLegalizerCallback>
60 TLegalizerCallback callback) {
62 auto newOp = op.clone();
69 template <
typename TOp,
typename TLegalizerCallback>
70 void replaceOpWithUnrealizedConversion(
PatternRewriter &rewriter, TOp op,
71 TLegalizerCallback callback) {
72 replaceOpWithLegalizedOp(rewriter, op, [&](TOp newOp) {
74 return rewriter.
create<UnrealizedConversionCastOp>(
84 static FailureOr<Value> getSVELegalizedMemref(
Value illegalMemref) {
88 auto unrealizedConversion =
89 llvm::cast<UnrealizedConversionCastOp>(definingOp);
90 return unrealizedConversion.getOperand(0);
96 struct RelaxScalableVectorAllocaAlignment
100 LogicalResult matchAndRewrite(memref::AllocaOp allocaOp,
102 auto memrefElementType = allocaOp.getType().getElementType();
103 auto vectorType = llvm::dyn_cast<VectorType>(memrefElementType);
104 if (!vectorType || !vectorType.isScalable() || allocaOp.getAlignment())
108 unsigned aligment = vectorType.getElementType().isInteger(1) ? 2 : 16;
110 [&] { allocaOp.setAlignment(aligment); });
131 template <
typename AllocLikeOp>
135 LogicalResult matchAndRewrite(AllocLikeOp allocLikeOp,
138 llvm::dyn_cast<VectorType>(allocLikeOp.getType().getElementType());
140 if (!vectorType || !isSVEMaskType(vectorType))
146 replaceOpWithUnrealizedConversion(
147 rewriter, allocLikeOp, [&](AllocLikeOp newAllocLikeOp) {
148 newAllocLikeOp.getResult().setType(
149 llvm::cast<MemRefType>(newAllocLikeOp.getType().cloneWith(
150 {}, widenScalableMaskTypeToSvbool(vectorType))));
151 return newAllocLikeOp;
179 struct LegalizeSVEMaskTypeCastConversion
183 LogicalResult matchAndRewrite(vector::TypeCastOp typeCastOp,
185 auto resultType = typeCastOp.getResultMemRefType();
186 auto vectorType = llvm::dyn_cast<VectorType>(resultType.getElementType());
188 if (!vectorType || !isSVEMaskType(vectorType))
191 auto legalMemref = getSVELegalizedMemref(typeCastOp.getMemref());
192 if (failed(legalMemref))
196 replaceOpWithUnrealizedConversion(
197 rewriter, typeCastOp, [&](vector::TypeCastOp newTypeCast) {
198 newTypeCast.setOperand(*legalMemref);
199 newTypeCast.getResult().setType(
200 llvm::cast<MemRefType>(newTypeCast.getType().cloneWith(
201 {}, widenScalableMaskTypeToSvbool(vectorType))));
222 struct LegalizeSVEMaskStoreConversion
226 LogicalResult matchAndRewrite(memref::StoreOp storeOp,
228 auto loc = storeOp.getLoc();
230 Value valueToStore = storeOp.getValueToStore();
231 auto vectorType = llvm::dyn_cast<VectorType>(valueToStore.
getType());
233 if (!vectorType || !isSVEMaskType(vectorType))
236 auto legalMemref = getSVELegalizedMemref(storeOp.getMemref());
237 if (failed(legalMemref))
240 auto legalMaskType = widenScalableMaskTypeToSvbool(
241 llvm::cast<VectorType>(valueToStore.
getType()));
242 auto convertToSvbool = rewriter.
create<arm_sve::ConvertToSvboolOp>(
243 loc, legalMaskType, valueToStore);
246 replaceOpWithLegalizedOp(rewriter, storeOp,
247 [&](memref::StoreOp newStoreOp) {
248 newStoreOp.setOperand(0, convertToSvbool);
249 newStoreOp.setOperand(1, *legalMemref);
270 struct LegalizeSVEMaskLoadConversion :
public OpRewritePattern<memref::LoadOp> {
273 LogicalResult matchAndRewrite(memref::LoadOp loadOp,
275 auto loc = loadOp.getLoc();
277 Value loadedMask = loadOp.getResult();
278 auto vectorType = llvm::dyn_cast<VectorType>(loadedMask.
getType());
280 if (!vectorType || !isSVEMaskType(vectorType))
283 auto legalMemref = getSVELegalizedMemref(loadOp.getMemref());
284 if (failed(legalMemref))
287 auto legalMaskType = widenScalableMaskTypeToSvbool(vectorType);
290 replaceOpWithLegalizedOp(rewriter, loadOp, [&](memref::LoadOp newLoadOp) {
291 newLoadOp.setMemRef(*legalMemref);
292 newLoadOp.getResult().setType(legalMaskType);
293 return rewriter.
create<arm_sve::ConvertFromSvboolOp>(
294 loc, loadedMask.
getType(), newLoadOp);
305 patterns.
add<RelaxScalableVectorAllocaAlignment,
306 LegalizeSVEMaskAllocation<memref::AllocaOp>,
307 LegalizeSVEMaskAllocation<memref::AllocOp>,
308 LegalizeSVEMaskTypeCastConversion,
309 LegalizeSVEMaskStoreConversion, LegalizeSVEMaskLoadConversion>(
314 struct LegalizeVectorStorage
315 :
public arm_sve::impl::LegalizeVectorStorageBase<LegalizeVectorStorage> {
317 void runOnOperation()
override {
321 std::move(patterns)))) {
325 target.addDynamicallyLegalOp<UnrealizedConversionCastOp>(
326 [](UnrealizedConversionCastOp unrealizedConversion) {
338 return std::make_unique<LegalizeVectorStorage>();
static MLIRContext * getContext(OpFoldResult val)
constexpr StringLiteral kSVELegalizerTag("__arm_sve_legalize_vector_storage__")
StringAttr getStringAttr(const Twine &bytes)
This class describes a specific conversion target.
NamedAttribute represents a combination of a name and an Attribute value.
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Operation * insert(Operation *op)
Insert the given operation at the current insertion point and return it.
Operation is the basic unit of execution within MLIR.
bool hasAttr(StringAttr name)
Return true if the operation has an attribute with the provided name, false otherwise.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
MLIRContext * getContext() const
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
This class provides an abstraction over the various different ranges of value types.
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
This is a builder type that keeps local references to arguments.
Builder & setDim(unsigned pos, int64_t val)
Set a dim in shape @pos to val.
void populateLegalizeVectorStoragePatterns(RewritePatternSet &patterns)
Collect a set of patterns to legalize Arm SVE vector storage.
std::unique_ptr< Pass > createLegalizeVectorStoragePass()
Pass to legalize Arm SVE vector storage.
Include the generated interface declarations.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
LogicalResult applyPatternsAndFoldGreedily(Region ®ion, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr)
Rewrite ops in the given region, which must be isolated from above, by repeatedly applying the highes...
LogicalResult applyPartialConversion(ArrayRef< Operation * > ops, const ConversionTarget &target, const FrozenRewritePatternSet &patterns, ConversionConfig config=ConversionConfig())
Below we define several entry points for operation conversion.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...