17 #define GEN_PASS_DEF_LEGALIZEVECTORSTORAGE
18 #include "mlir/Dialect/ArmSVE/Transforms/Passes.h.inc"
44 bool isSVEMaskType(VectorType type) {
45 return type.getRank() > 0 && type.getElementType().isInteger(1) &&
46 type.getScalableDims().back() && type.getShape().back() < 16 &&
47 llvm::isPowerOf2_32(type.getShape().back()) &&
48 !llvm::is_contained(type.getScalableDims().drop_back(),
true);
51 VectorType widenScalableMaskTypeToSvbool(VectorType type) {
52 assert(isSVEMaskType(type));
58 template <
typename TOp,
typename TLegalizerCallback>
60 TLegalizerCallback callback) {
62 auto newOp = op.
clone();
69 template <
typename TOp,
typename TLegalizerCallback>
70 void replaceOpWithUnrealizedConversion(
PatternRewriter &rewriter, TOp op,
71 TLegalizerCallback callback) {
72 replaceOpWithLegalizedOp(rewriter, op, [&](TOp newOp) {
74 return rewriter.
create<UnrealizedConversionCastOp>(
88 auto unrealizedConversion =
89 llvm::cast<UnrealizedConversionCastOp>(definingOp);
90 return unrealizedConversion.getOperand(0);
96 struct RelaxScalableVectorAllocaAlignment
102 auto memrefElementType = allocaOp.getType().getElementType();
103 auto vectorType = llvm::dyn_cast<VectorType>(memrefElementType);
104 if (!vectorType || !vectorType.isScalable() || allocaOp.getAlignment())
108 unsigned aligment = vectorType.getElementType().isInteger(1) ? 2 : 16;
110 [&] { allocaOp.setAlignment(aligment); });
131 template <
typename AllocLikeOp>
138 llvm::dyn_cast<VectorType>(allocLikeOp.getType().getElementType());
140 if (!vectorType || !isSVEMaskType(vectorType))
146 replaceOpWithUnrealizedConversion(
147 rewriter, allocLikeOp, [&](AllocLikeOp newAllocLikeOp) {
148 newAllocLikeOp.getResult().setType(
149 llvm::cast<MemRefType>(newAllocLikeOp.getType().cloneWith(
150 {}, widenScalableMaskTypeToSvbool(vectorType))));
151 return newAllocLikeOp;
179 struct LegalizeSVEMaskTypeCastConversion
185 auto resultType = typeCastOp.getResultMemRefType();
186 auto vectorType = llvm::dyn_cast<VectorType>(resultType.getElementType());
188 if (!vectorType || !isSVEMaskType(vectorType))
191 auto legalMemref = getSVELegalizedMemref(typeCastOp.getMemref());
196 replaceOpWithUnrealizedConversion(
197 rewriter, typeCastOp, [&](vector::TypeCastOp newTypeCast) {
198 newTypeCast.setOperand(*legalMemref);
199 newTypeCast.getResult().setType(
200 llvm::cast<MemRefType>(newTypeCast.getType().cloneWith(
201 {}, widenScalableMaskTypeToSvbool(vectorType))));
222 struct LegalizeSVEMaskStoreConversion
228 auto loc = storeOp.getLoc();
230 Value valueToStore = storeOp.getValueToStore();
231 auto vectorType = llvm::dyn_cast<VectorType>(valueToStore.
getType());
233 if (!vectorType || !isSVEMaskType(vectorType))
236 auto legalMemref = getSVELegalizedMemref(storeOp.getMemref());
240 auto legalMaskType = widenScalableMaskTypeToSvbool(
241 llvm::cast<VectorType>(valueToStore.
getType()));
242 auto convertToSvbool = rewriter.
create<arm_sve::ConvertToSvboolOp>(
243 loc, legalMaskType, valueToStore);
246 replaceOpWithLegalizedOp(rewriter, storeOp,
247 [&](memref::StoreOp newStoreOp) {
248 newStoreOp.setOperand(0, convertToSvbool);
249 newStoreOp.setOperand(1, *legalMemref);
270 struct LegalizeSVEMaskLoadConversion :
public OpRewritePattern<memref::LoadOp> {
275 auto loc = loadOp.getLoc();
277 Value loadedMask = loadOp.getResult();
278 auto vectorType = llvm::dyn_cast<VectorType>(loadedMask.
getType());
280 if (!vectorType || !isSVEMaskType(vectorType))
283 auto legalMemref = getSVELegalizedMemref(loadOp.getMemref());
287 auto legalMaskType = widenScalableMaskTypeToSvbool(vectorType);
290 replaceOpWithLegalizedOp(rewriter, loadOp, [&](memref::LoadOp newLoadOp) {
291 newLoadOp.setMemRef(*legalMemref);
292 newLoadOp.getResult().setType(legalMaskType);
293 return rewriter.
create<arm_sve::ConvertFromSvboolOp>(
294 loc, loadedMask.
getType(), newLoadOp);
305 patterns.
add<RelaxScalableVectorAllocaAlignment,
306 LegalizeSVEMaskAllocation<memref::AllocaOp>,
307 LegalizeSVEMaskAllocation<memref::AllocOp>,
308 LegalizeSVEMaskTypeCastConversion,
309 LegalizeSVEMaskStoreConversion, LegalizeSVEMaskLoadConversion>(
314 struct LegalizeVectorStorage
315 :
public arm_sve::impl::LegalizeVectorStorageBase<LegalizeVectorStorage> {
317 void runOnOperation()
override {
321 std::move(patterns)))) {
325 target.addDynamicallyLegalOp<UnrealizedConversionCastOp>(
326 [](UnrealizedConversionCastOp unrealizedConversion) {
338 return std::make_unique<LegalizeVectorStorage>();
static MLIRContext * getContext(OpFoldResult val)
constexpr StringLiteral kSVELegalizerTag("__arm_sve_legalize_vector_storage__")
StringAttr getStringAttr(const Twine &bytes)
This class describes a specific conversion target.
This class provides support for representing a failure result, or a valid value of type T.
NamedAttribute represents a combination of a name and an Attribute value.
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Operation * insert(Operation *op)
Insert the given operation at the current insertion point and return it.
Operation is the basic unit of execution within MLIR.
bool hasAttr(StringAttr name)
Return true if the operation has an attribute with the provided name, false otherwise.
Operation * clone(IRMapping &mapper, CloneOptions options=CloneOptions::all())
Create a deep copy of this operation, remapping any operands that use values outside of the operation...
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Location getLoc()
The source location the operation was defined or derived from.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
MLIRContext * getContext() const
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
This class provides an abstraction over the various different ranges of value types.
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
This is a builder type that keeps local references to arguments.
Builder & setDim(unsigned pos, int64_t val)
Set a dim in shape @pos to val.
void populateLegalizeVectorStoragePatterns(RewritePatternSet &patterns)
Collect a set of patterns to legalize Arm SVE vector storage.
std::unique_ptr< Pass > createLegalizeVectorStoragePass()
Pass to legalize Arm SVE vector storage.
Include the generated interface declarations.
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
LogicalResult applyPatternsAndFoldGreedily(Region ®ion, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr)
Rewrite ops in the given region, which must be isolated from above, by repeatedly applying the highes...
LogicalResult applyPartialConversion(ArrayRef< Operation * > ops, const ConversionTarget &target, const FrozenRewritePatternSet &patterns, ConversionConfig config=ConversionConfig())
Below we define several entry points for operation conversion.
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value.
This class represents an efficient way to signal success or failure.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...