17 #define GEN_PASS_DEF_LEGALIZEVECTORSTORAGE
18 #include "mlir/Dialect/ArmSVE/Transforms/Passes.h.inc"
44 bool isSVEMaskType(VectorType type) {
45 return type.getRank() > 0 && type.getElementType().isInteger(1) &&
46 type.getScalableDims().back() && type.getShape().back() < 16 &&
47 llvm::isPowerOf2_32(type.getShape().back()) &&
48 !llvm::is_contained(type.getScalableDims().drop_back(),
true);
51 VectorType widenScalableMaskTypeToSvbool(VectorType type) {
52 assert(isSVEMaskType(type));
58 template <
typename TOp,
typename TLegalizerCallback>
60 TLegalizerCallback callback) {
62 auto newOp = op.clone();
69 template <
typename TOp,
typename TLegalizerCallback>
70 void replaceOpWithUnrealizedConversion(
PatternRewriter &rewriter, TOp op,
71 TLegalizerCallback callback) {
72 replaceOpWithLegalizedOp(rewriter, op, [&](TOp newOp) {
74 return UnrealizedConversionCastOp::create(
75 rewriter, op.getLoc(),
TypeRange{op.getResult().getType()},
84 static FailureOr<Value> getSVELegalizedMemref(
Value illegalMemref) {
88 auto unrealizedConversion =
89 llvm::cast<UnrealizedConversionCastOp>(definingOp);
90 return unrealizedConversion.getOperand(0);
96 struct RelaxScalableVectorAllocaAlignment
100 LogicalResult matchAndRewrite(memref::AllocaOp allocaOp,
102 auto memrefElementType = allocaOp.getType().getElementType();
103 auto vectorType = llvm::dyn_cast<VectorType>(memrefElementType);
104 if (!vectorType || !vectorType.isScalable() || allocaOp.getAlignment())
108 unsigned aligment = vectorType.getElementType().isInteger(1) ? 2 : 16;
110 [&] { allocaOp.setAlignment(aligment); });
131 template <
typename AllocLikeOp>
135 LogicalResult matchAndRewrite(AllocLikeOp allocLikeOp,
138 llvm::dyn_cast<VectorType>(allocLikeOp.getType().getElementType());
140 if (!vectorType || !isSVEMaskType(vectorType))
146 replaceOpWithUnrealizedConversion(
147 rewriter, allocLikeOp, [&](AllocLikeOp newAllocLikeOp) {
148 newAllocLikeOp.getResult().setType(
149 llvm::cast<MemRefType>(newAllocLikeOp.getType().cloneWith(
150 {}, widenScalableMaskTypeToSvbool(vectorType))));
151 return newAllocLikeOp;
179 struct LegalizeSVEMaskTypeCastConversion
183 LogicalResult matchAndRewrite(vector::TypeCastOp typeCastOp,
185 auto resultType = typeCastOp.getResultMemRefType();
186 auto vectorType = llvm::dyn_cast<VectorType>(resultType.getElementType());
188 if (!vectorType || !isSVEMaskType(vectorType))
191 auto legalMemref = getSVELegalizedMemref(typeCastOp.getMemref());
196 replaceOpWithUnrealizedConversion(
197 rewriter, typeCastOp, [&](vector::TypeCastOp newTypeCast) {
198 newTypeCast.setOperand(*legalMemref);
199 newTypeCast.getResult().setType(
200 llvm::cast<MemRefType>(newTypeCast.getType().cloneWith(
201 {}, widenScalableMaskTypeToSvbool(vectorType))));
222 struct LegalizeSVEMaskStoreConversion
226 LogicalResult matchAndRewrite(memref::StoreOp storeOp,
228 auto loc = storeOp.getLoc();
230 Value valueToStore = storeOp.getValueToStore();
231 auto vectorType = llvm::dyn_cast<VectorType>(valueToStore.
getType());
233 if (!vectorType || !isSVEMaskType(vectorType))
236 auto legalMemref = getSVELegalizedMemref(storeOp.getMemref());
240 auto legalMaskType = widenScalableMaskTypeToSvbool(
241 llvm::cast<VectorType>(valueToStore.
getType()));
242 auto convertToSvbool = arm_sve::ConvertToSvboolOp::create(
243 rewriter, loc, legalMaskType, valueToStore);
246 replaceOpWithLegalizedOp(rewriter, storeOp,
247 [&](memref::StoreOp newStoreOp) {
248 newStoreOp.setOperand(0, convertToSvbool);
249 newStoreOp.setOperand(1, *legalMemref);
270 struct LegalizeSVEMaskLoadConversion :
public OpRewritePattern<memref::LoadOp> {
273 LogicalResult matchAndRewrite(memref::LoadOp loadOp,
275 auto loc = loadOp.getLoc();
277 Value loadedMask = loadOp.getResult();
278 auto vectorType = llvm::dyn_cast<VectorType>(loadedMask.
getType());
280 if (!vectorType || !isSVEMaskType(vectorType))
283 auto legalMemref = getSVELegalizedMemref(loadOp.getMemref());
287 auto legalMaskType = widenScalableMaskTypeToSvbool(vectorType);
290 replaceOpWithLegalizedOp(rewriter, loadOp, [&](memref::LoadOp newLoadOp) {
291 newLoadOp.setMemRef(*legalMemref);
292 newLoadOp.getResult().setType(legalMaskType);
293 return arm_sve::ConvertFromSvboolOp::create(
294 rewriter, loc, loadedMask.
getType(), newLoadOp);
321 struct LegalizeTransferRead :
public OpRewritePattern<vector::TransferReadOp> {
324 LogicalResult matchAndRewrite(vector::TransferReadOp readOp,
338 if (readOp.isMasked() || readOp.getMask())
340 "masked transfers not-supported");
351 if (!readOp.getPermutationMap().isMinorIdentity())
362 VectorType origVT = readOp.getVectorType();
364 const int64_t origVRank = origVT.getRank();
365 if (origVRank < 2 || origVT.getNumScalableDims() != 1)
371 const int64_t numCollapseDims = std::distance(
372 llvm::find(origScalableDims,
true), origScalableDims.end());
373 if (numCollapseDims < 2)
375 "scalable dimension is trailing");
379 auto memTy = dyn_cast<MemRefType>(readOp.getBase().getType());
380 if (!(memTy && memTy.areTrailingDimsContiguous(numCollapseDims)))
382 readOp,
"non-contiguous memref dimensions to collapse");
390 if (!llvm::equal(memTy.getShape().take_back(numCollapseDims - 1),
391 origVT.getShape().take_back(numCollapseDims - 1)))
393 readOp,
"memref and vector dimensions do not match");
398 [](
bool v) {
return v; }))
400 readOp,
"out-of-bounds transfer from a dimension to collapse");
404 for (int64_t i = 0; i < memTy.getRank() - numCollapseDims + 1; ++i)
405 reassoc.push_back({i});
406 for (int64_t i = memTy.getRank() - numCollapseDims + 1; i < memTy.getRank();
408 reassoc.back().push_back(i);
409 if (!memref::CollapseShapeOp::isGuaranteedCollapsible(memTy, reassoc))
411 Value collapsedMem = memref::CollapseShapeOp::create(
412 rewriter, readOp.getLoc(), readOp.getBase(), reassoc);
416 for (int64_t i = origVRank - numCollapseDims + 1; i < origVRank; ++i)
417 shape[origVRank - numCollapseDims] *= shape[i];
418 shape.pop_back_n(numCollapseDims - 1);
421 origScalableDims.drop_back(numCollapseDims - 1));
424 auto indices = readOp.getIndices().drop_back(numCollapseDims - 1);
427 auto newReadOp = vector::TransferReadOp::create(
428 rewriter, readOp.getLoc(), collapsedVT, collapsedMem, indices,
433 auto toOrigShape = vector::ShapeCastOp::create(rewriter, readOp.getLoc(),
446 .add<RelaxScalableVectorAllocaAlignment,
447 LegalizeSVEMaskAllocation<memref::AllocaOp>,
448 LegalizeSVEMaskAllocation<memref::AllocOp>,
449 LegalizeSVEMaskTypeCastConversion, LegalizeSVEMaskStoreConversion,
450 LegalizeSVEMaskLoadConversion, LegalizeTransferRead>(
455 struct LegalizeVectorStorage
456 :
public arm_sve::impl::LegalizeVectorStorageBase<LegalizeVectorStorage> {
458 void runOnOperation()
override {
465 target.addDynamicallyLegalOp<UnrealizedConversionCastOp>(
466 [](UnrealizedConversionCastOp unrealizedConversion) {
478 return std::make_unique<LegalizeVectorStorage>();
static MLIRContext * getContext(OpFoldResult val)
constexpr StringLiteral kSVELegalizerTag("__arm_sve_legalize_vector_storage__")
StringAttr getStringAttr(const Twine &bytes)
This class describes a specific conversion target.
NamedAttribute represents a combination of a name and an Attribute value.
Operation * insert(Operation *op)
Insert the given operation at the current insertion point and return it.
Operation is the basic unit of execution within MLIR.
bool hasAttr(StringAttr name)
Return true if the operation has an attribute with the provided name, false otherwise.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
This class provides an abstraction over the various different ranges of value types.
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
This is a builder type that keeps local references to arguments.
Builder & setDim(unsigned pos, int64_t val)
Set a dim in shape @pos to val.
void populateLegalizeVectorStoragePatterns(RewritePatternSet &patterns)
Collect a set of patterns to legalize Arm SVE vector storage.
std::unique_ptr< Pass > createLegalizeVectorStoragePass()
Pass to legalize Arm SVE vector storage.
Include the generated interface declarations.
LogicalResult applyPatternsGreedily(Region ®ion, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr)
Rewrite ops in the given region, which must be isolated from above, by repeatedly applying the highes...
const FrozenRewritePatternSet & patterns
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
LogicalResult applyPartialConversion(ArrayRef< Operation * > ops, const ConversionTarget &target, const FrozenRewritePatternSet &patterns, ConversionConfig config=ConversionConfig())
Below we define several entry points for operation conversion.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...