17#define GEN_PASS_DEF_LEGALIZEVECTORSTORAGE
18#include "mlir/Dialect/ArmSVE/Transforms/Passes.h.inc"
44bool isSVEMaskType(VectorType type) {
45 return type.getRank() > 0 && type.getElementType().isInteger(1) &&
46 type.getScalableDims().back() && type.getShape().back() < 16 &&
47 llvm::isPowerOf2_32(type.getShape().back()) &&
48 !llvm::is_contained(type.getScalableDims().drop_back(),
true);
51VectorType widenScalableMaskTypeToSvbool(VectorType type) {
52 assert(isSVEMaskType(type));
58template <
typename TOp,
typename TLegalizerCallback>
60 TLegalizerCallback callback) {
62 auto newOp = op.clone();
69template <
typename TOp,
typename TLegalizerCallback>
70void replaceOpWithUnrealizedConversion(
PatternRewriter &rewriter, TOp op,
71 TLegalizerCallback callback) {
72 replaceOpWithLegalizedOp(rewriter, op, [&](TOp newOp) {
74 return UnrealizedConversionCastOp::create(
75 rewriter, op.getLoc(),
TypeRange{op.getResult().getType()},
84static FailureOr<Value> getSVELegalizedMemref(
Value illegalMemref) {
88 auto unrealizedConversion =
89 llvm::cast<UnrealizedConversionCastOp>(definingOp);
90 return unrealizedConversion.getOperand(0);
96struct RelaxScalableVectorAllocaAlignment
100 LogicalResult matchAndRewrite(memref::AllocaOp allocaOp,
102 auto memrefElementType = allocaOp.getType().getElementType();
103 auto vectorType = llvm::dyn_cast<VectorType>(memrefElementType);
104 if (!vectorType || !vectorType.isScalable() || allocaOp.getAlignment())
108 unsigned aligment = vectorType.getElementType().isInteger(1) ? 2 : 16;
110 [&] { allocaOp.setAlignment(aligment); });
131template <
typename AllocLikeOp>
133 using OpRewritePattern<AllocLikeOp>::OpRewritePattern;
135 LogicalResult matchAndRewrite(AllocLikeOp allocLikeOp,
136 PatternRewriter &rewriter)
const override {
138 llvm::dyn_cast<VectorType>(allocLikeOp.getType().getElementType());
140 if (!vectorType || !isSVEMaskType(vectorType))
146 replaceOpWithUnrealizedConversion(
147 rewriter, allocLikeOp, [&](AllocLikeOp newAllocLikeOp) {
148 newAllocLikeOp.getResult().setType(
149 llvm::cast<MemRefType>(newAllocLikeOp.getType().cloneWith(
150 {}, widenScalableMaskTypeToSvbool(vectorType))));
151 return newAllocLikeOp;
179struct LegalizeSVEMaskTypeCastConversion
183 LogicalResult matchAndRewrite(vector::TypeCastOp typeCastOp,
184 PatternRewriter &rewriter)
const override {
185 auto resultType = typeCastOp.getResultMemRefType();
186 auto vectorType = llvm::dyn_cast<VectorType>(resultType.getElementType());
188 if (!vectorType || !isSVEMaskType(vectorType))
191 auto legalMemref = getSVELegalizedMemref(typeCastOp.getMemref());
196 replaceOpWithUnrealizedConversion(
197 rewriter, typeCastOp, [&](vector::TypeCastOp newTypeCast) {
198 newTypeCast.setOperand(*legalMemref);
199 newTypeCast.getResult().setType(
200 llvm::cast<MemRefType>(newTypeCast.getType().cloneWith(
201 {}, widenScalableMaskTypeToSvbool(vectorType))));
222struct LegalizeSVEMaskStoreConversion
226 LogicalResult matchAndRewrite(memref::StoreOp storeOp,
227 PatternRewriter &rewriter)
const override {
228 auto loc = storeOp.getLoc();
230 Value valueToStore = storeOp.getValueToStore();
231 auto vectorType = llvm::dyn_cast<VectorType>(valueToStore.
getType());
233 if (!vectorType || !isSVEMaskType(vectorType))
236 auto legalMemref = getSVELegalizedMemref(storeOp.getMemref());
240 auto legalMaskType = widenScalableMaskTypeToSvbool(
241 llvm::cast<VectorType>(valueToStore.
getType()));
242 auto convertToSvbool = arm_sve::ConvertToSvboolOp::create(
243 rewriter, loc, legalMaskType, valueToStore);
246 replaceOpWithLegalizedOp(rewriter, storeOp,
247 [&](memref::StoreOp newStoreOp) {
248 newStoreOp.setOperand(0, convertToSvbool);
249 newStoreOp.setOperand(1, *legalMemref);
270struct LegalizeSVEMaskLoadConversion :
public OpRewritePattern<memref::LoadOp> {
273 LogicalResult matchAndRewrite(memref::LoadOp loadOp,
274 PatternRewriter &rewriter)
const override {
275 auto loc = loadOp.getLoc();
277 Value loadedMask = loadOp.getResult();
278 auto vectorType = llvm::dyn_cast<VectorType>(loadedMask.
getType());
280 if (!vectorType || !isSVEMaskType(vectorType))
283 auto legalMemref = getSVELegalizedMemref(loadOp.getMemref());
287 auto legalMaskType = widenScalableMaskTypeToSvbool(vectorType);
290 replaceOpWithLegalizedOp(rewriter, loadOp, [&](memref::LoadOp newLoadOp) {
291 newLoadOp.setMemRef(*legalMemref);
292 newLoadOp.getResult().setType(legalMaskType);
293 return arm_sve::ConvertFromSvboolOp::create(
294 rewriter, loc, loadedMask.
getType(), newLoadOp);
321struct LegalizeTransferRead :
public OpRewritePattern<vector::TransferReadOp> {
324 LogicalResult matchAndRewrite(vector::TransferReadOp readOp,
325 PatternRewriter &rewriter)
const override {
338 if (readOp.isMasked() || readOp.getMask())
340 "masked transfers not-supported");
351 if (!readOp.getPermutationMap().isMinorIdentity())
362 VectorType origVT = readOp.getVectorType();
363 ArrayRef<bool> origScalableDims = origVT.getScalableDims();
364 const int64_t origVRank = origVT.getRank();
365 if (origVRank < 2 || origVT.getNumScalableDims() != 1)
371 const int64_t numCollapseDims = std::distance(
372 llvm::find(origScalableDims,
true), origScalableDims.end());
373 if (numCollapseDims < 2)
375 "scalable dimension is trailing");
379 auto memTy = dyn_cast<MemRefType>(readOp.getBase().getType());
380 if (!(memTy && memTy.areTrailingDimsContiguous(numCollapseDims)))
382 readOp,
"non-contiguous memref dimensions to collapse");
390 if (!llvm::equal(memTy.getShape().take_back(numCollapseDims - 1),
391 origVT.getShape().take_back(numCollapseDims - 1)))
393 readOp,
"memref and vector dimensions do not match");
395 SmallVector<bool> origInBounds = readOp.getInBoundsValues();
397 ArrayRef<bool>(origInBounds).take_back(numCollapseDims - 1),
398 [](
bool v) {
return v; }))
400 readOp,
"out-of-bounds transfer from a dimension to collapse");
403 SmallVector<ReassociationIndices> reassoc;
404 for (int64_t i = 0; i < memTy.getRank() - numCollapseDims + 1; ++i)
405 reassoc.push_back({i});
406 for (int64_t i = memTy.getRank() - numCollapseDims + 1; i < memTy.getRank();
408 reassoc.back().push_back(i);
409 if (!memref::CollapseShapeOp::isGuaranteedCollapsible(memTy, reassoc))
411 Value collapsedMem = memref::CollapseShapeOp::create(
412 rewriter, readOp.getLoc(), readOp.getBase(), reassoc);
415 SmallVector<int64_t> shape(origVT.getShape());
416 for (int64_t i = origVRank - numCollapseDims + 1; i < origVRank; ++i)
417 shape[origVRank - numCollapseDims] *= shape[i];
418 shape.pop_back_n(numCollapseDims - 1);
420 VectorType::get(shape, origVT.getElementType(),
421 origScalableDims.drop_back(numCollapseDims - 1));
424 auto indices = readOp.getIndices().drop_back(numCollapseDims - 1);
427 auto newReadOp = vector::TransferReadOp::create(
428 rewriter, readOp.getLoc(), collapsedVT, collapsedMem,
indices,
430 ArrayRef<bool>(origInBounds).drop_back(numCollapseDims - 1));
433 auto toOrigShape = vector::ShapeCastOp::create(rewriter, readOp.getLoc(),
446 .add<RelaxScalableVectorAllocaAlignment,
447 LegalizeSVEMaskAllocation<memref::AllocaOp>,
448 LegalizeSVEMaskAllocation<memref::AllocOp>,
449 LegalizeSVEMaskTypeCastConversion, LegalizeSVEMaskStoreConversion,
450 LegalizeSVEMaskLoadConversion, LegalizeTransferRead>(
455struct LegalizeVectorStorage
458 void runOnOperation()
override {
465 target.addDynamicallyLegalOp<UnrealizedConversionCastOp>(
466 [](UnrealizedConversionCastOp unrealizedConversion) {
470 if (
failed(applyPartialConversion(getOperation(),
target, {})))
478 return std::make_unique<LegalizeVectorStorage>();
constexpr StringLiteral kSVELegalizerTag("__arm_sve_legalize_vector_storage__")
StringAttr getStringAttr(const Twine &bytes)
NamedAttribute represents a combination of a name and an Attribute value.
Operation * insert(Operation *op)
Insert the given operation at the current insertion point and return it.
Operation is the basic unit of execution within MLIR.
bool hasAttr(StringAttr name)
Return true if the operation has an attribute with the provided name, false otherwise.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
This class provides an abstraction over the various different ranges of value types.
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
This is a builder type that keeps local references to arguments.
Builder & setDim(unsigned pos, int64_t val)
Set a dim in shape @pos to val.
void populateLegalizeVectorStoragePatterns(RewritePatternSet &patterns)
Collect a set of patterns to legalize Arm SVE vector storage.
std::unique_ptr< Pass > createLegalizeVectorStoragePass()
Pass to legalize Arm SVE vector storage.
Include the generated interface declarations.
LogicalResult applyPatternsGreedily(Region ®ion, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr)
Rewrite ops in the given region, which must be isolated from above, by repeatedly applying the highes...
const FrozenRewritePatternSet & patterns
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...