MLIR 22.0.0git
VectorToArmSME.cpp
Go to the documentation of this file.
1//===- VectorToArmSME.cpp - Conversion from Vector to the ArmSME dialect --===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
10
16#include "llvm/Support/Casting.h"
17
18using namespace mlir;
19
20namespace {
21
22/// Conversion pattern for vector.transfer_read.
23///
24/// ---
25///
26/// Example 1: op with identity permutation map to horizontal
27/// arm_sme.tile_load:
28///
29/// vector.transfer_read ... permutation_map: (d0, d1) -> (d0, d1)
30///
31/// is converted to:
32///
33/// arm_sme.tile_load ...
34///
35/// ---
36///
37/// Example 2: op with transpose permutation map to vertical arm_sme.tile_load
38/// (in-flight transpose):
39///
40/// vector.transfer_read ... permutation_map: (d0, d1) -> (d1, d0)
41///
42/// is converted to:
43///
44/// arm_sme.tile_load ... layout<vertical>
45struct TransferReadToArmSMELowering
46 : public OpRewritePattern<vector::TransferReadOp> {
47 using Base::Base;
48
49 LogicalResult matchAndRewrite(vector::TransferReadOp transferReadOp,
50 PatternRewriter &rewriter) const final {
51 // The permutation map must have two results.
52 if (transferReadOp.getTransferRank() != 2)
53 return rewriter.notifyMatchFailure(transferReadOp,
54 "not a 2 result permutation map");
55
56 auto vectorType = transferReadOp.getVectorType();
57 if (!arm_sme::isValidSMETileVectorType(vectorType))
58 return rewriter.notifyMatchFailure(transferReadOp,
59 "not a valid vector type for SME");
60
61 if (!llvm::isa<MemRefType>(transferReadOp.getBase().getType()))
62 return rewriter.notifyMatchFailure(transferReadOp, "not a memref source");
63
64 // Out-of-bounds dims are not supported.
65 if (transferReadOp.hasOutOfBoundsDim())
66 return rewriter.notifyMatchFailure(transferReadOp,
67 "not inbounds transfer read");
68
69 AffineMap map = transferReadOp.getPermutationMap();
70 if (!map.isPermutation())
71 return rewriter.notifyMatchFailure(transferReadOp,
72 "unsupported permutation map");
73
74 // Note: For 2D vector types the only non-identity permutation is a simple
75 // transpose [1, 0].
76 bool transposed = !map.isIdentity();
77 arm_sme::TileSliceLayout layout =
78 transposed ? arm_sme::TileSliceLayout::Vertical
79 : arm_sme::TileSliceLayout::Horizontal;
80
81 // Padding isn't optional for transfer_read, but is only used in the case
82 // of out-of-bounds accesses (not supported here) and/or masking. Mask is
83 // optional, if it's not present don't pass padding.
84 auto mask = transferReadOp.getMask();
85 auto padding = mask ? transferReadOp.getPadding() : nullptr;
86 rewriter.replaceOpWithNewOp<arm_sme::TileLoadOp>(
87 transferReadOp, vectorType, transferReadOp.getBase(),
88 transferReadOp.getIndices(), padding, mask, layout);
89
90 return success();
91 }
92};
93
94/// Conversion pattern for vector.transfer_write.
95///
96/// ---
97///
98/// Example 1: op with identity permutation map to horizontal
99/// arm_sme.tile_store:
100///
101/// vector.transfer_write %vector, %source[%c0, %c0]
102/// {in_bounds = [true, true]} : vector<[16]x[16]xi8>, memref<?x?xi8>
103///
104/// is converted to:
105///
106/// arm_sme.tile_store %vector, %source[%c0, %c0] : memref<?x?xi8>,
107/// vector<[16]x[16]xi8>
108/// ---
109///
110/// Example 2: op with transpose permutation map to vertical arm_sme.tile_store
111/// (in-flight transpose):
112///
113/// vector.transfer_write %vector, %source[%c0, %c0]
114/// {permutation_map = affine_map<(d0, d1) -> (d1, d0)>,
115/// in_bounds = [true, true]} : vector<[16]x[16]xi8>, memref<?x?xi8>
116///
117/// is converted to:
118///
119/// arm_sme.tile_store %vector, %source[%c0, %c0] layout<vertical>
120/// : memref<?x?xi8>, vector<[16]x[16]xi8>
121struct TransferWriteToArmSMELowering
122 : public OpRewritePattern<vector::TransferWriteOp> {
123 using Base::Base;
124
125 LogicalResult matchAndRewrite(vector::TransferWriteOp writeOp,
126 PatternRewriter &rewriter) const final {
127 auto vType = writeOp.getVectorType();
129 return failure();
130
131 if (!llvm::isa<MemRefType>(writeOp.getBase().getType()))
132 return failure();
133
134 // Out-of-bounds dims are not supported.
135 if (writeOp.hasOutOfBoundsDim())
136 return rewriter.notifyMatchFailure(writeOp,
137 "not inbounds transfer write");
138
139 AffineMap map = writeOp.getPermutationMap();
140 if (!map.isPermutation())
141 return rewriter.notifyMatchFailure(writeOp,
142 "unsupported permutation map");
143
144 // Note: For 2D vector types the only non-identity permutation is a simple
145 // transpose [1, 0].
146 bool transposed = !map.isIdentity();
147 arm_sme::TileSliceLayout layout =
148 transposed ? arm_sme::TileSliceLayout::Vertical
149 : arm_sme::TileSliceLayout::Horizontal;
150
151 rewriter.replaceOpWithNewOp<arm_sme::TileStoreOp>(
152 writeOp, writeOp.getVector(), writeOp.getBase(), writeOp.getIndices(),
153 writeOp.getMask(), layout);
154 return success();
155 }
156};
157
158/// Conversion pattern for vector.load.
159struct VectorLoadToArmSMELowering : public OpRewritePattern<vector::LoadOp> {
160 using Base::Base;
161
162 LogicalResult matchAndRewrite(vector::LoadOp load,
163 PatternRewriter &rewriter) const override {
164 if (!arm_sme::isValidSMETileVectorType(load.getVectorType()))
165 return failure();
166
167 rewriter.replaceOpWithNewOp<arm_sme::TileLoadOp>(
168 load, load.getVectorType(), load.getBase(), load.getIndices());
169
170 return success();
171 }
172};
173
174/// Conversion pattern for vector.store.
175struct VectorStoreToArmSMELowering : public OpRewritePattern<vector::StoreOp> {
176 using Base::Base;
177
178 LogicalResult matchAndRewrite(vector::StoreOp store,
179 PatternRewriter &rewriter) const override {
180 if (!arm_sme::isValidSMETileVectorType(store.getVectorType()))
181 return failure();
182
183 rewriter.replaceOpWithNewOp<arm_sme::TileStoreOp>(
184 store, store.getValueToStore(), store.getBase(), store.getIndices());
185
186 return success();
187 }
188};
189
190/// Conversion pattern for vector.broadcast.
191///
192/// Example:
193///
194/// %broadcast_to_tile = vector.broadcast %src : i32 to vector<[4]x[4]xi32>
195///
196/// is converted to:
197///
198/// %broadcast_to_1d = vector.broadcast %src : i32 to vector<[4]xi32>
199/// %broadcast_to_tile = scf.for %tile_slice_index = %c0 to %num_tile_slices
200/// step %c1 iter_args(%iter_tile = %init_tile) -> (vector<[4]x[4]xi32>)
201/// {
202/// %tile_update = arm_sme.insert_tile_slice
203/// %broadcast_to_1d, %iter_tile[%tile_slice_index] :
204/// vector<[4]xi32> into vector<[4]x[4]xi32>
205/// scf.yield %tile_update : vector<[4]x[4]xi32>
206/// }
207///
208/// Supports scalar, 0-d vector, and 1-d vector broadcasts.
209struct BroadcastOpToArmSMELowering
210 : public OpRewritePattern<vector::BroadcastOp> {
211 using Base::Base;
212
213 LogicalResult matchAndRewrite(vector::BroadcastOp broadcastOp,
214 PatternRewriter &rewriter) const final {
215 auto tileType = broadcastOp.getResultVectorType();
216 if (!tileType || !arm_sme::isValidSMETileVectorType(tileType))
217 return failure();
218
219 auto loc = broadcastOp.getLoc();
220
221 auto srcType = broadcastOp.getSourceType();
222 auto srcVectorType = dyn_cast<VectorType>(srcType);
223
224 Value broadcastOp1D;
225 if (srcType.isIntOrFloat() ||
226 (srcVectorType && (srcVectorType.getRank() == 0))) {
227 // Broadcast scalar or 0-d vector to 1-d vector.
228 VectorType tileSliceType = VectorType::Builder(tileType).dropDim(0);
229 broadcastOp1D = vector::BroadcastOp::create(rewriter, loc, tileSliceType,
230 broadcastOp.getSource());
231 } else if (srcVectorType && (srcVectorType.getRank() == 1))
232 // Value to broadcast is already a 1-d vector, nothing to do.
233 broadcastOp1D = broadcastOp.getSource();
234 else
235 return failure();
236
237 auto initTile = arm_sme::GetTileOp::create(rewriter, loc, tileType);
238
239 auto makeLoopBody = [&](OpBuilder &b, Location loc, Value tileSliceIndex,
240 Value currentTile) {
241 // Create 'arm_sme.insert_tile_slice' to broadcast the value
242 // to each tile slice.
243 auto nextTile = arm_sme::InsertTileSliceOp::create(
244 b, loc, tileType, broadcastOp1D, currentTile, tileSliceIndex);
245 return nextTile.getResult();
246 };
247
248 // Create a loop over ZA tile slices.
249 auto forOp =
250 createLoopOverTileSlices(rewriter, loc, initTile, makeLoopBody);
251
252 rewriter.replaceOp(broadcastOp, forOp.getResult(0));
253
254 return success();
255 }
256};
257
258/// Conversion pattern for vector.transpose.
259///
260/// Stores the input tile to memory and reloads vertically.
261///
262/// Example:
263///
264/// %transposed_src = vector.transpose %src, [1, 0]
265/// : vector<[4]x[4]xi32> to vector<[4]x[4]xi32>
266///
267/// is converted to:
268///
269/// %alloca = memref.alloca(%svl_s, %svl_s) : memref<?x?xi32>
270/// %arm_sme.tile_store %src, <hor>, %alloca[%c0, %c0]
271/// : memref<?x?xi32>, vector<[4]x[4]xi32>
272/// %transposed_src = arm_sme.tile_load %alloca[%c0, %c0]
273/// layout<vertical> : memref<?x?xi32>, vector<[4]x[4]xi32>
274///
275/// NOTE: Transposing via memory is obviously expensive, the current intention
276/// is to avoid the transpose if possible, this is therefore intended as a
277/// fallback and to provide base support for Vector ops. If it turns out
278/// transposes can't be avoided then this should be replaced with a more optimal
279/// implementation, perhaps with tile <-> vector (MOVA) ops.
280struct TransposeOpToArmSMELowering
281 : public OpRewritePattern<vector::TransposeOp> {
282 using Base::Base;
283
284 LogicalResult matchAndRewrite(vector::TransposeOp transposeOp,
285 PatternRewriter &rewriter) const final {
286 auto tileType = transposeOp.getResultVectorType();
287 if (!tileType || !arm_sme::isValidSMETileVectorType(tileType))
288 return failure();
289
290 // Bail unless this is a true 2-D matrix transpose.
291 ArrayRef<int64_t> permutation = transposeOp.getPermutation();
292 if (permutation[0] != 1 || permutation[1] != 0)
293 return failure();
294
295 auto loc = transposeOp.getLoc();
296 Value input = transposeOp.getVector();
297
298 if (auto xferOp = input.getDefiningOp<vector::TransferReadOp>();
299 xferOp && xferOp->hasOneUse()) {
300 // Fold transpose into transfer_read to enable in-flight transpose when
301 // converting to arm_sme.tile_load.
302 rewriter.modifyOpInPlace(xferOp, [&]() {
303 xferOp->setAttr(xferOp.getPermutationMapAttrName(),
304 AffineMapAttr::get(AffineMap::getPermutationMap(
305 permutation, transposeOp.getContext())));
306 });
307 rewriter.replaceOp(transposeOp, xferOp);
308 return success();
309 }
310
311 // Allocate buffer to store input tile to.
312 Value vscale =
313 vector::VectorScaleOp::create(rewriter, loc, rewriter.getIndexType());
314 Value minTileSlices = arith::ConstantOp::create(
315 rewriter, loc, rewriter.getIndexAttr(tileType.getDimSize(0)));
316 Value c0 =
317 arith::ConstantOp::create(rewriter, loc, rewriter.getIndexAttr(0));
318 Value numTileSlices =
319 arith::MulIOp::create(rewriter, loc, vscale, minTileSlices);
320 auto bufferType =
321 MemRefType::get({ShapedType::kDynamic, ShapedType::kDynamic},
322 tileType.getElementType());
323 auto buffer = memref::AllocaOp::create(
324 rewriter, loc, bufferType, ValueRange{numTileSlices, numTileSlices});
325
326 // Store input tile.
327 auto tileStoreOp = arm_sme::TileStoreOp::create(rewriter, loc, input,
328 buffer, ValueRange{c0, c0});
329
330 // Reload input tile vertically.
331 rewriter.replaceOpWithNewOp<arm_sme::TileLoadOp>(
332 transposeOp, tileType, tileStoreOp.getBase(), tileStoreOp.getIndices(),
333 arm_sme::TileSliceLayout::Vertical);
334
335 return success();
336 }
337};
338
339/// Conversion pattern for vector.outerproduct.
340///
341/// If the vector.outerproduct is masked (and the mask is from a
342/// vector.create_mask), then the mask is decomposed into two 1-D masks for the
343/// operands.
344///
345/// Example:
346///
347/// %mask = vector.create_mask %dimA, %dimB : vector<[4]x[4]xi1>
348/// %result = vector.mask %mask {
349/// vector.outerproduct %vecA, %vecB
350/// : vector<[4]xf32>, vector<[4]xf32>
351/// } : vector<[4]x[4]xi1> -> vector<[4]x[4]xf32>
352///
353/// is converted to:
354///
355/// %maskA = vector.create_mask %dimA : vector<[4]xi1>
356/// %maskB = vector.create_mask %dimB : vector<[4]xi1>
357/// %result = arm_sme.outerproduct %vecA, %vecB masks(%maskA, %maskB)
358/// : vector<[4]xf32>, vector<[4]xf32>
359///
360/// Unmasked outerproducts can be directly replaced with the arm_sme op.
361///
362/// Example:
363///
364/// %result = vector.outerproduct %vecA, %vecB
365/// : vector<[4]xf32>, vector<[4]xf32>
366///
367/// is converted to:
368///
369/// %result = arm_sme.outerproduct %vecA, %vecB
370/// : vector<[4]xf32>, vector<[4]xf32>
371///
372struct VectorOuterProductToArmSMELowering
373 : public OpRewritePattern<vector::OuterProductOp> {
374
375 using Base::Base;
376
377 LogicalResult matchAndRewrite(vector::OuterProductOp outerProductOp,
378 PatternRewriter &rewriter) const override {
379
380 // We don't yet support lowering AXPY operations to SME. These could be
381 // lowered by masking out all but the first element of the LHS.
382 if (!isa<VectorType>(outerProductOp.getOperandTypeRHS()))
383 return rewriter.notifyMatchFailure(outerProductOp,
384 "AXPY operations not supported");
385
387 outerProductOp.getResultVectorType()))
388 return rewriter.notifyMatchFailure(
389 outerProductOp, "outer product does not fit into SME tile");
390
391 auto kind = outerProductOp.getKind();
392 if (kind != vector::CombiningKind::ADD)
393 return rewriter.notifyMatchFailure(
394 outerProductOp,
395 "unsupported kind (lowering to SME only supports ADD at the moment)");
396
397 Value lhsMask = {};
398 Value rhsMask = {};
399 Operation *rootOp = outerProductOp;
400 auto loc = outerProductOp.getLoc();
401 if (outerProductOp.isMasked()) {
402 auto maskOp = outerProductOp.getMaskingOp();
403 rewriter.setInsertionPoint(maskOp);
404 rootOp = maskOp;
405 auto operandMasks = decomposeResultMask(loc, maskOp.getMask(), rewriter);
406 if (failed(operandMasks))
407 return failure();
408 std::tie(lhsMask, rhsMask) = *operandMasks;
409 }
410
411 rewriter.replaceOpWithNewOp<arm_sme::OuterProductOp>(
412 rootOp, outerProductOp.getResultVectorType(), outerProductOp.getLhs(),
413 outerProductOp.getRhs(), lhsMask, rhsMask, outerProductOp.getAcc());
414
415 return success();
416 }
417
418 static FailureOr<std::pair<Value, Value>>
419 decomposeResultMask(Location loc, Value mask, PatternRewriter &rewriter) {
420 // Attempt to extract masks from vector.create_mask.
421 // TODO: Add support for other mask sources.
422 auto createMaskOp = mask.getDefiningOp<vector::CreateMaskOp>();
423 if (!createMaskOp)
424 return failure();
425
426 auto maskType = createMaskOp.getVectorType();
427 Value lhsMaskDim = createMaskOp.getOperand(0);
428 Value rhsMaskDim = createMaskOp.getOperand(1);
429
430 VectorType operandMaskType = VectorType::Builder(maskType).dropDim(0);
431 Value lhsMask = vector::CreateMaskOp::create(rewriter, loc, operandMaskType,
432 lhsMaskDim);
433 Value rhsMask = vector::CreateMaskOp::create(rewriter, loc, operandMaskType,
434 rhsMaskDim);
435
436 return std::make_pair(lhsMask, rhsMask);
437 }
438};
439
440/// Lower `vector.extract` using `arm_sme.extract_tile_slice`.
441///
442/// Example:
443/// ```
444/// %el = vector.extract %tile[%row, %col]: i32 from vector<[4]x[4]xi32>
445/// ```
446/// Becomes:
447/// ```
448/// %slice = arm_sme.extract_tile_slice %tile[%row]
449/// : vector<[4]xi32> from vector<[4]x[4]xi32>
450/// %el = vector.extract %slice[%col] : i32 from vector<[4]xi32>
451/// ```
452struct VectorExtractToArmSMELowering
453 : public OpRewritePattern<vector::ExtractOp> {
454 using Base::Base;
455
456 LogicalResult matchAndRewrite(vector::ExtractOp extractOp,
457 PatternRewriter &rewriter) const override {
458 VectorType sourceType = extractOp.getSourceVectorType();
459 if (!arm_sme::isValidSMETileVectorType(sourceType))
460 return failure();
461
462 auto loc = extractOp.getLoc();
463 auto position = extractOp.getMixedPosition();
464
465 Value sourceVector = extractOp.getSource();
466
467 // Extract entire vector. Should be handled by folder, but just to be safe.
468 if (position.empty()) {
469 rewriter.replaceOp(extractOp, sourceVector);
470 return success();
471 }
472
473 Value sliceIndex = vector::getAsValues(rewriter, loc, position[0]).front();
474 auto extractTileSlice = arm_sme::ExtractTileSliceOp::create(
475 rewriter, loc, sourceVector, sliceIndex);
476
477 if (position.size() == 1) {
478 // Single index case: Extracts a 1D slice.
479 rewriter.replaceOp(extractOp, extractTileSlice);
480 return success();
481 }
482
483 // Two indices case: Extracts a single element.
484 assert(position.size() == 2);
485 rewriter.replaceOpWithNewOp<vector::ExtractOp>(extractOp, extractTileSlice,
486 position[1]);
487
488 return success();
489 }
490};
491
492/// Lower `vector.insert` using `arm_sme.insert_tile_slice` and
493/// `arm_sme.extract_tile_slice`.
494///
495/// Example:
496/// ```
497/// %new_tile = vector.insert %el, %tile[%row, %col]
498/// : i32 into vector<[4]x[4]xi32>
499/// ```
500/// Becomes:
501/// ```
502/// %slice = arm_sme.extract_tile_slice %tile[%row]
503/// : vector<[4]xi32> from vector<[4]x[4]xi32>
504/// %new_slice = vector.insert %el, %slice[%col] : i32 into vector<[4]xi32>
505/// %new_tile = arm_sme.insert_tile_slice %new_slice, %tile[%row]
506/// : vector<[4]xi32> into vector<[4]x[4]xi32>
507/// ```
508struct VectorInsertToArmSMELowering
509 : public OpRewritePattern<vector::InsertOp> {
510 using Base::Base;
511
512 LogicalResult matchAndRewrite(vector::InsertOp insertOp,
513 PatternRewriter &rewriter) const override {
514 VectorType resultType = insertOp.getResult().getType();
515
516 if (!arm_sme::isValidSMETileVectorType(resultType))
517 return failure();
518
519 auto loc = insertOp.getLoc();
520 auto position = insertOp.getMixedPosition();
521
522 Value source = insertOp.getValueToStore();
523
524 // Overwrite entire vector with value. Should be handled by folder, but
525 // just to be safe.
526 if (position.empty()) {
527 rewriter.replaceOp(insertOp, source);
528 return success();
529 }
530
531 Value tileSlice = source;
532 Value sliceIndex = vector::getAsValues(rewriter, loc, position[0]).front();
533 if (position.size() == 2) {
534 // Two indices case: Insert single element into tile.
535 // We need to first extract the existing slice and update the element.
536 tileSlice = arm_sme::ExtractTileSliceOp::create(
537 rewriter, loc, insertOp.getDest(), sliceIndex);
538 tileSlice = vector::InsertOp::create(rewriter, loc, source, tileSlice,
539 position[1]);
540 }
541
542 // Insert the slice into the destination tile.
543 rewriter.replaceOpWithNewOp<arm_sme::InsertTileSliceOp>(
544 insertOp, tileSlice, insertOp.getDest(), sliceIndex);
545 return success();
546 }
547};
548
549/// Lowers `vector.print` of a tile into a loop over the rows of the tile,
550/// extracting them via `arm_sme.extract_tile_slice`, then printing with
551/// a 1D `vector.print`.
552///
553/// BEFORE:
554/// ```mlir
555/// vector.print %tile : vector<[4]x[4]xf32>
556/// ```
557/// AFTER:
558/// ```mlir
559/// %c0 = arith.constant 0 : index
560/// %c1 = arith.constant 1 : index
561/// %c4 = arith.constant 4 : index
562/// %vscale = vector.vscale
563/// %svl_s = arith.muli %c4, %vscale : index
564/// scf.for %i = %c0 to %svl_s step %c1 {
565/// %tile_slice = arm_sme.extract_tile_slice %tile[%i]
566/// : vector<[4]xf32> from vector<[4]x[4]xf32>
567/// vector.print %tile_slice : vector<[4]xf32>
568/// }
569/// ```
570struct VectorPrintToArmSMELowering : public OpRewritePattern<vector::PrintOp> {
571 using Base::Base;
572
573 LogicalResult matchAndRewrite(vector::PrintOp printOp,
574 PatternRewriter &rewriter) const override {
575 if (!printOp.getSource())
576 return failure();
577
578 VectorType vectorType = dyn_cast<VectorType>(printOp.getPrintType());
579 if (!vectorType || !arm_sme::isValidSMETileVectorType(vectorType))
580 return failure();
581
582 auto loc = printOp.getLoc();
583
584 // Create a loop over the rows of the tile.
585 auto vscale = vector::VectorScaleOp::create(rewriter, loc);
586 auto minTileRows =
587 arith::ConstantIndexOp::create(rewriter, loc, vectorType.getDimSize(0));
588 auto lowerBound = arith::ConstantIndexOp::create(rewriter, loc, 0);
589 auto upperBound = arith::MulIOp::create(rewriter, loc, minTileRows, vscale);
590 auto step = arith::ConstantIndexOp::create(rewriter, loc, 1);
591 auto forOp =
592 scf::ForOp::create(rewriter, loc, lowerBound, upperBound, step);
593 {
594 // Loop body.
595 rewriter.setInsertionPointToStart(forOp.getBody());
596 // Extract the current row from the tile.
597 Value rowIndex = forOp.getInductionVar();
598 auto tileSlice = arm_sme::ExtractTileSliceOp::create(
599 rewriter, loc, printOp.getSource(), rowIndex);
600 // Print the row with a 1D vector.print.
601 vector::PrintOp::create(rewriter, loc, tileSlice,
602 printOp.getPunctuation());
603 }
604
605 rewriter.eraseOp(printOp);
606 return success();
607 }
608};
609
610/// Folds a ExtractTileSliceOp + TransferWriteOp to a StoreTileSliceOp.
611///
612/// BEFORE:
613/// ```mlir
614/// %slice = arm_sme.extract_tile_slice %tile[%index]
615/// : vector<[4]xf32> from vector<[4]x[4]xf32>
616/// vector.transfer_write %slice, %memref[%i, %j], %mask {in_bounds = [true]}
617/// : vector<[4]xf32>, memref<?x?xf32>
618/// ```
619/// AFTER:
620/// ```mlir
621/// arm_sme.store_tile_slice %tile, %index, %mask, %memref[%i, %j]
622/// : memref<?x?xf32>, vector<[4]xi1>, vector<[4]x[4]xf32>
623/// ```
624struct FoldTransferWriteOfExtractTileSlice
625 : public OpRewritePattern<vector::TransferWriteOp> {
626 using Base::Base;
627
628 LogicalResult matchAndRewrite(vector::TransferWriteOp writeOp,
629 PatternRewriter &rewriter) const final {
630 if (!isa<MemRefType>(writeOp.getBase().getType()))
631 return rewriter.notifyMatchFailure(writeOp, "destination not a memref");
632
633 if (writeOp.hasOutOfBoundsDim())
634 return rewriter.notifyMatchFailure(writeOp,
635 "not inbounds transfer write");
636
637 auto extractTileSlice =
638 writeOp.getVector().getDefiningOp<arm_sme::ExtractTileSliceOp>();
639 if (!extractTileSlice)
640 return rewriter.notifyMatchFailure(
641 writeOp, "vector to store not from ExtractTileSliceOp");
642
643 AffineMap map = writeOp.getPermutationMap();
644 if (!map.isMinorIdentity())
645 return rewriter.notifyMatchFailure(writeOp,
646 "unsupported permutation map");
647
648 Value mask = writeOp.getMask();
649 if (!mask) {
650 auto maskType = writeOp.getVectorType().clone(rewriter.getI1Type());
651 mask = arith::ConstantOp::create(rewriter, writeOp.getLoc(), maskType,
652 DenseElementsAttr::get(maskType, true));
653 }
654
655 rewriter.replaceOpWithNewOp<arm_sme::StoreTileSliceOp>(
656 writeOp, extractTileSlice.getTile(),
657 extractTileSlice.getTileSliceIndex(), mask, writeOp.getBase(),
658 writeOp.getIndices(), extractTileSlice.getLayout());
659 return success();
660 }
661};
662
663/// Lower a `vector.extract` from a 2-D scalable `vector.create_mask` to
664/// `arm_sve.psel`. Note: While psel is under ArmSVE it requires SME (or
665/// SVE 2.1), so this is currently the most logical place for this lowering.
666///
667/// Example:
668/// ```mlir
669/// %mask = vector.create_mask %a, %b : vector<[4]x[8]xi1>
670/// %slice = vector.extract %mask[%index]
671/// : vector<[8]xi1> from vector<[4]x[8]xi1>
672/// ```
673/// Becomes:
674/// ```
675/// %mask_rows = vector.create_mask %a : vector<[4]xi1>
676/// %mask_cols = vector.create_mask %b : vector<[8]xi1>
677/// %slice = arm_sve.psel %mask_cols, %mask_rows[%index]
678/// : vector<[8]xi1>, vector<[4]xi1>
679/// ```
680struct ExtractFromCreateMaskToPselLowering
681 : public OpRewritePattern<vector::ExtractOp> {
682 using Base::Base;
683
684 LogicalResult matchAndRewrite(vector::ExtractOp extractOp,
685 PatternRewriter &rewriter) const override {
686 if (extractOp.getNumIndices() != 1)
687 return rewriter.notifyMatchFailure(extractOp, "not single extract index");
688
689 auto resultType = extractOp.getResult().getType();
690 auto resultVectorType = dyn_cast<VectorType>(resultType);
691 if (!resultVectorType)
692 return rewriter.notifyMatchFailure(extractOp, "result not VectorType");
693
694 auto createMaskOp =
695 extractOp.getSource().getDefiningOp<vector::CreateMaskOp>();
696 if (!createMaskOp)
697 return rewriter.notifyMatchFailure(extractOp, "source not CreateMaskOp");
698
699 auto maskType = createMaskOp.getVectorType();
700 if (maskType.getRank() != 2 || !maskType.allDimsScalable())
701 return rewriter.notifyMatchFailure(createMaskOp, "not 2-D scalable mask");
702
703 auto isSVEPredicateSize = [](int64_t size) {
704 return size > 0 && size <= 16 && llvm::isPowerOf2_32(uint32_t(size));
705 };
706
707 auto rowsBaseSize = maskType.getDimSize(0);
708 auto colsBaseSize = maskType.getDimSize(1);
709 if (!isSVEPredicateSize(rowsBaseSize) || !isSVEPredicateSize(colsBaseSize))
710 return rewriter.notifyMatchFailure(
711 createMaskOp, "mask dimensions not SVE predicate-sized");
712
713 auto loc = extractOp.getLoc();
714 VectorType rowMaskType = VectorType::Builder(maskType).dropDim(1);
715 VectorType colMaskType = VectorType::Builder(maskType).dropDim(0);
716
717 // Create the two 1-D masks at the location of the 2-D create_mask (which is
718 // usually outside a loop). This prevents the need for later hoisting.
719 rewriter.setInsertionPoint(createMaskOp);
720 auto rowMask = vector::CreateMaskOp::create(rewriter, loc, rowMaskType,
721 createMaskOp.getOperand(0));
722 auto colMask = vector::CreateMaskOp::create(rewriter, loc, colMaskType,
723 createMaskOp.getOperand(1));
724
725 rewriter.setInsertionPoint(extractOp);
726 auto position =
727 vector::getAsValues(rewriter, loc, extractOp.getMixedPosition());
728 rewriter.replaceOpWithNewOp<arm_sve::PselOp>(extractOp, colMask, rowMask,
729 position[0]);
730 return success();
731 }
732};
733
734} // namespace
735
737 MLIRContext &ctx) {
738 patterns.add<BroadcastOpToArmSMELowering, TransferReadToArmSMELowering,
739 TransferWriteToArmSMELowering, TransposeOpToArmSMELowering,
740 VectorLoadToArmSMELowering, VectorStoreToArmSMELowering,
741 VectorOuterProductToArmSMELowering,
742 VectorExtractToArmSMELowering, VectorInsertToArmSMELowering,
743 VectorPrintToArmSMELowering, FoldTransferWriteOfExtractTileSlice,
744 ExtractFromCreateMaskToPselLowering>(&ctx);
745}
return success()
b
Return true if permutation is a valid permutation of the outer_dims_perm (case OuterOrInnerPerm::Oute...
auto load
static void printOp(llvm::raw_ostream &os, Operation *op, OpPrintingFlags &flags)
Definition Unit.cpp:18
bool isMinorIdentity() const
Returns true if this affine map is a minor identity, i.e.
static AffineMap getPermutationMap(ArrayRef< unsigned > permutation, MLIRContext *context)
Returns an AffineMap representing a permutation.
bool isIdentity() const
Returns true if this affine map is an identity affine map.
bool isPermutation() const
Returns true if the AffineMap represents a symbol-less permutation map.
IntegerAttr getIndexAttr(int64_t value)
Definition Builders.cpp:108
IntegerType getI1Type()
Definition Builders.cpp:53
IndexType getIndexType()
Definition Builders.cpp:51
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
MLIRContext is the top-level object for a collection of MLIR operations.
Definition MLIRContext.h:63
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
Definition Builders.h:431
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Definition Builders.h:398
bool hasOneUse()
Returns true if this operation has exactly one use.
Definition Operation.h:849
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition Value.cpp:18
static ConstantIndexOp create(OpBuilder &builder, Location location, int64_t value)
Definition ArithOps.cpp:359
scf::ForOp createLoopOverTileSlices(PatternRewriter &rewriter, Location loc, Value initTile, std::function< Value(OpBuilder &, Location, Value, Value)> makeLoopBody)
Generates a for loop over ZA tile slices where the induction variable is the tile slice index and eac...
Definition Utils.cpp:89
bool isValidSMETileVectorType(VectorType vType)
Returns true if vType is a valid vector type for an SME tile or false otherwise.
Definition Utils.cpp:43
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition Remarks.h:561
SmallVector< Value > getAsValues(OpBuilder &builder, Location loc, ArrayRef< OpFoldResult > foldResults)
Convert foldResults into Values.
Include the generated interface declarations.
const FrozenRewritePatternSet & patterns
void populateVectorToArmSMEPatterns(RewritePatternSet &patterns, MLIRContext &ctx)
Collect a set of patterns to lower Vector ops to ArmSME ops that map to LLVM intrinsics.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...