MLIR 22.0.0git
VectorLegalization.cpp
Go to the documentation of this file.
1//===- VectorLegalization.cpp - Legalize vectors for lowering to ArmSME ---===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This pass legalizes vector operations so they can be lowered to ArmSME.
10//
11// Note: In the context of this pass 'tile' always refers to an SME tile.
12//
13//===----------------------------------------------------------------------===//
14
30
31#define DEBUG_TYPE "arm-sme-vector-legalization"
32
33namespace mlir::arm_sme {
34#define GEN_PASS_DEF_VECTORLEGALIZATION
35#include "mlir/Dialect/ArmSME/Transforms/Passes.h.inc"
36} // namespace mlir::arm_sme
37
38using namespace mlir;
39using namespace mlir::arm_sme;
40
41namespace {
42
43//===----------------------------------------------------------------------===//
44// Decomposition of vector operations larger than an SME tile
45//===----------------------------------------------------------------------===//
46
47// Common match failure reasons.
48static constexpr StringLiteral kMatchFailureNotSMETileTypeMultiple(
49 "op vector size is not multiple of SME tiles");
50static constexpr StringLiteral kMatchFailureUnsupportedMaskOp(
51 "op mask is unsupported for legalization/decomposition");
52static constexpr StringLiteral
53 kMatchFailureNonPermutationMap("op affine map is not a permutation");
54static constexpr StringLiteral kMatchFailureNotIllegalToLegal(
55 "expected transpose from illegal type to legal type");
56
57/// An SMESubTile represents a single SME-sized sub-tile from decomposing a
58/// larger vector type. The (`row`, `col`) are the position of the tile in the
59/// original vector type. For example for an [8]x[8] tile with four [4]x[4]
60/// sub-tiles, we would have:
61///
62/// 8 x vscale
63/// ┌─────────────┬─────────────┐
64/// │(0,0) │(0,4) │
65/// │ │ │
66/// ├─────────────┼─────────────┤ 8 x vscale
67/// │(4,0) │(4,4) │
68/// │ │ │
69/// └─────────────┴─────────────┘
70struct SMESubTile {
71 // Note: The units of (row, col) are vscale (as SME tiles are scalable).
72 int row{0};
73 int col{0};
74 // The SME tile type.
75 VectorType type;
76};
77
78/// Adds a constant elementwise scalable offset to `indices` (which are of equal
79/// length). For example, in the 2D case this would return:
80// { indices[0] + offset[0] * vscale, indices[1] + offset[1] * vscale }
81SmallVector<Value, 2> addConstantScalableOffset(OpBuilder &builder,
82 Location loc,
84 ArrayRef<int> scalableOffsets) {
85 auto vscale = vector::VectorScaleOp::create(builder, loc);
86 return llvm::map_to_vector(
87 llvm::zip_equal(indices, scalableOffsets), [&](auto pair) -> Value {
88 auto [index, base] = pair;
89 auto offset = arith::MulIOp::create(
90 builder, loc, arith::ConstantIndexOp::create(builder, loc, base),
91 vscale);
92 return arith::AddIOp::create(builder, loc, index, offset);
93 });
94}
95
96/// Adjusts `indices` (e.g. from a load/store) for a larger vector type to
97/// indices for one of the SME sub-tiles it will decompose into.
98///
99/// For example, if you were to decompose an 8x8 load into four 4x4 tiles, the
100/// indices for each tile would need to be adjusted as follows:
101///
102/// initial indices = [a,b], inital size = 8x8, target size = 4x4
103/// ┌─────────────┬─────────────┐
104/// │[a,b] │[a,b+4] │
105/// │ │ │
106/// ├─────────────┼─────────────┤
107/// │[a+4,b] │[a+4,b+4] │
108/// │ │ │
109/// └─────────────┴─────────────┘
110SmallVector<Value, 2> getSMESubTileIndices(OpBuilder &builder, Location loc,
112 SMESubTile smeTile) {
113 return addConstantScalableOffset(builder, loc, indices,
114 {smeTile.row, smeTile.col});
115}
116
117/// Returns true if `mask` is generated by an operation that can be decomposed
118/// for SME. Currently, that is just no mask, or vector.create_mask.
119/// TODO: Add support for vector.constant_mask once required for SME.
120bool isSupportedMaskOp(Value mask) {
121 return !mask || mask.getDefiningOp<vector::CreateMaskOp>();
122}
123
124/// Extracts a mask for an SME sub-tile from the mask of a larger vector type.
125Value extractSMEMask(OpBuilder &builder, Location loc, Value mask,
126 SMESubTile smeTile) {
127 assert(isSupportedMaskOp(mask));
128 if (!mask)
129 return Value{};
130 auto createMask = mask.getDefiningOp<vector::CreateMaskOp>();
131 // The operands of `vector.create_mask` (from a 2D perspective) are the
132 // coordinates where the mask ends. So we subtract where this tile starts,
133 // from the mask operands to get the parameters for this sub-tile.
134 auto smeTileMaskDims = addConstantScalableOffset(
135 builder, loc, createMask.getOperands(), {-smeTile.row, -smeTile.col});
136 auto smeTileCreateMask = vector::CreateMaskOp::create(
137 builder, loc, smeTile.type.clone(builder.getI1Type()), smeTileMaskDims);
138 return smeTileCreateMask.getResult();
139}
140
141/// Constructs an iterator that returns each SME tile (with coordinates)
142/// contained within a VectorType. For example, if decomposing an [8]x[8] into
143/// [4]x[4] tiles, the iterator would yield the tiles: (0, 0), (0, 4), (4, 0),
144/// (4, 4).
145auto decomposeToSMETiles(OpBuilder &builder, VectorType type,
146 VectorType smeTileType,
147 bool transposeIndices = false) {
148 return llvm::map_range(
150 type.getShape(),
151 {std::min(type.getDimSize(0), smeTileType.getDimSize(0)),
152 std::min(type.getDimSize(1), smeTileType.getDimSize(1))}),
153 [=](auto indices) {
154 int row = int(indices[0]);
155 int col = int(indices[1]);
156 if (transposeIndices)
157 std::swap(row, col);
158 return SMESubTile{row, col, smeTileType};
159 });
160}
161
162/// Returns the number of SME tiles that fit into the (2D-scalable) vector type
163/// `type`.
164int getNumberOfSMETilesForVectorType(VectorType type) {
165 assert(isMultipleOfSMETileVectorType(type) &&
166 "`type` not multiple of SME tiles");
167 int64_t vectorRows = type.getDimSize(0);
168 int64_t vectorCols = type.getDimSize(1);
169 auto elementType = type.getElementType();
170 unsigned minNumElts = getSMETileSliceMinNumElts(elementType);
171 return (vectorRows * vectorCols) / (minNumElts * minNumElts);
172}
173
174/// Legalize `arith.constant dense<value>` splat operations to fit within SME
175/// tiles by decomposing them into tile-sized operations.
176struct LegalizeArithConstantOpsByDecomposition
177 : public OpConversionPattern<arith::ConstantOp> {
178 using OpConversionPattern::OpConversionPattern;
179
180 LogicalResult
181 matchAndRewrite(arith::ConstantOp constantOp, OpAdaptor adaptor,
182 ConversionPatternRewriter &rewriter) const override {
183 auto vectorType = dyn_cast<VectorType>(constantOp.getType());
184 auto denseAttr = dyn_cast<DenseElementsAttr>(constantOp.getValueAttr());
185 if (!vectorType || !denseAttr || !denseAttr.isSplat())
186 return failure();
187
188 if (!isMultipleOfSMETileVectorType(vectorType))
189 return rewriter.notifyMatchFailure(constantOp,
190 kMatchFailureNotSMETileTypeMultiple);
191
192 auto smeTileType = getSMETileTypeForElement(vectorType.getElementType());
193 auto tileCount = getNumberOfSMETilesForVectorType(vectorType);
194 auto tileSplat = arith::ConstantOp::create(
195 rewriter, constantOp.getLoc(), denseAttr.resizeSplat(smeTileType));
196 SmallVector<Value> repl(tileCount, tileSplat);
197 rewriter.replaceOpWithMultiple(constantOp, {repl});
198
199 return success();
200 }
201};
202
203/// Legalize `vector.outerproduct` operations to fit within SME tiles by
204/// decomposing them into tile-sized operations.
205struct LegalizeVectorOuterProductOpsByDecomposition
206 : public OpConversionPattern<vector::OuterProductOp> {
207 using OpConversionPattern::OpConversionPattern;
208
209 LogicalResult
210 matchAndRewrite(vector::OuterProductOp outerProductOp,
211 OneToNOpAdaptor adaptor,
212 ConversionPatternRewriter &rewriter) const override {
213 auto vectorType = outerProductOp.getResultVectorType();
214 if (!isMultipleOfSMETileVectorType(vectorType))
215 return rewriter.notifyMatchFailure(outerProductOp,
216 kMatchFailureNotSMETileTypeMultiple);
217
218 Value mask;
219 Operation *rootOp = outerProductOp;
220 auto loc = outerProductOp.getLoc();
221 if (outerProductOp.isMasked()) {
222 auto maskOp = outerProductOp.getMaskingOp();
223 mask = maskOp.getMask();
224 rootOp = maskOp;
225 rewriter.setInsertionPoint(rootOp);
226 }
227
228 if (!isSupportedMaskOp(mask))
229 return rewriter.notifyMatchFailure(outerProductOp,
230 kMatchFailureUnsupportedMaskOp);
231
232 ValueRange accSMETiles = adaptor.getAcc();
233 auto smeTileType = getSMETileTypeForElement(vectorType.getElementType());
234 VectorType sliceType = VectorType::Builder(smeTileType).dropDim(0);
235
236 SmallVector<Value> resultSMETiles;
237 for (auto [index, smeTile] : llvm::enumerate(
238 decomposeToSMETiles(rewriter, vectorType, smeTileType))) {
239
240 auto smeMask = extractSMEMask(rewriter, loc, mask, smeTile);
241 auto lhs = vector::ScalableExtractOp::create(
242 rewriter, loc, sliceType, outerProductOp.getLhs(), smeTile.row);
243 auto rhs = vector::ScalableExtractOp::create(
244 rewriter, loc, sliceType, outerProductOp.getRhs(), smeTile.col);
245 auto smeOuterProduct = vector::OuterProductOp::create(
246 rewriter, loc, smeTileType, lhs, rhs,
247 !accSMETiles.empty() ? accSMETiles[index] : Value{},
248 outerProductOp.getKind());
249
250 auto *maskedOuterProduct =
251 vector::maskOperation(rewriter, smeOuterProduct, smeMask);
252 resultSMETiles.push_back(maskedOuterProduct->getResult(0));
253 }
254
255 rewriter.replaceOpWithMultiple(rootOp, {resultSMETiles});
256 return success();
257 }
258};
259
260// Workaround for `vector.mask`. We want to match on `vector.outerproduct` (to
261// get the help of the type conversion), but doing so results in the type
262// conversion adding target materializations in the `vector.mask` region
263// (invalid). This pattern matches on `vector.mask` then calls into the
264// `vector.outerproduct` pattern to work around this issue.
265struct LegalizeMaskedVectorOuterProductOpsByDecomposition
266 : public OpConversionPattern<vector::MaskOp> {
267 using OpConversionPattern::OpConversionPattern;
268
269 LogicalResult
270 matchAndRewrite(vector::MaskOp maskOp, OneToNOpAdaptor adaptor,
271 ConversionPatternRewriter &rewriter) const override {
272 if (auto outerProductOp = llvm::dyn_cast_or_null<vector::OuterProductOp>(
273 maskOp.getMaskableOp())) {
274 LegalizeVectorOuterProductOpsByDecomposition pattern(*getTypeConverter(),
275 getContext());
276 return static_cast<RewritePattern &>(pattern).matchAndRewrite(
277 outerProductOp, rewriter);
278 }
279 return failure();
280 }
281};
282
283/// Legalize `vector.transfer_read` operations to fit within SME tiles by
284/// decomposing them into tile-sized operations.
285struct LegalizeTransferReadOpsByDecomposition
286 : public OpConversionPattern<vector::TransferReadOp> {
287 using OpConversionPattern::OpConversionPattern;
288
289 LogicalResult
290 matchAndRewrite(vector::TransferReadOp readOp, OneToNOpAdaptor adaptor,
291 ConversionPatternRewriter &rewriter) const override {
292 auto vectorType = readOp.getVectorType();
293 if (!isMultipleOfSMETileVectorType(vectorType))
294 return rewriter.notifyMatchFailure(readOp,
295 kMatchFailureNotSMETileTypeMultiple);
296
297 auto mask = readOp.getMask();
298 if (!isSupportedMaskOp(mask))
299 return rewriter.notifyMatchFailure(readOp,
300 kMatchFailureUnsupportedMaskOp);
301
302 auto permutationMap = readOp.getPermutationMap();
303 if (!permutationMap.isPermutation())
304 return rewriter.notifyMatchFailure(readOp,
305 kMatchFailureNonPermutationMap);
307 // Note: For 2D vector types the only non-identity permutation is a simple
308 // transpose [1, 0].
309 bool transposed = !permutationMap.isIdentity();
311 auto loc = readOp.getLoc();
312 auto smeTileType = getSMETileTypeForElement(vectorType.getElementType());
313
314 SmallVector<Value> resultSMETiles;
315 for (SMESubTile smeTile :
316 decomposeToSMETiles(rewriter, vectorType, smeTileType, transposed)) {
317 auto smeMask = extractSMEMask(rewriter, loc, mask, smeTile);
318 auto smeRead = vector::TransferReadOp::create(
319 rewriter, loc, smeTileType, readOp.getBase(),
320 getSMESubTileIndices(rewriter, loc, readOp.getIndices(), smeTile),
321 readOp.getPermutationMapAttr(), readOp.getPadding(), smeMask,
322 readOp.getInBoundsAttr());
323 resultSMETiles.push_back(smeRead);
324 }
326 rewriter.replaceOpWithMultiple(readOp, {resultSMETiles});
327 return success();
329};
330
331/// Legalize `vector.transfer_write` operations to fit within SME tiles by
332/// decomposing them into tile-sized operations.
333struct LegalizeTransferWriteOpsByDecomposition
334 : public OpConversionPattern<vector::TransferWriteOp> {
335 using OpConversionPattern::OpConversionPattern;
336
337 LogicalResult
338 matchAndRewrite(vector::TransferWriteOp writeOp, OneToNOpAdaptor adaptor,
339 ConversionPatternRewriter &rewriter) const override {
340 auto vectorType = writeOp.getVectorType();
341 if (!isMultipleOfSMETileVectorType(vectorType))
342 return rewriter.notifyMatchFailure(writeOp,
343 kMatchFailureNotSMETileTypeMultiple);
344
345 auto mask = writeOp.getMask();
346 if (!isSupportedMaskOp(mask))
347 return rewriter.notifyMatchFailure(writeOp,
348 kMatchFailureUnsupportedMaskOp);
349
350 auto permutationMap = writeOp.getPermutationMap();
351 if (!permutationMap.isPermutation())
352 return rewriter.notifyMatchFailure(writeOp,
353 kMatchFailureNonPermutationMap);
354
355 // Note: For 2D vector types the only non-identity permutation is a simple
356 // transpose [1, 0].
357 bool transposed = !permutationMap.isIdentity();
358
359 auto loc = writeOp.getLoc();
360 auto smeTileType = getSMETileTypeForElement(vectorType.getElementType());
361 auto inputSMETiles = adaptor.getValueToStore();
362
363 Value destTensorOrMemref = writeOp.getBase();
364 for (auto [index, smeTile] : llvm::enumerate(decomposeToSMETiles(
365 rewriter, vectorType, smeTileType, transposed))) {
366 auto smeMask = extractSMEMask(rewriter, loc, mask, smeTile);
367 auto smeWrite = vector::TransferWriteOp::create(
368 rewriter, loc, inputSMETiles[index], destTensorOrMemref,
369 getSMESubTileIndices(rewriter, loc, writeOp.getIndices(), smeTile),
370 writeOp.getPermutationMapAttr(), smeMask, writeOp.getInBoundsAttr());
371 if (writeOp.hasPureTensorSemantics())
372 destTensorOrMemref = smeWrite.getResult();
373 }
374
375 if (writeOp.hasPureTensorSemantics())
376 rewriter.replaceOp(writeOp, destTensorOrMemref);
377 else
378 rewriter.eraseOp(writeOp);
379
380 return success();
381 }
382};
383
384/// Legalize a multi-tile transfer_write as a single store loop. This is done as
385/// part of type decomposition as at this level we know each tile write is
386/// disjoint, but that information is lost after decomposition (without analysis
387/// to reconstruct it).
388///
389/// Example (pseudo-MLIR):
390///
391/// ```
392/// vector.transfer_write %vector, %dest[%y, %x], %mask
393/// : vector<[16]x[8]xi16>, memref<?x?xi16>
394/// ```
395/// Is rewritten to:
396/// ```
397/// scf.for %slice_idx = %c0 to %c8_vscale step %c1 {
398/// %upper_slice_mask = vector.extract %mask[%slice_idx] ─┐
399/// : vector<[8]xi1> from vector<[16]x[8]xi1> |
400/// %upper_slice = vector.extract %upper_tile[%slice_idx] |- Store upper tile
401/// : vector<[8]xi16> from vector<[8]x[8]xi16> |
402/// vector.transfer_write %upper_slice, |
403/// %dest[%slice_idx + %y, %x], %upper_slice_mask |
404/// : vector<[8]xi16>, memref<?x?xi16> ┘
405/// %lower_slice_idx = %slice_idx + %c8_vscale ─┐
406/// %lower_slice_mask = vector.extract %mask[%lower_slice_idx] |
407/// : vector<[8]xi1> from vector<[16]x[8]xi1> |
408/// %lower_slice = vector.extract %lower_tile[%slice_idx] |- Store lower
409/// : vector<[8]xi16> from vector<[8]x[8]xi16> | tile
410/// vector.transfer_write %lower_slice, |
411/// %dest[%lower_slice_idx + %y, %x], %lower_slice_mask |
412/// : vector<[8]xi16>, memref<?x?xi16> ┘
413/// }
414/// ```
415struct LegalizeMultiTileTransferWriteAsStoreLoop
416 : public OpConversionPattern<vector::TransferWriteOp> {
417 using OpConversionPattern::OpConversionPattern;
418
419 LogicalResult
420 matchAndRewrite(vector::TransferWriteOp writeOp, OneToNOpAdaptor adaptor,
421 ConversionPatternRewriter &rewriter) const override {
422 if (writeOp.hasPureTensorSemantics())
423 return rewriter.notifyMatchFailure(
424 writeOp, "TODO: tensor semantics are unsupported");
425
426 auto permutationMap = writeOp.getPermutationMap();
427 if (!permutationMap.isPermutation())
428 return rewriter.notifyMatchFailure(writeOp,
429 kMatchFailureNonPermutationMap);
430
431 bool transposed = !permutationMap.isIdentity();
432 if (transposed)
433 return rewriter.notifyMatchFailure(writeOp,
434 "TODO: transpose unsupported");
435
436 auto vectorType = writeOp.getVectorType();
437 if (!isMultipleOfSMETileVectorType(vectorType))
438 return rewriter.notifyMatchFailure(writeOp,
439 kMatchFailureNotSMETileTypeMultiple);
440
441 // Note: We also disallow masks where any dimension is > 16 because that
442 // prevents the masking from being lowered to use arm_sve.psel.
443 auto mask = writeOp.getMask();
444 if (!isSupportedMaskOp(mask) || (mask && (vectorType.getDimSize(0) > 16 ||
445 vectorType.getDimSize(1) > 16)))
446 return rewriter.notifyMatchFailure(writeOp,
447 kMatchFailureUnsupportedMaskOp);
448
449 auto loc = writeOp.getLoc();
450 auto createVscaleMultiple =
452
453 // Get SME tile and slice types.
454 auto smeTileType = getSMETileTypeForElement(vectorType.getElementType());
455 auto minTileSlices = smeTileType.getDimSize(0);
456 VectorType sliceMaskType =
457 VectorType::get(minTileSlices, rewriter.getI1Type(), true);
458
459 // Create loop over all tile slices.
460 auto lowerBound = arith::ConstantIndexOp::create(rewriter, loc, 0);
461 auto upperBound = createVscaleMultiple(minTileSlices);
462 auto step = arith::ConstantIndexOp::create(rewriter, loc, 1);
463 auto storeLoop =
464 scf::ForOp::create(rewriter, loc, lowerBound, upperBound, step);
465 rewriter.setInsertionPointToStart(storeLoop.getBody());
466
467 // For each sub-tile of the multi-tile `vectorType`.
468 auto inputSMETiles = adaptor.getValueToStore();
469 auto tileSliceIndex = storeLoop.getInductionVar();
470 for (auto [index, smeTile] : llvm::enumerate(
471 decomposeToSMETiles(rewriter, vectorType, smeTileType))) {
472 // The coordinates of the tile within `vectorType`.
473 auto tileRow = createVscaleMultiple(smeTile.row);
474 auto tileCol = createVscaleMultiple(smeTile.col);
475
476 // The current slice of `vectorType` we are processing.
477 auto sliceIndex =
478 arith::AddIOp::create(rewriter, loc, tileRow, tileSliceIndex);
479
480 // Where in the destination memref the current slice will be stored.
481 auto storeRow = arith::AddIOp::create(rewriter, loc, sliceIndex,
482 writeOp.getIndices()[0]);
483 auto storeCol = arith::AddIOp::create(rewriter, loc, tileCol,
484 writeOp.getIndices()[1]);
485
486 // Extract the mask for the current slice.
487 Value sliceMask = nullptr;
488 if (mask) {
489 sliceMask = vector::ExtractOp::create(rewriter, loc, mask,
490 OpFoldResult(sliceIndex));
491 if (sliceMaskType != sliceMask.getType())
492 sliceMask = vector::ScalableExtractOp::create(
493 rewriter, loc, sliceMaskType, sliceMask, smeTile.col);
494 }
495
496 // Extract and store the current slice.
497 Value tile = inputSMETiles[index];
498 auto slice =
499 vector::ExtractOp::create(rewriter, loc, tile, tileSliceIndex);
500 vector::TransferWriteOp::create(
501 rewriter, loc, slice, writeOp.getBase(),
502 ValueRange{storeRow, storeCol},
503 AffineMapAttr::get(writeOp.getPermutationMap().dropResult(0)),
504 sliceMask,
505 rewriter.getBoolArrayAttr(
506 ArrayRef<bool>(writeOp.getInBoundsValues()).drop_front()));
507 }
508
509 rewriter.eraseOp(writeOp);
510 return success();
511 }
512};
513
514//===----------------------------------------------------------------------===//
515// ArmSME-specific fixup canonicalizations/folds
516//===----------------------------------------------------------------------===//
517
518/// Folds an extract from a 3D `vector.create_mask` (which is a vector of
519/// SME-like masks), into a compare and a 2D `vector.create_mask`. This is
520/// necessary for the mask to be lowered to ArmSME.
521///
522/// Example:
523///
524/// BEFORE:
525/// ```mlir
526/// %mask = vector.create_mask %nonConstantDim, %a, %b : vector<4x[4]x[4]xi1>
527/// %subMask = vector.extract %mask[2]
528/// : vector<[4]x[4]xi1> from vector<4x[4]x[4]xi1>
529/// ```
530///
531/// AFTER:
532/// ```mlir
533/// %extractionInTrueRegion = arith.cmpi slt, %c2, %nonConstantDim : index
534/// %newMaskFrontDim = arith.select %extractionInTrueRegion, %a, %c0 : index
535/// %subMask = vector.create_mask %newMaskFrontDim, %b : vector<[4]x[4]xi1>
536/// ```
537struct FoldExtractFromVectorOfSMELikeCreateMasks
538 : public OpRewritePattern<vector::ExtractOp> {
539 using OpRewritePattern<vector::ExtractOp>::OpRewritePattern;
540
541 LogicalResult matchAndRewrite(vector::ExtractOp extractOp,
542 PatternRewriter &rewriter) const override {
543 auto loc = extractOp.getLoc();
544 auto createMaskOp =
545 extractOp.getSource().getDefiningOp<vector::CreateMaskOp>();
546 if (!createMaskOp)
547 return rewriter.notifyMatchFailure(
548 extractOp, "extract not from vector.create_mask op");
549
550 VectorType extractedMaskType =
551 llvm::dyn_cast<VectorType>(extractOp.getResult().getType());
552 if (!extractedMaskType)
553 return rewriter.notifyMatchFailure(extractOp,
554 "extracted type is not a vector type");
555
556 auto numScalable = extractedMaskType.getNumScalableDims();
557 if (numScalable != 2)
558 return rewriter.notifyMatchFailure(
559 extractOp, "expected extracted type to be an SME-like mask");
560
561 // TODO: Support multiple extraction indices.
562 if (extractOp.getStaticPosition().size() != 1)
563 return rewriter.notifyMatchFailure(
564 extractOp, "only a single extraction index is supported");
565
566 auto frontMaskDim = createMaskOp.getOperand(0);
567 if (frontMaskDim.getDefiningOp<arith::ConstantOp>())
568 return rewriter.notifyMatchFailure(
569 extractOp,
570 "constant vector.create_masks dims should be folded elsewhere");
571
572 auto zero = arith::ConstantIndexOp::create(rewriter, loc, 0);
573 auto extractionIndex = getValueOrCreateConstantIndexOp(
574 rewriter, loc, extractOp.getMixedPosition()[0]);
575 auto extractionInTrueRegion = arith::CmpIOp::create(
576 rewriter, loc, rewriter.getI1Type(), arith::CmpIPredicate::slt,
577 extractionIndex, frontMaskDim);
578 auto newMaskFrontDim =
579 arith::SelectOp::create(rewriter, loc, extractionInTrueRegion,
580 createMaskOp.getOperand(1), zero);
581
582 rewriter.replaceOpWithNewOp<vector::CreateMaskOp>(
583 extractOp, extractedMaskType,
584 ValueRange{newMaskFrontDim, createMaskOp.getOperand(2)});
585 return success();
586 }
587};
588
589/// A vector type where no fixed dimension comes after a scalable dimension.
590bool isLegalVectorType(VectorType vType) {
591 bool seenFixedDim = false;
592 for (bool scalableFlag : llvm::reverse(vType.getScalableDims())) {
593 seenFixedDim |= !scalableFlag;
594 if (seenFixedDim && scalableFlag)
595 return false;
596 }
597 return true;
598}
599
600/// Lifts an illegal vector.transpose and vector.transfer_read to a
601/// memref.subview + memref.transpose, followed by a legal read.
602///
603/// 'Illegal' here means a leading scalable dimension and a fixed trailing
604/// dimension, which has no valid lowering.
605///
606/// The memref.transpose is metadata-only transpose that produces a strided
607/// memref, which eventually becomes a loop reading individual elements.
608///
609/// Example:
610///
611/// BEFORE:
612/// ```mlir
613/// %illegalRead = vector.transfer_read %memref[%a, %b]
614/// : memref<?x?xf32>, vector<[8]x4xf32>
615/// %legalType = vector.transpose %illegalRead, [1, 0]
616/// : vector<[8]x4xf32> to vector<4x[8]xf32>
617/// ```
618///
619/// AFTER:
620/// ```mlir
621/// %readSubview = memref.subview %memref[%a, %b] [%c8_vscale, %c4] [%c1, %c1]
622/// : memref<?x?xf32> to memref<?x?xf32>
623/// %transpose = memref.transpose %readSubview (d0, d1) -> (d1, d0)
624/// : memref<?x?xf32> to memref<?x?xf32>
625/// %legalType = vector.transfer_read %transpose[%c0, %c0]
626/// : memref<?x?xf32>, vector<4x[8]xf32>
627/// ```
628struct LiftIllegalVectorTransposeToMemory
629 : public OpRewritePattern<vector::TransposeOp> {
630 using OpRewritePattern<vector::TransposeOp>::OpRewritePattern;
631
632 static Value getExtensionSource(Operation *op) {
633 if (isa_and_present<arith::ExtSIOp, arith::ExtUIOp, arith::ExtFOp>(op))
634 return op->getOperand(0);
635 return {};
636 }
637
638 LogicalResult matchAndRewrite(vector::TransposeOp transposeOp,
639 PatternRewriter &rewriter) const override {
640 auto sourceType = transposeOp.getSourceVectorType();
641 auto resultType = transposeOp.getResultVectorType();
642 if (isLegalVectorType(sourceType) || !isLegalVectorType(resultType))
643 return rewriter.notifyMatchFailure(transposeOp,
644 kMatchFailureNotIllegalToLegal);
645
646 // Look through extend for transfer_read.
647 Value maybeRead = transposeOp.getVector();
648 auto *transposeSourceOp = maybeRead.getDefiningOp();
649 Operation *extendOp = nullptr;
650 if (Value extendSource = getExtensionSource(transposeSourceOp)) {
651 maybeRead = extendSource;
652 extendOp = transposeSourceOp;
653 }
654
655 auto illegalRead = maybeRead.getDefiningOp<vector::TransferReadOp>();
656 if (!illegalRead)
657 return rewriter.notifyMatchFailure(
658 transposeOp,
659 "expected source to be (possibly extended) transfer_read");
660
661 if (!illegalRead.getPermutationMap().isIdentity())
662 return rewriter.notifyMatchFailure(
663 illegalRead, "expected read to have identity permutation map");
664
665 auto loc = transposeOp.getLoc();
666 auto zero = arith::ConstantIndexOp::create(rewriter, loc, 0);
667 auto one = arith::ConstantIndexOp::create(rewriter, loc, 1);
668
669 // Create a subview that matches the size of the illegal read vector type.
670 auto readType = illegalRead.getVectorType();
671 auto readSizes = llvm::map_to_vector(
672 llvm::zip_equal(readType.getShape(), readType.getScalableDims()),
673 [&](auto dim) -> Value {
674 auto [size, isScalable] = dim;
675 auto dimSize = arith::ConstantIndexOp::create(rewriter, loc, size);
676 if (!isScalable)
677 return dimSize;
678 auto vscale = vector::VectorScaleOp::create(rewriter, loc);
679 return arith::MulIOp::create(rewriter, loc, vscale, dimSize);
680 });
681 SmallVector<Value> strides(readType.getRank(), Value(one));
682 auto readSubview =
683 memref::SubViewOp::create(rewriter, loc, illegalRead.getBase(),
684 illegalRead.getIndices(), readSizes, strides);
685
686 // Apply the transpose to all values/attributes of the transfer_read:
687 // - The mask
688 Value mask = illegalRead.getMask();
689 if (mask) {
690 // Note: The transpose for the mask should fold into the
691 // vector.create_mask/constant_mask op, which will then become legal.
692 mask = vector::TransposeOp::create(rewriter, loc, mask,
693 transposeOp.getPermutation());
694 }
695 // - The source memref
696 mlir::AffineMap transposeMap = AffineMap::getPermutationMap(
697 transposeOp.getPermutation(), getContext());
698 auto transposedSubview = memref::TransposeOp::create(
699 rewriter, loc, readSubview, AffineMapAttr::get(transposeMap));
700 ArrayAttr inBoundsAttr = illegalRead.getInBoundsAttr();
701 // - The `in_bounds` attribute
702 if (inBoundsAttr) {
703 SmallVector<Attribute> inBoundsValues(inBoundsAttr.begin(),
704 inBoundsAttr.end());
705 applyPermutationToVector(inBoundsValues, transposeOp.getPermutation());
706 inBoundsAttr = rewriter.getArrayAttr(inBoundsValues);
707 }
708
709 VectorType legalReadType = resultType.clone(readType.getElementType());
710 // Note: The indices are all zero as the subview is already offset.
711 SmallVector<Value> readIndices(illegalRead.getIndices().size(), zero);
712 auto legalRead = vector::TransferReadOp::create(
713 rewriter, loc, legalReadType, transposedSubview, readIndices,
714 illegalRead.getPermutationMapAttr(), illegalRead.getPadding(), mask,
715 inBoundsAttr);
716
717 // Replace the transpose with the new read, extending the result if
718 // necessary.
719 rewriter.replaceOp(transposeOp, [&]() -> Operation * {
720 if (extendOp)
721 return rewriter.create(loc, extendOp->getName().getIdentifier(),
722 Value(legalRead), resultType);
723 return legalRead;
724 }());
725
726 return success();
727 }
728};
729
730/// Rewrites an illegal/unsupported SVE transfer_write(transpose) to instead use
731/// the ZA state. This workaround rewrite to support these transposes when ZA is
732/// available.
733///
734/// Example:
735///
736/// BEFORE:
737/// ```mlir
738/// %transpose = vector.transpose %vec, [1, 0]
739/// : vector<2x[4]xf32> to vector<[4]x2xf32>
740/// vector.transfer_write %transpose, %dest[%y, %x]
741/// : vector<[4]x2xf32>, memref<?x?xf32>
742/// ```
743///
744/// AFTER:
745/// ```mlir
746/// %0 = arm_sme.get_tile : vector<[4]x[4]xf32>
747/// %1 = vector.extract %vec[0] : vector<[4]xf32> from vector<2x[4]xf32>
748/// %2 = vector.insert %1, %0 [0] : vector<[4]xf32> into vector<[4]x[4]xf32>
749/// %3 = vector.extract %vec[1] : vector<[4]xf32> from vector<2x[4]xf32>
750/// %4 = vector.insert %3, %2 [1] : vector<[4]xf32> into vector<[4]x[4]xf32>
751/// %c4_vscale = arith.muli %vscale, %c4 : index
752/// %mask = vector.create_mask %c4_vscale, %c2 : vector<[4]x[4]xi1>
753/// vector.transfer_write %4, %dest[%y, %x], %mask
754/// {permutation_map = affine_map<(d0, d1) -> (d1, d0)>}
755/// : vector<[4]x[4]xf32>, memref<?x?xf32>
756/// ```
757///
758/// Values larger than a single tile are supported via decomposition.
759struct LowerIllegalTransposeStoreViaZA
760 : public OpRewritePattern<vector::TransferWriteOp> {
762
763 LogicalResult matchAndRewrite(vector::TransferWriteOp writeOp,
764 PatternRewriter &rewriter) const override {
765 if (!isSupportedMaskOp(writeOp.getMask()))
766 return rewriter.notifyMatchFailure(writeOp,
767 kMatchFailureUnsupportedMaskOp);
768
769 auto permutationMap = writeOp.getPermutationMap();
770 if (!permutationMap.isIdentity())
771 return rewriter.notifyMatchFailure(writeOp,
772 kMatchFailureNonPermutationMap);
773
774 auto transposeOp = writeOp.getVector().getDefiningOp<vector::TransposeOp>();
775 if (!transposeOp)
776 return failure();
777
778 auto sourceType = transposeOp.getSourceVectorType();
779 auto resultType = transposeOp.getResultVectorType();
780
781 if (resultType.getRank() != 2)
782 return rewriter.notifyMatchFailure(transposeOp, "TransposeOp not rank 2");
783
784 if (!isLegalVectorType(sourceType) || isLegalVectorType(resultType))
785 return rewriter.notifyMatchFailure(
786 transposeOp, "not illegal/unsupported SVE transpose");
787
788 auto smeTileType = getSMETileTypeForElement(resultType.getElementType());
789 VectorType smeSliceType = VectorType::Builder(smeTileType).dropDim(0);
790
791 if (sourceType.getDimSize(0) <= 1 ||
792 sourceType.getDimSize(1) % smeSliceType.getDimSize(0) != 0)
793 return rewriter.notifyMatchFailure(writeOp, "unsupported source shape");
794
795 auto loc = writeOp.getLoc();
796 auto createVscaleMultiple =
798
799 auto transposeMap = AffineMapAttr::get(
800 AffineMap::getPermutationMap(ArrayRef<int64_t>{1, 0}, getContext()));
801
802 // Note: We need to use `get_tile` as there's no vector-level `undef`.
803 Value undefTile = arm_sme::GetTileOp::create(rewriter, loc, smeTileType);
804 Value destTensorOrMemref = writeOp.getBase();
805 auto numSlicesPerTile =
806 std::min(sourceType.getDimSize(0), smeTileType.getDimSize(0));
807 auto numSlices =
808 arith::ConstantIndexOp::create(rewriter, loc, numSlicesPerTile);
809 for (auto [index, smeTile] : llvm::enumerate(
810 decomposeToSMETiles(rewriter, sourceType, smeTileType))) {
811 // 1. _Deliberately_ drop a scalable dimension and insert a fixed number
812 // of slices from the source type into the SME tile. Without checking
813 // vscale (and emitting multiple implementations) we can't make use of the
814 // rows of the tile after 1*vscale rows.
815 Value tile = undefTile;
816 for (int d = 0; d < numSlicesPerTile; ++d) {
817 Value vector =
818 vector::ExtractOp::create(rewriter, loc, transposeOp.getVector(),
819 rewriter.getIndexAttr(d + smeTile.row));
820 if (vector.getType() != smeSliceType) {
821 vector = vector::ScalableExtractOp::create(
822 rewriter, loc, smeSliceType, vector, smeTile.col);
823 }
824 tile = vector::InsertOp::create(rewriter, loc, vector, tile, d);
825 }
826
827 // 2. Transpose the tile position.
828 auto transposedRow = createVscaleMultiple(smeTile.col);
829 auto transposedCol =
830 arith::ConstantIndexOp::create(rewriter, loc, smeTile.row);
831
832 // 3. Compute mask for tile store.
833 Value maskRows;
834 Value maskCols;
835 if (auto mask = writeOp.getMask()) {
836 auto createMask = mask.getDefiningOp<vector::CreateMaskOp>();
837 maskRows = arith::SubIOp::create(
838 rewriter, loc, createMask.getOperand(0), transposedRow);
839 maskCols = arith::SubIOp::create(
840 rewriter, loc, createMask.getOperand(1), transposedCol);
841 maskCols = index::MinSOp::create(rewriter, loc, maskCols, numSlices);
842 } else {
843 maskRows = createVscaleMultiple(smeTileType.getDimSize(0));
844 maskCols = numSlices;
845 }
846 auto subMask = vector::CreateMaskOp::create(
847 rewriter, loc, smeTileType.clone(rewriter.getI1Type()),
848 ValueRange{maskRows, maskCols});
849
850 // 4. Emit a transposed tile write.
851 auto writeIndices = writeOp.getIndices();
852 Value destRow =
853 arith::AddIOp::create(rewriter, loc, transposedRow, writeIndices[0]);
854 Value destCol =
855 arith::AddIOp::create(rewriter, loc, transposedCol, writeIndices[1]);
856 auto smeWrite = vector::TransferWriteOp::create(
857 rewriter, loc, tile, destTensorOrMemref, ValueRange{destRow, destCol},
858 transposeMap, subMask, writeOp.getInBounds());
859
860 if (writeOp.hasPureTensorSemantics())
861 destTensorOrMemref = smeWrite.getResult();
862 }
863
864 if (writeOp.hasPureTensorSemantics())
865 rewriter.replaceOp(writeOp, destTensorOrMemref);
866 else
867 rewriter.eraseOp(writeOp);
868
869 return success();
870 }
871};
872
873/// Lower `vector.transfer_read` of a scalable column to `scf::for`
874///
875/// Lowers a "read" of a scalable column from a MemRef for which there is no
876/// hardware pperation that we could use to a loop over the rows to read and
877/// loads one element at a time.
878///
879/// BEFORE:
880/// ```
881/// %res = vector.transfer_read %mem[%a, %b] (...)
882/// : memref<?x?xf32>, vector<[4]x1xf32>
883/// ```
884///
885/// AFTER:
886/// ```
887/// %cst = arith.constant (...) : vector<[4]xf32>
888/// %vscale = vector.vscale
889/// %c4_vscale = arith.muli %vscale, %c4 : index
890/// %scf = scf.for %lb = %c0 to %c4_vscale step %c1 iter_args(%arg4 = %cst)
891/// -> (vector<[4]xf32>) {
892///
893/// %load = memref.load %mem[%arg3 + %a, %b] : memref<?x?xf32>
894/// %vec = vector.insert %load, %cst [%arg3] : f32 into vector<[4]xf32>
895/// scf.yield %vec : vector<[4]xf32>
896/// }
897/// %res = vector.shape_cast %scf : vector<[4]xf32> to vector<[4]x1xf32>
898/// ```
899///
900/// TODO: This transformation isn't specific to SME - move it to the SVE
901/// dialect.
902/// TODO: Check the in_bounds attribute and generate vector.maskedload if
903/// required.
904struct LowerColumnTransferReadToLoops
905 : public OpRewritePattern<vector::TransferReadOp> {
907
908 LogicalResult matchAndRewrite(vector::TransferReadOp readOp,
909 PatternRewriter &rewriter) const override {
910 // NOTE: This is a fairly low-level transformation, so we shouldn't be
911 // adding support for Tensors without good rationale.
912 if (readOp.hasPureTensorSemantics())
913 return rewriter.notifyMatchFailure(
914 readOp, "Tensor semantics are unsupported (either bufferize or "
915 "extend this pattern)");
916
917 auto resType = readOp.getVectorType();
918
919 if (resType.getRank() != 2)
920 return rewriter.notifyMatchFailure(readOp,
921 "Only 2D vectors are supported!");
922
923 if (resType.getShape()[1] != 1)
924 return rewriter.notifyMatchFailure(
925 readOp, "The trailing output dim is != 1 (not supported ATM)");
926
927 if (!resType.getScalableDims()[0] || resType.getScalableDims()[1])
928 return rewriter.notifyMatchFailure(
929 readOp, "Expected the leading dim to be scalable and the trailing "
930 "dim to be fixed.");
931
932 // Create new result type - similar to the original vector with the
933 // trailing unit dim collapsed.
934 int64_t numRows = resType.getShape()[0];
935 VectorType newResType = VectorType::get(numRows, resType.getElementType(),
936 /*scalableDims=*/{true});
937
938 // Create a loop over all rows and load one element at a time.
939 auto loc = readOp.getLoc();
940 auto lowerBound = arith::ConstantIndexOp::create(rewriter, loc, 0);
941 auto createVscaleMultiple =
943 auto upperBound = createVscaleMultiple(numRows);
944 auto step = arith::ConstantIndexOp::create(rewriter, loc, 1);
945 Value init = arith::ConstantOp::create(
946 rewriter, loc, newResType, DenseElementsAttr::get(newResType, 0.0f));
947
948 scf::ForOp loadLoop;
949 {
950 OpBuilder::InsertionGuard g(rewriter);
951 loadLoop = scf::ForOp::create(rewriter, loc, lowerBound, upperBound, step,
952 ValueRange{init});
953 rewriter.setInsertionPointToStart(loadLoop.getBody());
954
955 auto tileSliceIndex = loadLoop.getInductionVar();
956
957 auto idx0 = arith::AddIOp::create(rewriter, loc, tileSliceIndex,
958 readOp.getIndices()[0]);
959 auto idx1 = readOp.getIndices()[1];
960
961 Value scalar = memref::LoadOp::create(rewriter, loc, readOp.getBase(),
962 SmallVector<Value>({idx0, idx1}));
963
964 Operation *updateInit = vector::InsertOp::create(
965 rewriter, loc, scalar, loadLoop.getRegionIterArg(0), tileSliceIndex);
966
967 scf::YieldOp::create(rewriter, loc, updateInit->getResult(0));
968 }
969
970 // The read operation has been "legalized", but since the original result
971 // type was a 2D vector, we need to cast before returning the result. This
972 // ShapeCast should cancel-out with some other ShapeCast (i.e. it's a
973 // no-op).
974 auto sc = vector::ShapeCastOp::create(
975 rewriter, loc, readOp.getResult().getType(), loadLoop.getResult(0));
976
977 rewriter.replaceOp(readOp, sc);
978
979 return success();
980 }
981};
982
983struct VectorLegalizationPass
984 : public arm_sme::impl::VectorLegalizationBase<VectorLegalizationPass> {
985 void runOnOperation() override {
986 auto *context = &getContext();
987 TypeConverter converter;
988 RewritePatternSet patterns(context);
989 converter.addConversion([](Type type) { return type; });
990 converter.addConversion(
991 [](VectorType vectorType,
992 SmallVectorImpl<Type> &types) -> std::optional<LogicalResult> {
993 if (!isMultipleOfSMETileVectorType(vectorType))
994 return std::nullopt;
995 auto smeTileCount = getNumberOfSMETilesForVectorType(vectorType);
996 auto smeTileType =
997 getSMETileTypeForElement(vectorType.getElementType());
998 types = SmallVector<Type>(smeTileCount, smeTileType);
999 return success();
1000 });
1001
1002 // Apply preprocessing patterns.
1003 RewritePatternSet rewritePatterns(context);
1004 rewritePatterns
1005 .add<FoldExtractFromVectorOfSMELikeCreateMasks,
1006 LowerColumnTransferReadToLoops, LiftIllegalVectorTransposeToMemory,
1007 LowerIllegalTransposeStoreViaZA>(context);
1008 if (failed(
1009 applyPatternsGreedily(getOperation(), std::move(rewritePatterns))))
1010 return signalPassFailure();
1011
1012 // Note: These two patterns are added with a high benefit to ensure:
1013 // - Masked outer products are handled before unmasked ones
1014 // - Multi-tile writes are lowered as a store loop (if possible)
1015 patterns.add<LegalizeMaskedVectorOuterProductOpsByDecomposition,
1016 LegalizeMultiTileTransferWriteAsStoreLoop>(converter, context,
1017 /*benefit=*/1024);
1018 patterns.add<LegalizeArithConstantOpsByDecomposition,
1019 LegalizeVectorOuterProductOpsByDecomposition,
1020 LegalizeTransferReadOpsByDecomposition,
1021 LegalizeTransferWriteOpsByDecomposition>(converter, context);
1022 populateFunctionOpInterfaceTypeConversionPattern<func::FuncOp>(patterns,
1023 converter);
1027
1028 ConversionTarget target(getContext());
1029 target.markUnknownOpDynamicallyLegal(
1030 [&](Operation *op) { return converter.isLegal(op); });
1031 target.addDynamicallyLegalOp<func::FuncOp>([&](func::FuncOp op) {
1032 return converter.isSignatureLegal(op.getFunctionType());
1033 });
1034 if (failed(applyPartialConversion(getOperation(), target,
1035 std::move(patterns))))
1036 return signalPassFailure();
1037 }
1038};
1039
1040} // namespace
1041
1043 return std::make_unique<VectorLegalizationPass>();
1044}
return success()
lhs
ArrayAttr()
b getContext())
static Value createMask(AffineForOp vecForOp, VectorizationState &state)
Creates a mask used to filter out garbage elements in the last iteration of unaligned loops.
static AffineMap getPermutationMap(ArrayRef< unsigned > permutation, MLIRContext *context)
Returns an AffineMap representing a permutation.
IntegerAttr getIndexAttr(int64_t value)
Definition Builders.cpp:108
IntegerType getI1Type()
Definition Builders.cpp:53
ArrayAttr getArrayAttr(ArrayRef< Attribute > value)
Definition Builders.cpp:266
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition Location.h:76
This class helps build Operations.
Definition Builders.h:207
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
Definition Builders.h:431
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition Builders.cpp:457
This class represents a single result from folding an operation.
StringAttr getIdentifier() const
Return the name of this operation as a StringAttr.
Value getOperand(unsigned idx)
Definition Operation.h:350
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Definition Operation.h:407
OperationName getName()
The name of an operation is the key identifier for it.
Definition Operation.h:119
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
A range-style iterator that allows for iterating over the offsets of all potential tiles of size tile...
This class provides an abstraction over the different types of ranges over Values.
Definition ValueRange.h:387
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
Type getType() const
Return the type of this value.
Definition Value.h:105
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition Value.cpp:18
static ConstantIndexOp create(OpBuilder &builder, Location location, int64_t value)
Definition ArithOps.cpp:359
VectorType getSMETileTypeForElement(Type elementType)
Creates a vector type for the SME tile of elementType.
Definition Utils.cpp:128
unsigned getSMETileSliceMinNumElts(Type type)
Return minimum number of elements for the given element type in a vector of SVL bits.
Definition Utils.cpp:32
std::unique_ptr< Pass > createVectorLegalizationPass()
Pass that legalizes vectors so they can be lowered to ArmSME.
bool isMultipleOfSMETileVectorType(VectorType vType)
Returns true if vType is a multiple of an SME tile size.
Definition Utils.cpp:111
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition Remarks.h:561
void populateSCFStructuralTypeConversions(const TypeConverter &typeConverter, RewritePatternSet &patterns, PatternBenefit benefit=1)
Similar to populateSCFStructuralTypeConversionsAndLegality but does not populate the conversion targe...
Operation * maskOperation(OpBuilder &builder, Operation *maskableOp, Value mask, Value passthru=Value())
Creates a vector.mask operation around a maskable operation.
auto makeVscaleConstantBuilder(PatternRewriter &rewriter, Location loc)
Returns a functor (int64_t -> Value) which returns a constant vscale multiple.
Include the generated interface declarations.
void populateReturnOpTypeConversionPattern(RewritePatternSet &patterns, const TypeConverter &converter, PatternBenefit benefit=1)
Add a pattern to the given pattern list to rewrite return ops to use operands that have been legalize...
LogicalResult applyPatternsGreedily(Region &region, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr)
Rewrite ops in the given region, which must be isolated from above, by repeatedly applying the highes...
void populateCallOpTypeConversionPattern(RewritePatternSet &patterns, const TypeConverter &converter, PatternBenefit benefit=1)
Add a pattern to the given pattern list to convert the operand and result types of a CallOp with the ...
const FrozenRewritePatternSet & patterns
Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
Definition Utils.cpp:111
SmallVector< Loops, 8 > tile(ArrayRef< scf::ForOp > forOps, ArrayRef< Value > sizes, ArrayRef< scf::ForOp > targets)
Performs tiling fo imperfectly nested loops (with interchange) by strip-mining the forOps by sizes an...
Definition Utils.cpp:1293
void applyPermutationToVector(SmallVector< T, N > &inVec, ArrayRef< int64_t > permutation)
Apply the permutation defined by permutation to inVec.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...