MLIR 22.0.0git
LegalizeVectorStorage.cpp
Go to the documentation of this file.
1//===- LegalizeVectorStorage.cpp - Ensures SVE loads/stores are legal -----===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
15
16namespace mlir::arm_sve {
17#define GEN_PASS_DEF_LEGALIZEVECTORSTORAGE
18#include "mlir/Dialect/ArmSVE/Transforms/Passes.h.inc"
19} // namespace mlir::arm_sve
21using namespace mlir;
22using namespace mlir::arm_sve;
23
24// A tag to mark unrealized_conversions produced by this pass. This is used to
25// detect IR this pass failed to completely legalize, and report an error.
26// If everything was successfully legalized, no tagged ops will remain after
27// this pass.
28constexpr StringLiteral kSVELegalizerTag("__arm_sve_legalize_vector_storage__");
30/// Definitions:
31///
32/// [1] svbool = vector<...x[16]xi1>, which maps to some multiple of full SVE
33/// predicate registers. A full predicate is the smallest quantity that can be
34/// loaded/stored.
35///
36/// [2] SVE mask = hardware-sized SVE predicate mask, i.e. its trailing
37/// dimension matches the size of a legal SVE vector size (such as
38/// vector<[4]xi1>), but is too small to be stored to memory (i.e smaller than
39/// a svbool).
41namespace {
42
43/// Checks if a vector type is a SVE mask [2].
44bool isSVEMaskType(VectorType type) {
45 return type.getRank() > 0 && type.getElementType().isInteger(1) &&
46 type.getScalableDims().back() && type.getShape().back() < 16 &&
47 llvm::isPowerOf2_32(type.getShape().back()) &&
48 !llvm::is_contained(type.getScalableDims().drop_back(), true);
49}
50
51VectorType widenScalableMaskTypeToSvbool(VectorType type) {
52 assert(isSVEMaskType(type));
53 return VectorType::Builder(type).setDim(type.getRank() - 1, 16);
54}
55
56/// A helper for cloning an op and replacing it will a new version, updated by a
57/// callback.
58template <typename TOp, typename TLegalizerCallback>
59void replaceOpWithLegalizedOp(PatternRewriter &rewriter, TOp op,
60 TLegalizerCallback callback) {
61 // Clone the previous op to preserve any properties/attributes.
62 auto newOp = op.clone();
63 rewriter.insert(newOp);
64 rewriter.replaceOp(op, callback(newOp));
65}
66
67/// A helper for cloning an op and replacing it with a new version, updated by a
68/// callback, and an unrealized conversion back to the type of the replaced op.
69template <typename TOp, typename TLegalizerCallback>
70void replaceOpWithUnrealizedConversion(PatternRewriter &rewriter, TOp op,
71 TLegalizerCallback callback) {
72 replaceOpWithLegalizedOp(rewriter, op, [&](TOp newOp) {
73 // Mark our `unrealized_conversion_casts` with a pass label.
74 return UnrealizedConversionCastOp::create(
75 rewriter, op.getLoc(), TypeRange{op.getResult().getType()},
76 ValueRange{callback(newOp)},
78 rewriter.getUnitAttr()));
79 });
80}
81
82/// Extracts the widened SVE memref value (that's legal to store/load) from the
83/// `unrealized_conversion_cast`s added by this pass.
84static FailureOr<Value> getSVELegalizedMemref(Value illegalMemref) {
85 Operation *definingOp = illegalMemref.getDefiningOp();
86 if (!definingOp || !definingOp->hasAttr(kSVELegalizerTag))
87 return failure();
88 auto unrealizedConversion =
89 llvm::cast<UnrealizedConversionCastOp>(definingOp);
90 return unrealizedConversion.getOperand(0);
91}
92
93/// The default alignment of an alloca in LLVM may request overaligned sizes for
94/// SVE types, which will fail during stack frame allocation. This rewrite
95/// explicitly adds a reasonable alignment to allocas of scalable types.
96struct RelaxScalableVectorAllocaAlignment
97 : public OpRewritePattern<memref::AllocaOp> {
99
100 LogicalResult matchAndRewrite(memref::AllocaOp allocaOp,
101 PatternRewriter &rewriter) const override {
102 auto memrefElementType = allocaOp.getType().getElementType();
103 auto vectorType = llvm::dyn_cast<VectorType>(memrefElementType);
104 if (!vectorType || !vectorType.isScalable() || allocaOp.getAlignment())
105 return failure();
106
107 // Set alignment based on the defaults for SVE vectors and predicates.
108 unsigned aligment = vectorType.getElementType().isInteger(1) ? 2 : 16;
109 rewriter.modifyOpInPlace(allocaOp,
110 [&] { allocaOp.setAlignment(aligment); });
111
112 return success();
113 }
114};
115
116/// Replaces allocations of SVE predicates smaller than an svbool [1] (_illegal_
117/// to load/store) with a wider allocation of svbool (_legal_ to load/store)
118/// followed by a tagged unrealized conversion to the original type.
119///
120/// Example
121/// ```
122/// %alloca = memref.alloca() : memref<vector<[4]xi1>>
123/// ```
124/// is rewritten into:
125/// ```
126/// %widened = memref.alloca() {alignment = 1 : i64} : memref<vector<[16]xi1>>
127/// %alloca = builtin.unrealized_conversion_cast %widened
128/// : memref<vector<[16]xi1>> to memref<vector<[4]xi1>>
129/// {__arm_sve_legalize_vector_storage__}
130/// ```
131template <typename AllocLikeOp>
132struct LegalizeSVEMaskAllocation : public OpRewritePattern<AllocLikeOp> {
133 using OpRewritePattern<AllocLikeOp>::OpRewritePattern;
134
135 LogicalResult matchAndRewrite(AllocLikeOp allocLikeOp,
136 PatternRewriter &rewriter) const override {
137 auto vectorType =
138 llvm::dyn_cast<VectorType>(allocLikeOp.getType().getElementType());
139
140 if (!vectorType || !isSVEMaskType(vectorType))
141 return failure();
142
143 // Replace this alloc-like op of an SVE mask [2] with one of a (storable)
144 // svbool mask [1]. A temporary unrealized_conversion_cast is added to the
145 // old type to allow local rewrites.
146 replaceOpWithUnrealizedConversion(
147 rewriter, allocLikeOp, [&](AllocLikeOp newAllocLikeOp) {
148 newAllocLikeOp.getResult().setType(
149 llvm::cast<MemRefType>(newAllocLikeOp.getType().cloneWith(
150 {}, widenScalableMaskTypeToSvbool(vectorType))));
151 return newAllocLikeOp;
152 });
153
154 return success();
155 }
156};
157
158/// Replaces vector.type_casts of unrealized conversions to SVE predicate memref
159/// types that are _illegal_ to load/store from (!= svbool [1]), with type casts
160/// of memref types that are _legal_ to load/store, followed by unrealized
161/// conversions.
162///
163/// Example:
164/// ```
165/// %alloca = builtin.unrealized_conversion_cast %widened
166/// : memref<vector<[16]xi1>> to memref<vector<[8]xi1>>
167/// {__arm_sve_legalize_vector_storage__}
168/// %cast = vector.type_cast %alloca
169/// : memref<vector<3x[8]xi1>> to memref<3xvector<[8]xi1>>
170/// ```
171/// is rewritten into:
172/// ```
173/// %widened_cast = vector.type_cast %widened
174/// : memref<vector<3x[16]xi1>> to memref<3xvector<[16]xi1>>
175/// %cast = builtin.unrealized_conversion_cast %widened_cast
176/// : memref<3xvector<[16]xi1>> to memref<3xvector<[8]xi1>>
177/// {__arm_sve_legalize_vector_storage__}
178/// ```
179struct LegalizeSVEMaskTypeCastConversion
180 : public OpRewritePattern<vector::TypeCastOp> {
182
183 LogicalResult matchAndRewrite(vector::TypeCastOp typeCastOp,
184 PatternRewriter &rewriter) const override {
185 auto resultType = typeCastOp.getResultMemRefType();
186 auto vectorType = llvm::dyn_cast<VectorType>(resultType.getElementType());
187
188 if (!vectorType || !isSVEMaskType(vectorType))
189 return failure();
190
191 auto legalMemref = getSVELegalizedMemref(typeCastOp.getMemref());
192 if (failed(legalMemref))
193 return failure();
194
195 // Replace this vector.type_cast with one of a (storable) svbool mask [1].
196 replaceOpWithUnrealizedConversion(
197 rewriter, typeCastOp, [&](vector::TypeCastOp newTypeCast) {
198 newTypeCast.setOperand(*legalMemref);
199 newTypeCast.getResult().setType(
200 llvm::cast<MemRefType>(newTypeCast.getType().cloneWith(
201 {}, widenScalableMaskTypeToSvbool(vectorType))));
202 return newTypeCast;
203 });
204
205 return success();
206 }
207};
208
209/// Replaces stores to unrealized conversions to SVE predicate memref types that
210/// are _illegal_ to load/store from (!= svbool [1]), with
211/// `arm_sve.convert_to_svbool`s followed by (legal) wider stores.
212///
213/// Example:
214/// ```
215/// memref.store %mask, %alloca[] : memref<vector<[8]xi1>>
216/// ```
217/// is rewritten into:
218/// ```
219/// %svbool = arm_sve.convert_to_svbool %mask : vector<[8]xi1>
220/// memref.store %svbool, %widened[] : memref<vector<[16]xi1>>
221/// ```
222struct LegalizeSVEMaskStoreConversion
223 : public OpRewritePattern<memref::StoreOp> {
225
226 LogicalResult matchAndRewrite(memref::StoreOp storeOp,
227 PatternRewriter &rewriter) const override {
228 auto loc = storeOp.getLoc();
229
230 Value valueToStore = storeOp.getValueToStore();
231 auto vectorType = llvm::dyn_cast<VectorType>(valueToStore.getType());
232
233 if (!vectorType || !isSVEMaskType(vectorType))
234 return failure();
235
236 auto legalMemref = getSVELegalizedMemref(storeOp.getMemref());
237 if (failed(legalMemref))
238 return failure();
239
240 auto legalMaskType = widenScalableMaskTypeToSvbool(
241 llvm::cast<VectorType>(valueToStore.getType()));
242 auto convertToSvbool = arm_sve::ConvertToSvboolOp::create(
243 rewriter, loc, legalMaskType, valueToStore);
244 // Replace this store with a conversion to a storable svbool mask [1],
245 // followed by a wider store.
246 replaceOpWithLegalizedOp(rewriter, storeOp,
247 [&](memref::StoreOp newStoreOp) {
248 newStoreOp.setOperand(0, convertToSvbool);
249 newStoreOp.setOperand(1, *legalMemref);
250 return newStoreOp;
251 });
252
253 return success();
254 }
255};
256
257/// Replaces loads from unrealized conversions to SVE predicate memref types
258/// that are _illegal_ to load/store from (!= svbool [1]), types with (legal)
259/// wider loads, followed by `arm_sve.convert_from_svbool`s.
260///
261/// Example:
262/// ```
263/// %reload = memref.load %alloca[] : memref<vector<[4]xi1>>
264/// ```
265/// is rewritten into:
266/// ```
267/// %svbool = memref.load %widened[] : memref<vector<[16]xi1>>
268/// %reload = arm_sve.convert_from_svbool %reload : vector<[4]xi1>
269/// ```
270struct LegalizeSVEMaskLoadConversion : public OpRewritePattern<memref::LoadOp> {
272
273 LogicalResult matchAndRewrite(memref::LoadOp loadOp,
274 PatternRewriter &rewriter) const override {
275 auto loc = loadOp.getLoc();
276
277 Value loadedMask = loadOp.getResult();
278 auto vectorType = llvm::dyn_cast<VectorType>(loadedMask.getType());
279
280 if (!vectorType || !isSVEMaskType(vectorType))
281 return failure();
282
283 auto legalMemref = getSVELegalizedMemref(loadOp.getMemref());
284 if (failed(legalMemref))
285 return failure();
286
287 auto legalMaskType = widenScalableMaskTypeToSvbool(vectorType);
288 // Replace this load with a legal load of an svbool type, followed by a
289 // conversion back to the original type.
290 replaceOpWithLegalizedOp(rewriter, loadOp, [&](memref::LoadOp newLoadOp) {
291 newLoadOp.setMemRef(*legalMemref);
292 newLoadOp.getResult().setType(legalMaskType);
293 return arm_sve::ConvertFromSvboolOp::create(
294 rewriter, loc, loadedMask.getType(), newLoadOp);
295 });
296
297 return success();
298 }
299};
300
301/// Transforms a `transfer_read` operation so it reads vector of a type that
302/// can be mapped to an LLVM type ("LLVM-legal" type). This is done by
303/// collapsing trailing dimensions so we obtain a vector type with a single
304/// scalable dimension in the rightmost position.
305///
306/// Example:
307/// ```
308/// %v = vector.transfer_read %M[%i, %j, %c0, %c0], %c0_i8
309/// {in_bounds = [false, true, true, true]}
310/// : memref<?x?x2x8xi8>, vector<2x[4]x2x8xi8>
311/// ```
312/// is rewritten to
313/// ```
314/// %collapse_shape = memref.collapse_shape %M [[0], [1, 2, 3]]
315/// : memref<?x?x2x8xi8> into memref<?x?xi8>
316/// %0 = vector.transfer_read %collapse_shape[%i, %j], %c0_i8
317/// {in_bounds = [false, true]}
318/// : memref<?x?xi8>, vector<2x[64]xi8>
319/// %1 = vector.shape_cast %0 : vector<2x[64]xi8> to vector<2x[4]x2x8xi8>
320/// ```
321struct LegalizeTransferRead : public OpRewritePattern<vector::TransferReadOp> {
323
324 LogicalResult matchAndRewrite(vector::TransferReadOp readOp,
325 PatternRewriter &rewriter) const override {
326
327 // Do not try to transform masked reads. For example, if we have a transfer
328 // to a `vector<[4]x4xi8>` we could have a mask like
329 // 1 1 1 0
330 // 1 1 1 0
331 // 1 1 1 0
332 // 0 0 0 0
333 // Flattening this mask would look like
334 // 1 1 1 0 1 1 1 0 1 1 1 0 0 0 0 0
335 // and we have not yet figured out an efficient way to build such a mask,
336 // neither from the mask operand, nor from the original `vector.create_mask`
337 // operation (if visible at all).
338 if (readOp.isMasked() || readOp.getMask())
339 return rewriter.notifyMatchFailure(readOp,
340 "masked transfers not-supported");
341
342 // General permutation maps are not supported. The issue is with transpose,
343 // broadcast, and other forms of non-identify mapping in the minor
344 // dimensions which is impossible to represent after collapsing (at least
345 // because the resulting "collapsed" maps would have smaller number of
346 // dimension indices).
347 // TODO: We have not had yet the need for it, but some forms of permutation
348 // maps with identity in the minor dimensions voukld be supported, for
349 // example `(i, j, k, p) -> (j, i, k, p)` where we need to collapse only `k`
350 // and `p`.
351 if (!readOp.getPermutationMap().isMinorIdentity())
352 return rewriter.notifyMatchFailure(readOp, "non-identity permutation");
353
354 // We handle transfers of vectors with rank >= 2 and a single scalable
355 // dimension. This transformation aims to transform an LLVM-illegal type
356 // into an LLVM-legal type and one dimensional vectors are already
357 // LLVM-legal, even if scalable. A value of a vector type with more than one
358 // scalable dimension is impossible to represent using a vector type with no
359 // scalable dimensions or a single one. For example a `vector<[4]x[4]xi8>`
360 // would have `4 * 4 * vscale * vscale` elements and this quantity is
361 // impossible to represent as `N` or `N * vscale` (where `N` is a constant).
362 VectorType origVT = readOp.getVectorType();
363 ArrayRef<bool> origScalableDims = origVT.getScalableDims();
364 const int64_t origVRank = origVT.getRank();
365 if (origVRank < 2 || origVT.getNumScalableDims() != 1)
366 return rewriter.notifyMatchFailure(readOp, "wrong dimensions");
367
368 // Number of trailing dimensions to collapse, including the scalable
369 // dimension. Nothing to do if the single scalable dimension is already the
370 // last one.
371 const int64_t numCollapseDims = std::distance(
372 llvm::find(origScalableDims, true), origScalableDims.end());
373 if (numCollapseDims < 2)
374 return rewriter.notifyMatchFailure(readOp,
375 "scalable dimension is trailing");
376
377 // We want a simple memref (not a tensor) with contiguous elements for at
378 // least all the trailing dimensions up to and including the scalable one.
379 auto memTy = dyn_cast<MemRefType>(readOp.getBase().getType());
380 if (!(memTy && memTy.areTrailingDimsContiguous(numCollapseDims)))
381 return rewriter.notifyMatchFailure(
382 readOp, "non-contiguous memref dimensions to collapse");
383
384 // The dimensions to collapse (excluding the scalable one) of the vector and
385 // the memref must match. A dynamic memref dimension is considered
386 // non-matching. The transfers from the dimensions to collapse must be
387 // in-bounds (it follows the corresponding indices would be zero). This
388 // guarantees that the operation transfers a contiguous block
389 // and no padding is necessary.
390 if (!llvm::equal(memTy.getShape().take_back(numCollapseDims - 1),
391 origVT.getShape().take_back(numCollapseDims - 1)))
392 return rewriter.notifyMatchFailure(
393 readOp, "memref and vector dimensions do not match");
394
395 SmallVector<bool> origInBounds = readOp.getInBoundsValues();
396 if (!llvm::all_of(
397 ArrayRef<bool>(origInBounds).take_back(numCollapseDims - 1),
398 [](bool v) { return v; }))
399 return rewriter.notifyMatchFailure(
400 readOp, "out-of-bounds transfer from a dimension to collapse");
401
402 // Collapse the trailing dimensions of the memref.
403 SmallVector<ReassociationIndices> reassoc;
404 for (int64_t i = 0; i < memTy.getRank() - numCollapseDims + 1; ++i)
405 reassoc.push_back({i});
406 for (int64_t i = memTy.getRank() - numCollapseDims + 1; i < memTy.getRank();
407 ++i)
408 reassoc.back().push_back(i);
409 if (!memref::CollapseShapeOp::isGuaranteedCollapsible(memTy, reassoc))
410 return failure();
411 Value collapsedMem = memref::CollapseShapeOp::create(
412 rewriter, readOp.getLoc(), readOp.getBase(), reassoc);
413
414 // Get a vector type with collapsed trailing dimensions.
415 SmallVector<int64_t> shape(origVT.getShape());
416 for (int64_t i = origVRank - numCollapseDims + 1; i < origVRank; ++i)
417 shape[origVRank - numCollapseDims] *= shape[i];
418 shape.pop_back_n(numCollapseDims - 1);
419 auto collapsedVT =
420 VectorType::get(shape, origVT.getElementType(),
421 origScalableDims.drop_back(numCollapseDims - 1));
422
423 // Drop the extra (zero) indices.
424 auto indices = readOp.getIndices().drop_back(numCollapseDims - 1);
425
426 // Create the new `transfer_read`.
427 auto newReadOp = vector::TransferReadOp::create(
428 rewriter, readOp.getLoc(), collapsedVT, collapsedMem, indices,
429 readOp.getPadding(),
430 ArrayRef<bool>(origInBounds).drop_back(numCollapseDims - 1));
431
432 // Cast back to the original vector type.
433 auto toOrigShape = vector::ShapeCastOp::create(rewriter, readOp.getLoc(),
434 origVT, newReadOp);
435
436 rewriter.replaceOp(readOp, toOrigShape);
437 return success();
438 }
439};
440
441} // namespace
442
446 .add<RelaxScalableVectorAllocaAlignment,
447 LegalizeSVEMaskAllocation<memref::AllocaOp>,
448 LegalizeSVEMaskAllocation<memref::AllocOp>,
449 LegalizeSVEMaskTypeCastConversion, LegalizeSVEMaskStoreConversion,
450 LegalizeSVEMaskLoadConversion, LegalizeTransferRead>(
451 patterns.getContext());
452}
453
454namespace {
455struct LegalizeVectorStorage
456 : public arm_sve::impl::LegalizeVectorStorageBase<LegalizeVectorStorage> {
457
458 void runOnOperation() override {
461 if (failed(applyPatternsGreedily(getOperation(), std::move(patterns)))) {
462 signalPassFailure();
463 }
465 target.addDynamicallyLegalOp<UnrealizedConversionCastOp>(
466 [](UnrealizedConversionCastOp unrealizedConversion) {
467 return !unrealizedConversion->hasAttr(kSVELegalizerTag);
468 });
469 // This detects if we failed to completely legalize the IR.
470 if (failed(applyPartialConversion(getOperation(), target, {})))
471 signalPassFailure();
472 }
473};
474
475} // namespace
476
478 return std::make_unique<LegalizeVectorStorage>();
479}
return success()
constexpr StringLiteral kSVELegalizerTag("__arm_sve_legalize_vector_storage__")
b getContext())
UnitAttr getUnitAttr()
Definition Builders.cpp:98
StringAttr getStringAttr(const Twine &bytes)
Definition Builders.cpp:262
NamedAttribute represents a combination of a name and an Attribute value.
Definition Attributes.h:164
Operation * insert(Operation *op)
Insert the given operation at the current insertion point and return it.
Definition Builders.cpp:421
Operation is the basic unit of execution within MLIR.
Definition Operation.h:88
bool hasAttr(StringAttr name)
Return true if the operation has an attribute with the provided name, false otherwise.
Definition Operation.h:560
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
This class provides an abstraction over the various different ranges of value types.
Definition TypeRange.h:37
This class provides an abstraction over the different types of ranges over Values.
Definition ValueRange.h:387
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
Type getType() const
Return the type of this value.
Definition Value.h:105
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition Value.cpp:18
This is a builder type that keeps local references to arguments.
Builder & setDim(unsigned pos, int64_t val)
Set a dim in shape @pos to val.
void populateLegalizeVectorStoragePatterns(RewritePatternSet &patterns)
Collect a set of patterns to legalize Arm SVE vector storage.
std::unique_ptr< Pass > createLegalizeVectorStoragePass()
Pass to legalize Arm SVE vector storage.
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition Remarks.h:561
Include the generated interface declarations.
LogicalResult applyPatternsGreedily(Region &region, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr)
Rewrite ops in the given region, which must be isolated from above, by repeatedly applying the highes...
const FrozenRewritePatternSet & patterns
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...