MLIR  19.0.0git
LegalizeVectorStorage.cpp
Go to the documentation of this file.
1 //===- LegalizeVectorStorage.cpp - Ensures SVE loads/stores are legal -----===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
15 
16 namespace mlir::arm_sve {
17 #define GEN_PASS_DEF_LEGALIZEVECTORSTORAGE
18 #include "mlir/Dialect/ArmSVE/Transforms/Passes.h.inc"
19 } // namespace mlir::arm_sve
20 
21 using namespace mlir;
22 using namespace mlir::arm_sve;
23 
24 // A tag to mark unrealized_conversions produced by this pass. This is used to
25 // detect IR this pass failed to completely legalize, and report an error.
26 // If everything was successfully legalized, no tagged ops will remain after
27 // this pass.
28 constexpr StringLiteral kSVELegalizerTag("__arm_sve_legalize_vector_storage__");
29 
30 /// Definitions:
31 ///
32 /// [1] svbool = vector<...x[16]xi1>, which maps to some multiple of full SVE
33 /// predicate registers. A full predicate is the smallest quantity that can be
34 /// loaded/stored.
35 ///
36 /// [2] SVE mask = hardware-sized SVE predicate mask, i.e. its trailing
37 /// dimension matches the size of a legal SVE vector size (such as
38 /// vector<[4]xi1>), but is too small to be stored to memory (i.e smaller than
39 /// a svbool).
40 
41 namespace {
42 
43 /// Checks if a vector type is a SVE mask [2].
44 bool isSVEMaskType(VectorType type) {
45  return type.getRank() > 0 && type.getElementType().isInteger(1) &&
46  type.getScalableDims().back() && type.getShape().back() < 16 &&
47  llvm::isPowerOf2_32(type.getShape().back()) &&
48  !llvm::is_contained(type.getScalableDims().drop_back(), true);
49 }
50 
51 VectorType widenScalableMaskTypeToSvbool(VectorType type) {
52  assert(isSVEMaskType(type));
53  return VectorType::Builder(type).setDim(type.getRank() - 1, 16);
54 }
55 
56 /// A helper for cloning an op and replacing it will a new version, updated by a
57 /// callback.
58 template <typename TOp, typename TLegalizerCallback>
59 void replaceOpWithLegalizedOp(PatternRewriter &rewriter, TOp op,
60  TLegalizerCallback callback) {
61  // Clone the previous op to preserve any properties/attributes.
62  auto newOp = op.clone();
63  rewriter.insert(newOp);
64  rewriter.replaceOp(op, callback(newOp));
65 }
66 
67 /// A helper for cloning an op and replacing it with a new version, updated by a
68 /// callback, and an unrealized conversion back to the type of the replaced op.
69 template <typename TOp, typename TLegalizerCallback>
70 void replaceOpWithUnrealizedConversion(PatternRewriter &rewriter, TOp op,
71  TLegalizerCallback callback) {
72  replaceOpWithLegalizedOp(rewriter, op, [&](TOp newOp) {
73  // Mark our `unrealized_conversion_casts` with a pass label.
74  return rewriter.create<UnrealizedConversionCastOp>(
75  op.getLoc(), TypeRange{op.getResult().getType()},
76  ValueRange{callback(newOp)},
78  rewriter.getUnitAttr()));
79  });
80 }
81 
82 /// Extracts the widened SVE memref value (that's legal to store/load) from the
83 /// `unrealized_conversion_cast`s added by this pass.
84 static FailureOr<Value> getSVELegalizedMemref(Value illegalMemref) {
85  Operation *definingOp = illegalMemref.getDefiningOp();
86  if (!definingOp || !definingOp->hasAttr(kSVELegalizerTag))
87  return failure();
88  auto unrealizedConversion =
89  llvm::cast<UnrealizedConversionCastOp>(definingOp);
90  return unrealizedConversion.getOperand(0);
91 }
92 
93 /// The default alignment of an alloca in LLVM may request overaligned sizes for
94 /// SVE types, which will fail during stack frame allocation. This rewrite
95 /// explicitly adds a reasonable alignment to allocas of scalable types.
96 struct RelaxScalableVectorAllocaAlignment
97  : public OpRewritePattern<memref::AllocaOp> {
99 
100  LogicalResult matchAndRewrite(memref::AllocaOp allocaOp,
101  PatternRewriter &rewriter) const override {
102  auto memrefElementType = allocaOp.getType().getElementType();
103  auto vectorType = llvm::dyn_cast<VectorType>(memrefElementType);
104  if (!vectorType || !vectorType.isScalable() || allocaOp.getAlignment())
105  return failure();
106 
107  // Set alignment based on the defaults for SVE vectors and predicates.
108  unsigned aligment = vectorType.getElementType().isInteger(1) ? 2 : 16;
109  rewriter.modifyOpInPlace(allocaOp,
110  [&] { allocaOp.setAlignment(aligment); });
111 
112  return success();
113  }
114 };
115 
116 /// Replaces allocations of SVE predicates smaller than an svbool [1] (_illegal_
117 /// to load/store) with a wider allocation of svbool (_legal_ to load/store)
118 /// followed by a tagged unrealized conversion to the original type.
119 ///
120 /// Example
121 /// ```
122 /// %alloca = memref.alloca() : memref<vector<[4]xi1>>
123 /// ```
124 /// is rewritten into:
125 /// ```
126 /// %widened = memref.alloca() {alignment = 1 : i64} : memref<vector<[16]xi1>>
127 /// %alloca = builtin.unrealized_conversion_cast %widened
128 /// : memref<vector<[16]xi1>> to memref<vector<[4]xi1>>
129 /// {__arm_sve_legalize_vector_storage__}
130 /// ```
131 template <typename AllocLikeOp>
132 struct LegalizeSVEMaskAllocation : public OpRewritePattern<AllocLikeOp> {
134 
135  LogicalResult matchAndRewrite(AllocLikeOp allocLikeOp,
136  PatternRewriter &rewriter) const override {
137  auto vectorType =
138  llvm::dyn_cast<VectorType>(allocLikeOp.getType().getElementType());
139 
140  if (!vectorType || !isSVEMaskType(vectorType))
141  return failure();
142 
143  // Replace this alloc-like op of an SVE mask [2] with one of a (storable)
144  // svbool mask [1]. A temporary unrealized_conversion_cast is added to the
145  // old type to allow local rewrites.
146  replaceOpWithUnrealizedConversion(
147  rewriter, allocLikeOp, [&](AllocLikeOp newAllocLikeOp) {
148  newAllocLikeOp.getResult().setType(
149  llvm::cast<MemRefType>(newAllocLikeOp.getType().cloneWith(
150  {}, widenScalableMaskTypeToSvbool(vectorType))));
151  return newAllocLikeOp;
152  });
153 
154  return success();
155  }
156 };
157 
158 /// Replaces vector.type_casts of unrealized conversions to SVE predicate memref
159 /// types that are _illegal_ to load/store from (!= svbool [1]), with type casts
160 /// of memref types that are _legal_ to load/store, followed by unrealized
161 /// conversions.
162 ///
163 /// Example:
164 /// ```
165 /// %alloca = builtin.unrealized_conversion_cast %widened
166 /// : memref<vector<[16]xi1>> to memref<vector<[8]xi1>>
167 /// {__arm_sve_legalize_vector_storage__}
168 /// %cast = vector.type_cast %alloca
169 /// : memref<vector<3x[8]xi1>> to memref<3xvector<[8]xi1>>
170 /// ```
171 /// is rewritten into:
172 /// ```
173 /// %widened_cast = vector.type_cast %widened
174 /// : memref<vector<3x[16]xi1>> to memref<3xvector<[16]xi1>>
175 /// %cast = builtin.unrealized_conversion_cast %widened_cast
176 /// : memref<3xvector<[16]xi1>> to memref<3xvector<[8]xi1>>
177 /// {__arm_sve_legalize_vector_storage__}
178 /// ```
179 struct LegalizeSVEMaskTypeCastConversion
180  : public OpRewritePattern<vector::TypeCastOp> {
182 
183  LogicalResult matchAndRewrite(vector::TypeCastOp typeCastOp,
184  PatternRewriter &rewriter) const override {
185  auto resultType = typeCastOp.getResultMemRefType();
186  auto vectorType = llvm::dyn_cast<VectorType>(resultType.getElementType());
187 
188  if (!vectorType || !isSVEMaskType(vectorType))
189  return failure();
190 
191  auto legalMemref = getSVELegalizedMemref(typeCastOp.getMemref());
192  if (failed(legalMemref))
193  return failure();
194 
195  // Replace this vector.type_cast with one of a (storable) svbool mask [1].
196  replaceOpWithUnrealizedConversion(
197  rewriter, typeCastOp, [&](vector::TypeCastOp newTypeCast) {
198  newTypeCast.setOperand(*legalMemref);
199  newTypeCast.getResult().setType(
200  llvm::cast<MemRefType>(newTypeCast.getType().cloneWith(
201  {}, widenScalableMaskTypeToSvbool(vectorType))));
202  return newTypeCast;
203  });
204 
205  return success();
206  }
207 };
208 
209 /// Replaces stores to unrealized conversions to SVE predicate memref types that
210 /// are _illegal_ to load/store from (!= svbool [1]), with
211 /// `arm_sve.convert_to_svbool`s followed by (legal) wider stores.
212 ///
213 /// Example:
214 /// ```
215 /// memref.store %mask, %alloca[] : memref<vector<[8]xi1>>
216 /// ```
217 /// is rewritten into:
218 /// ```
219 /// %svbool = arm_sve.convert_to_svbool %mask : vector<[8]xi1>
220 /// memref.store %svbool, %widened[] : memref<vector<[16]xi1>>
221 /// ```
222 struct LegalizeSVEMaskStoreConversion
223  : public OpRewritePattern<memref::StoreOp> {
225 
226  LogicalResult matchAndRewrite(memref::StoreOp storeOp,
227  PatternRewriter &rewriter) const override {
228  auto loc = storeOp.getLoc();
229 
230  Value valueToStore = storeOp.getValueToStore();
231  auto vectorType = llvm::dyn_cast<VectorType>(valueToStore.getType());
232 
233  if (!vectorType || !isSVEMaskType(vectorType))
234  return failure();
235 
236  auto legalMemref = getSVELegalizedMemref(storeOp.getMemref());
237  if (failed(legalMemref))
238  return failure();
239 
240  auto legalMaskType = widenScalableMaskTypeToSvbool(
241  llvm::cast<VectorType>(valueToStore.getType()));
242  auto convertToSvbool = rewriter.create<arm_sve::ConvertToSvboolOp>(
243  loc, legalMaskType, valueToStore);
244  // Replace this store with a conversion to a storable svbool mask [1],
245  // followed by a wider store.
246  replaceOpWithLegalizedOp(rewriter, storeOp,
247  [&](memref::StoreOp newStoreOp) {
248  newStoreOp.setOperand(0, convertToSvbool);
249  newStoreOp.setOperand(1, *legalMemref);
250  return newStoreOp;
251  });
252 
253  return success();
254  }
255 };
256 
257 /// Replaces loads from unrealized conversions to SVE predicate memref types
258 /// that are _illegal_ to load/store from (!= svbool [1]), types with (legal)
259 /// wider loads, followed by `arm_sve.convert_from_svbool`s.
260 ///
261 /// Example:
262 /// ```
263 /// %reload = memref.load %alloca[] : memref<vector<[4]xi1>>
264 /// ```
265 /// is rewritten into:
266 /// ```
267 /// %svbool = memref.load %widened[] : memref<vector<[16]xi1>>
268 /// %reload = arm_sve.convert_from_svbool %reload : vector<[4]xi1>
269 /// ```
270 struct LegalizeSVEMaskLoadConversion : public OpRewritePattern<memref::LoadOp> {
272 
273  LogicalResult matchAndRewrite(memref::LoadOp loadOp,
274  PatternRewriter &rewriter) const override {
275  auto loc = loadOp.getLoc();
276 
277  Value loadedMask = loadOp.getResult();
278  auto vectorType = llvm::dyn_cast<VectorType>(loadedMask.getType());
279 
280  if (!vectorType || !isSVEMaskType(vectorType))
281  return failure();
282 
283  auto legalMemref = getSVELegalizedMemref(loadOp.getMemref());
284  if (failed(legalMemref))
285  return failure();
286 
287  auto legalMaskType = widenScalableMaskTypeToSvbool(vectorType);
288  // Replace this load with a legal load of an svbool type, followed by a
289  // conversion back to the original type.
290  replaceOpWithLegalizedOp(rewriter, loadOp, [&](memref::LoadOp newLoadOp) {
291  newLoadOp.setMemRef(*legalMemref);
292  newLoadOp.getResult().setType(legalMaskType);
293  return rewriter.create<arm_sve::ConvertFromSvboolOp>(
294  loc, loadedMask.getType(), newLoadOp);
295  });
296 
297  return success();
298  }
299 };
300 
301 } // namespace
302 
304  RewritePatternSet &patterns) {
305  patterns.add<RelaxScalableVectorAllocaAlignment,
306  LegalizeSVEMaskAllocation<memref::AllocaOp>,
307  LegalizeSVEMaskAllocation<memref::AllocOp>,
308  LegalizeSVEMaskTypeCastConversion,
309  LegalizeSVEMaskStoreConversion, LegalizeSVEMaskLoadConversion>(
310  patterns.getContext());
311 }
312 
313 namespace {
314 struct LegalizeVectorStorage
315  : public arm_sve::impl::LegalizeVectorStorageBase<LegalizeVectorStorage> {
316 
317  void runOnOperation() override {
318  RewritePatternSet patterns(&getContext());
320  if (failed(applyPatternsAndFoldGreedily(getOperation(),
321  std::move(patterns)))) {
322  signalPassFailure();
323  }
324  ConversionTarget target(getContext());
325  target.addDynamicallyLegalOp<UnrealizedConversionCastOp>(
326  [](UnrealizedConversionCastOp unrealizedConversion) {
327  return !unrealizedConversion->hasAttr(kSVELegalizerTag);
328  });
329  // This detects if we failed to completely legalize the IR.
330  if (failed(applyPartialConversion(getOperation(), target, {})))
331  signalPassFailure();
332  }
333 };
334 
335 } // namespace
336 
338  return std::make_unique<LegalizeVectorStorage>();
339 }
static MLIRContext * getContext(OpFoldResult val)
constexpr StringLiteral kSVELegalizerTag("__arm_sve_legalize_vector_storage__")
UnitAttr getUnitAttr()
Definition: Builders.cpp:114
StringAttr getStringAttr(const Twine &bytes)
Definition: Builders.cpp:269
This class describes a specific conversion target.
This class provides support for representing a failure result, or a valid value of type T.
Definition: LogicalResult.h:78
NamedAttribute represents a combination of a name and an Attribute value.
Definition: Attributes.h:207
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:464
Operation * insert(Operation *op)
Insert the given operation at the current insertion point and return it.
Definition: Builders.cpp:428
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
bool hasAttr(StringAttr name)
Return true if the operation has an attribute with the provided name, false otherwise.
Definition: Operation.h:555
Operation * clone(IRMapping &mapper, CloneOptions options=CloneOptions::all())
Create a deep copy of this operation, remapping any operands that use values outside of the operation...
Definition: Operation.cpp:717
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Definition: Operation.h:402
Location getLoc()
The source location the operation was defined or derived from.
Definition: Operation.h:223
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:785
MLIRContext * getContext() const
Definition: PatternMatch.h:822
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
Definition: PatternMatch.h:846
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
Definition: PatternMatch.h:630
This class provides an abstraction over the various different ranges of value types.
Definition: TypeRange.h:36
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:381
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:129
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition: Value.cpp:20
This is a builder type that keeps local references to arguments.
Definition: BuiltinTypes.h:305
Builder & setDim(unsigned pos, int64_t val)
Set a dim in shape @pos to val.
Definition: BuiltinTypes.h:339
void populateLegalizeVectorStoragePatterns(RewritePatternSet &patterns)
Collect a set of patterns to legalize Arm SVE vector storage.
std::unique_ptr< Pass > createLegalizeVectorStoragePass()
Pass to legalize Arm SVE vector storage.
Include the generated interface declarations.
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:62
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:56
LogicalResult applyPatternsAndFoldGreedily(Region &region, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr)
Rewrite ops in the given region, which must be isolated from above, by repeatedly applying the highes...
LogicalResult applyPartialConversion(ArrayRef< Operation * > ops, const ConversionTarget &target, const FrozenRewritePatternSet &patterns, ConversionConfig config=ConversionConfig())
Below we define several entry points for operation conversion.
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value.
Definition: LogicalResult.h:72
This class represents an efficient way to signal success or failure.
Definition: LogicalResult.h:26
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Definition: PatternMatch.h:358
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...
Definition: PatternMatch.h:362