MLIR  22.0.0git
VectorToArmSME.cpp
Go to the documentation of this file.
1 //===- VectorToArmSME.cpp - Conversion from Vector to the ArmSME dialect --===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
10 
15 #include "mlir/IR/BuiltinTypes.h"
16 #include "llvm/Support/Casting.h"
17 
18 using namespace mlir;
19 
20 namespace {
21 
22 /// Conversion pattern for vector.transfer_read.
23 ///
24 /// ---
25 ///
26 /// Example 1: op with identity permutation map to horizontal
27 /// arm_sme.tile_load:
28 ///
29 /// vector.transfer_read ... permutation_map: (d0, d1) -> (d0, d1)
30 ///
31 /// is converted to:
32 ///
33 /// arm_sme.tile_load ...
34 ///
35 /// ---
36 ///
37 /// Example 2: op with transpose permutation map to vertical arm_sme.tile_load
38 /// (in-flight transpose):
39 ///
40 /// vector.transfer_read ... permutation_map: (d0, d1) -> (d1, d0)
41 ///
42 /// is converted to:
43 ///
44 /// arm_sme.tile_load ... layout<vertical>
45 struct TransferReadToArmSMELowering
46  : public OpRewritePattern<vector::TransferReadOp> {
48 
49  LogicalResult matchAndRewrite(vector::TransferReadOp transferReadOp,
50  PatternRewriter &rewriter) const final {
51  // The permutation map must have two results.
52  if (transferReadOp.getTransferRank() != 2)
53  return rewriter.notifyMatchFailure(transferReadOp,
54  "not a 2 result permutation map");
55 
56  auto vectorType = transferReadOp.getVectorType();
57  if (!arm_sme::isValidSMETileVectorType(vectorType))
58  return rewriter.notifyMatchFailure(transferReadOp,
59  "not a valid vector type for SME");
60 
61  if (!llvm::isa<MemRefType>(transferReadOp.getBase().getType()))
62  return rewriter.notifyMatchFailure(transferReadOp, "not a memref source");
63 
64  // Out-of-bounds dims are not supported.
65  if (transferReadOp.hasOutOfBoundsDim())
66  return rewriter.notifyMatchFailure(transferReadOp,
67  "not inbounds transfer read");
68 
69  AffineMap map = transferReadOp.getPermutationMap();
70  if (!map.isPermutation())
71  return rewriter.notifyMatchFailure(transferReadOp,
72  "unsupported permutation map");
73 
74  // Note: For 2D vector types the only non-identity permutation is a simple
75  // transpose [1, 0].
76  bool transposed = !map.isIdentity();
77  arm_sme::TileSliceLayout layout =
78  transposed ? arm_sme::TileSliceLayout::Vertical
79  : arm_sme::TileSliceLayout::Horizontal;
80 
81  // Padding isn't optional for transfer_read, but is only used in the case
82  // of out-of-bounds accesses (not supported here) and/or masking. Mask is
83  // optional, if it's not present don't pass padding.
84  auto mask = transferReadOp.getMask();
85  auto padding = mask ? transferReadOp.getPadding() : nullptr;
86  rewriter.replaceOpWithNewOp<arm_sme::TileLoadOp>(
87  transferReadOp, vectorType, transferReadOp.getBase(),
88  transferReadOp.getIndices(), padding, mask, layout);
89 
90  return success();
91  }
92 };
93 
94 /// Conversion pattern for vector.transfer_write.
95 ///
96 /// ---
97 ///
98 /// Example 1: op with identity permutation map to horizontal
99 /// arm_sme.tile_store:
100 ///
101 /// vector.transfer_write %vector, %source[%c0, %c0]
102 /// {in_bounds = [true, true]} : vector<[16]x[16]xi8>, memref<?x?xi8>
103 ///
104 /// is converted to:
105 ///
106 /// arm_sme.tile_store %vector, %source[%c0, %c0] : memref<?x?xi8>,
107 /// vector<[16]x[16]xi8>
108 /// ---
109 ///
110 /// Example 2: op with transpose permutation map to vertical arm_sme.tile_store
111 /// (in-flight transpose):
112 ///
113 /// vector.transfer_write %vector, %source[%c0, %c0]
114 /// {permutation_map = affine_map<(d0, d1) -> (d1, d0)>,
115 /// in_bounds = [true, true]} : vector<[16]x[16]xi8>, memref<?x?xi8>
116 ///
117 /// is converted to:
118 ///
119 /// arm_sme.tile_store %vector, %source[%c0, %c0] layout<vertical>
120 /// : memref<?x?xi8>, vector<[16]x[16]xi8>
121 struct TransferWriteToArmSMELowering
122  : public OpRewritePattern<vector::TransferWriteOp> {
124 
125  LogicalResult matchAndRewrite(vector::TransferWriteOp writeOp,
126  PatternRewriter &rewriter) const final {
127  auto vType = writeOp.getVectorType();
129  return failure();
130 
131  if (!llvm::isa<MemRefType>(writeOp.getBase().getType()))
132  return failure();
133 
134  // Out-of-bounds dims are not supported.
135  if (writeOp.hasOutOfBoundsDim())
136  return rewriter.notifyMatchFailure(writeOp,
137  "not inbounds transfer write");
138 
139  AffineMap map = writeOp.getPermutationMap();
140  if (!map.isPermutation())
141  return rewriter.notifyMatchFailure(writeOp,
142  "unsupported permutation map");
143 
144  // Note: For 2D vector types the only non-identity permutation is a simple
145  // transpose [1, 0].
146  bool transposed = !map.isIdentity();
147  arm_sme::TileSliceLayout layout =
148  transposed ? arm_sme::TileSliceLayout::Vertical
149  : arm_sme::TileSliceLayout::Horizontal;
150 
151  rewriter.replaceOpWithNewOp<arm_sme::TileStoreOp>(
152  writeOp, writeOp.getVector(), writeOp.getBase(), writeOp.getIndices(),
153  writeOp.getMask(), layout);
154  return success();
155  }
156 };
157 
158 /// Conversion pattern for vector.load.
159 struct VectorLoadToArmSMELowering : public OpRewritePattern<vector::LoadOp> {
161 
162  LogicalResult matchAndRewrite(vector::LoadOp load,
163  PatternRewriter &rewriter) const override {
164  if (!arm_sme::isValidSMETileVectorType(load.getVectorType()))
165  return failure();
166 
167  rewriter.replaceOpWithNewOp<arm_sme::TileLoadOp>(
168  load, load.getVectorType(), load.getBase(), load.getIndices());
169 
170  return success();
171  }
172 };
173 
174 /// Conversion pattern for vector.store.
175 struct VectorStoreToArmSMELowering : public OpRewritePattern<vector::StoreOp> {
177 
178  LogicalResult matchAndRewrite(vector::StoreOp store,
179  PatternRewriter &rewriter) const override {
180  if (!arm_sme::isValidSMETileVectorType(store.getVectorType()))
181  return failure();
182 
183  rewriter.replaceOpWithNewOp<arm_sme::TileStoreOp>(
184  store, store.getValueToStore(), store.getBase(), store.getIndices());
185 
186  return success();
187  }
188 };
189 
190 /// Conversion pattern for vector.broadcast.
191 ///
192 /// Example:
193 ///
194 /// %broadcast_to_tile = vector.broadcast %src : i32 to vector<[4]x[4]xi32>
195 ///
196 /// is converted to:
197 ///
198 /// %broadcast_to_1d = vector.broadcast %src : i32 to vector<[4]xi32>
199 /// %broadcast_to_tile = scf.for %tile_slice_index = %c0 to %num_tile_slices
200 /// step %c1 iter_args(%iter_tile = %init_tile) -> (vector<[4]x[4]xi32>)
201 /// {
202 /// %tile_update = arm_sme.insert_tile_slice
203 /// %broadcast_to_1d, %iter_tile[%tile_slice_index] :
204 /// vector<[4]xi32> into vector<[4]x[4]xi32>
205 /// scf.yield %tile_update : vector<[4]x[4]xi32>
206 /// }
207 ///
208 /// Supports scalar, 0-d vector, and 1-d vector broadcasts.
209 struct BroadcastOpToArmSMELowering
210  : public OpRewritePattern<vector::BroadcastOp> {
212 
213  LogicalResult matchAndRewrite(vector::BroadcastOp broadcastOp,
214  PatternRewriter &rewriter) const final {
215  auto tileType = broadcastOp.getResultVectorType();
216  if (!tileType || !arm_sme::isValidSMETileVectorType(tileType))
217  return failure();
218 
219  auto loc = broadcastOp.getLoc();
220 
221  auto srcType = broadcastOp.getSourceType();
222  auto srcVectorType = dyn_cast<VectorType>(srcType);
223 
224  Value broadcastOp1D;
225  if (srcType.isIntOrFloat() ||
226  (srcVectorType && (srcVectorType.getRank() == 0))) {
227  // Broadcast scalar or 0-d vector to 1-d vector.
228  VectorType tileSliceType = VectorType::Builder(tileType).dropDim(0);
229  broadcastOp1D = vector::BroadcastOp::create(rewriter, loc, tileSliceType,
230  broadcastOp.getSource());
231  } else if (srcVectorType && (srcVectorType.getRank() == 1))
232  // Value to broadcast is already a 1-d vector, nothing to do.
233  broadcastOp1D = broadcastOp.getSource();
234  else
235  return failure();
236 
237  auto initTile = arm_sme::GetTileOp::create(rewriter, loc, tileType);
238 
239  auto makeLoopBody = [&](OpBuilder &b, Location loc, Value tileSliceIndex,
240  Value currentTile) {
241  // Create 'arm_sme.insert_tile_slice' to broadcast the value
242  // to each tile slice.
243  auto nextTile = arm_sme::InsertTileSliceOp::create(
244  b, loc, tileType, broadcastOp1D, currentTile, tileSliceIndex);
245  return nextTile.getResult();
246  };
247 
248  // Create a loop over ZA tile slices.
249  auto forOp =
250  createLoopOverTileSlices(rewriter, loc, initTile, makeLoopBody);
251 
252  rewriter.replaceOp(broadcastOp, forOp.getResult(0));
253 
254  return success();
255  }
256 };
257 
258 /// Conversion pattern for vector.transpose.
259 ///
260 /// Stores the input tile to memory and reloads vertically.
261 ///
262 /// Example:
263 ///
264 /// %transposed_src = vector.transpose %src, [1, 0]
265 /// : vector<[4]x[4]xi32> to vector<[4]x[4]xi32>
266 ///
267 /// is converted to:
268 ///
269 /// %alloca = memref.alloca(%svl_s, %svl_s) : memref<?x?xi32>
270 /// %arm_sme.tile_store %src, <hor>, %alloca[%c0, %c0]
271 /// : memref<?x?xi32>, vector<[4]x[4]xi32>
272 /// %transposed_src = arm_sme.tile_load %alloca[%c0, %c0]
273 /// layout<vertical> : memref<?x?xi32>, vector<[4]x[4]xi32>
274 ///
275 /// NOTE: Transposing via memory is obviously expensive, the current intention
276 /// is to avoid the transpose if possible, this is therefore intended as a
277 /// fallback and to provide base support for Vector ops. If it turns out
278 /// transposes can't be avoided then this should be replaced with a more optimal
279 /// implementation, perhaps with tile <-> vector (MOVA) ops.
280 struct TransposeOpToArmSMELowering
281  : public OpRewritePattern<vector::TransposeOp> {
283 
284  LogicalResult matchAndRewrite(vector::TransposeOp transposeOp,
285  PatternRewriter &rewriter) const final {
286  auto tileType = transposeOp.getResultVectorType();
287  if (!tileType || !arm_sme::isValidSMETileVectorType(tileType))
288  return failure();
289 
290  // Bail unless this is a true 2-D matrix transpose.
291  ArrayRef<int64_t> permutation = transposeOp.getPermutation();
292  if (permutation[0] != 1 || permutation[1] != 0)
293  return failure();
294 
295  auto loc = transposeOp.getLoc();
296  Value input = transposeOp.getVector();
297 
298  if (auto xferOp = input.getDefiningOp<vector::TransferReadOp>();
299  xferOp && xferOp->hasOneUse()) {
300  // Fold transpose into transfer_read to enable in-flight transpose when
301  // converting to arm_sme.tile_load.
302  rewriter.modifyOpInPlace(xferOp, [&]() {
303  xferOp->setAttr(xferOp.getPermutationMapAttrName(),
305  permutation, transposeOp.getContext())));
306  });
307  rewriter.replaceOp(transposeOp, xferOp);
308  return success();
309  }
310 
311  // Allocate buffer to store input tile to.
312  Value vscale =
313  vector::VectorScaleOp::create(rewriter, loc, rewriter.getIndexType());
314  Value minTileSlices = arith::ConstantOp::create(
315  rewriter, loc, rewriter.getIndexAttr(tileType.getDimSize(0)));
316  Value c0 =
317  arith::ConstantOp::create(rewriter, loc, rewriter.getIndexAttr(0));
318  Value numTileSlices =
319  arith::MulIOp::create(rewriter, loc, vscale, minTileSlices);
320  auto bufferType =
321  MemRefType::get({ShapedType::kDynamic, ShapedType::kDynamic},
322  tileType.getElementType());
323  auto buffer = memref::AllocaOp::create(
324  rewriter, loc, bufferType, ValueRange{numTileSlices, numTileSlices});
325 
326  // Store input tile.
327  auto tileStoreOp = arm_sme::TileStoreOp::create(rewriter, loc, input,
328  buffer, ValueRange{c0, c0});
329 
330  // Reload input tile vertically.
331  rewriter.replaceOpWithNewOp<arm_sme::TileLoadOp>(
332  transposeOp, tileType, tileStoreOp.getBase(), tileStoreOp.getIndices(),
333  arm_sme::TileSliceLayout::Vertical);
334 
335  return success();
336  }
337 };
338 
339 /// Conversion pattern for vector.outerproduct.
340 ///
341 /// If the vector.outerproduct is masked (and the mask is from a
342 /// vector.create_mask), then the mask is decomposed into two 1-D masks for the
343 /// operands.
344 ///
345 /// Example:
346 ///
347 /// %mask = vector.create_mask %dimA, %dimB : vector<[4]x[4]xi1>
348 /// %result = vector.mask %mask {
349 /// vector.outerproduct %vecA, %vecB
350 /// : vector<[4]xf32>, vector<[4]xf32>
351 /// } : vector<[4]x[4]xi1> -> vector<[4]x[4]xf32>
352 ///
353 /// is converted to:
354 ///
355 /// %maskA = vector.create_mask %dimA : vector<[4]xi1>
356 /// %maskB = vector.create_mask %dimB : vector<[4]xi1>
357 /// %result = arm_sme.outerproduct %vecA, %vecB masks(%maskA, %maskB)
358 /// : vector<[4]xf32>, vector<[4]xf32>
359 ///
360 /// Unmasked outerproducts can be directly replaced with the arm_sme op.
361 ///
362 /// Example:
363 ///
364 /// %result = vector.outerproduct %vecA, %vecB
365 /// : vector<[4]xf32>, vector<[4]xf32>
366 ///
367 /// is converted to:
368 ///
369 /// %result = arm_sme.outerproduct %vecA, %vecB
370 /// : vector<[4]xf32>, vector<[4]xf32>
371 ///
372 struct VectorOuterProductToArmSMELowering
373  : public OpRewritePattern<vector::OuterProductOp> {
374 
376 
377  LogicalResult matchAndRewrite(vector::OuterProductOp outerProductOp,
378  PatternRewriter &rewriter) const override {
379 
380  // We don't yet support lowering AXPY operations to SME. These could be
381  // lowered by masking out all but the first element of the LHS.
382  if (!isa<VectorType>(outerProductOp.getOperandTypeRHS()))
383  return rewriter.notifyMatchFailure(outerProductOp,
384  "AXPY operations not supported");
385 
387  outerProductOp.getResultVectorType()))
388  return rewriter.notifyMatchFailure(
389  outerProductOp, "outer product does not fit into SME tile");
390 
391  auto kind = outerProductOp.getKind();
392  if (kind != vector::CombiningKind::ADD)
393  return rewriter.notifyMatchFailure(
394  outerProductOp,
395  "unsupported kind (lowering to SME only supports ADD at the moment)");
396 
397  Value lhsMask = {};
398  Value rhsMask = {};
399  Operation *rootOp = outerProductOp;
400  auto loc = outerProductOp.getLoc();
401  if (outerProductOp.isMasked()) {
402  auto maskOp = outerProductOp.getMaskingOp();
403  rewriter.setInsertionPoint(maskOp);
404  rootOp = maskOp;
405  auto operandMasks = decomposeResultMask(loc, maskOp.getMask(), rewriter);
406  if (failed(operandMasks))
407  return failure();
408  std::tie(lhsMask, rhsMask) = *operandMasks;
409  }
410 
411  rewriter.replaceOpWithNewOp<arm_sme::OuterProductOp>(
412  rootOp, outerProductOp.getResultVectorType(), outerProductOp.getLhs(),
413  outerProductOp.getRhs(), lhsMask, rhsMask, outerProductOp.getAcc());
414 
415  return success();
416  }
417 
418  static FailureOr<std::pair<Value, Value>>
419  decomposeResultMask(Location loc, Value mask, PatternRewriter &rewriter) {
420  // Attempt to extract masks from vector.create_mask.
421  // TODO: Add support for other mask sources.
422  auto createMaskOp = mask.getDefiningOp<vector::CreateMaskOp>();
423  if (!createMaskOp)
424  return failure();
425 
426  auto maskType = createMaskOp.getVectorType();
427  Value lhsMaskDim = createMaskOp.getOperand(0);
428  Value rhsMaskDim = createMaskOp.getOperand(1);
429 
430  VectorType operandMaskType = VectorType::Builder(maskType).dropDim(0);
431  Value lhsMask = vector::CreateMaskOp::create(rewriter, loc, operandMaskType,
432  lhsMaskDim);
433  Value rhsMask = vector::CreateMaskOp::create(rewriter, loc, operandMaskType,
434  rhsMaskDim);
435 
436  return std::make_pair(lhsMask, rhsMask);
437  }
438 };
439 
440 /// Lower `vector.extract` using `arm_sme.extract_tile_slice`.
441 ///
442 /// Example:
443 /// ```
444 /// %el = vector.extract %tile[%row, %col]: i32 from vector<[4]x[4]xi32>
445 /// ```
446 /// Becomes:
447 /// ```
448 /// %slice = arm_sme.extract_tile_slice %tile[%row]
449 /// : vector<[4]xi32> from vector<[4]x[4]xi32>
450 /// %el = vector.extract %slice[%col] : i32 from vector<[4]xi32>
451 /// ```
452 struct VectorExtractToArmSMELowering
453  : public OpRewritePattern<vector::ExtractOp> {
455 
456  LogicalResult matchAndRewrite(vector::ExtractOp extractOp,
457  PatternRewriter &rewriter) const override {
458  VectorType sourceType = extractOp.getSourceVectorType();
459  if (!arm_sme::isValidSMETileVectorType(sourceType))
460  return failure();
461 
462  auto loc = extractOp.getLoc();
463  auto position = extractOp.getMixedPosition();
464 
465  Value sourceVector = extractOp.getVector();
466 
467  // Extract entire vector. Should be handled by folder, but just to be safe.
468  if (position.empty()) {
469  rewriter.replaceOp(extractOp, sourceVector);
470  return success();
471  }
472 
473  Value sliceIndex = vector::getAsValues(rewriter, loc, position[0]).front();
474  auto extractTileSlice = arm_sme::ExtractTileSliceOp::create(
475  rewriter, loc, sourceVector, sliceIndex);
476 
477  if (position.size() == 1) {
478  // Single index case: Extracts a 1D slice.
479  rewriter.replaceOp(extractOp, extractTileSlice);
480  return success();
481  }
482 
483  // Two indices case: Extracts a single element.
484  assert(position.size() == 2);
485  rewriter.replaceOpWithNewOp<vector::ExtractOp>(extractOp, extractTileSlice,
486  position[1]);
487 
488  return success();
489  }
490 };
491 
492 /// Lower `vector.insert` using `arm_sme.insert_tile_slice` and
493 /// `arm_sme.extract_tile_slice`.
494 ///
495 /// Example:
496 /// ```
497 /// %new_tile = vector.insert %el, %tile[%row, %col]
498 /// : i32 into vector<[4]x[4]xi32>
499 /// ```
500 /// Becomes:
501 /// ```
502 /// %slice = arm_sme.extract_tile_slice %tile[%row]
503 /// : vector<[4]xi32> from vector<[4]x[4]xi32>
504 /// %new_slice = vector.insert %el, %slice[%col] : i32 into vector<[4]xi32>
505 /// %new_tile = arm_sme.insert_tile_slice %new_slice, %tile[%row]
506 /// : vector<[4]xi32> into vector<[4]x[4]xi32>
507 /// ```
508 struct VectorInsertToArmSMELowering
509  : public OpRewritePattern<vector::InsertOp> {
511 
512  LogicalResult matchAndRewrite(vector::InsertOp insertOp,
513  PatternRewriter &rewriter) const override {
514  VectorType resultType = insertOp.getResult().getType();
515 
516  if (!arm_sme::isValidSMETileVectorType(resultType))
517  return failure();
518 
519  auto loc = insertOp.getLoc();
520  auto position = insertOp.getMixedPosition();
521 
522  Value source = insertOp.getValueToStore();
523 
524  // Overwrite entire vector with value. Should be handled by folder, but
525  // just to be safe.
526  if (position.empty()) {
527  rewriter.replaceOp(insertOp, source);
528  return success();
529  }
530 
531  Value tileSlice = source;
532  Value sliceIndex = vector::getAsValues(rewriter, loc, position[0]).front();
533  if (position.size() == 2) {
534  // Two indices case: Insert single element into tile.
535  // We need to first extract the existing slice and update the element.
536  tileSlice = arm_sme::ExtractTileSliceOp::create(
537  rewriter, loc, insertOp.getDest(), sliceIndex);
538  tileSlice = vector::InsertOp::create(rewriter, loc, source, tileSlice,
539  position[1]);
540  }
541 
542  // Insert the slice into the destination tile.
543  rewriter.replaceOpWithNewOp<arm_sme::InsertTileSliceOp>(
544  insertOp, tileSlice, insertOp.getDest(), sliceIndex);
545  return success();
546  }
547 };
548 
549 /// Lowers `vector.print` of a tile into a loop over the rows of the tile,
550 /// extracting them via `arm_sme.extract_tile_slice`, then printing with
551 /// a 1D `vector.print`.
552 ///
553 /// BEFORE:
554 /// ```mlir
555 /// vector.print %tile : vector<[4]x[4]xf32>
556 /// ```
557 /// AFTER:
558 /// ```mlir
559 /// %c0 = arith.constant 0 : index
560 /// %c1 = arith.constant 1 : index
561 /// %c4 = arith.constant 4 : index
562 /// %vscale = vector.vscale
563 /// %svl_s = arith.muli %c4, %vscale : index
564 /// scf.for %i = %c0 to %svl_s step %c1 {
565 /// %tile_slice = arm_sme.extract_tile_slice %tile[%i]
566 /// : vector<[4]xf32> from vector<[4]x[4]xf32>
567 /// vector.print %tile_slice : vector<[4]xf32>
568 /// }
569 /// ```
570 struct VectorPrintToArmSMELowering : public OpRewritePattern<vector::PrintOp> {
572 
573  LogicalResult matchAndRewrite(vector::PrintOp printOp,
574  PatternRewriter &rewriter) const override {
575  if (!printOp.getSource())
576  return failure();
577 
578  VectorType vectorType = dyn_cast<VectorType>(printOp.getPrintType());
579  if (!vectorType || !arm_sme::isValidSMETileVectorType(vectorType))
580  return failure();
581 
582  auto loc = printOp.getLoc();
583 
584  // Create a loop over the rows of the tile.
585  auto vscale = vector::VectorScaleOp::create(rewriter, loc);
586  auto minTileRows =
587  arith::ConstantIndexOp::create(rewriter, loc, vectorType.getDimSize(0));
588  auto lowerBound = arith::ConstantIndexOp::create(rewriter, loc, 0);
589  auto upperBound = arith::MulIOp::create(rewriter, loc, minTileRows, vscale);
590  auto step = arith::ConstantIndexOp::create(rewriter, loc, 1);
591  auto forOp =
592  scf::ForOp::create(rewriter, loc, lowerBound, upperBound, step);
593  {
594  // Loop body.
595  rewriter.setInsertionPointToStart(forOp.getBody());
596  // Extract the current row from the tile.
597  Value rowIndex = forOp.getInductionVar();
598  auto tileSlice = arm_sme::ExtractTileSliceOp::create(
599  rewriter, loc, printOp.getSource(), rowIndex);
600  // Print the row with a 1D vector.print.
601  vector::PrintOp::create(rewriter, loc, tileSlice,
602  printOp.getPunctuation());
603  }
604 
605  rewriter.eraseOp(printOp);
606  return success();
607  }
608 };
609 
610 /// Folds a ExtractTileSliceOp + TransferWriteOp to a StoreTileSliceOp.
611 ///
612 /// BEFORE:
613 /// ```mlir
614 /// %slice = arm_sme.extract_tile_slice %tile[%index]
615 /// : vector<[4]xf32> from vector<[4]x[4]xf32>
616 /// vector.transfer_write %slice, %memref[%i, %j], %mask {in_bounds = [true]}
617 /// : vector<[4]xf32>, memref<?x?xf32>
618 /// ```
619 /// AFTER:
620 /// ```mlir
621 /// arm_sme.store_tile_slice %tile, %index, %mask, %memref[%i, %j]
622 /// : memref<?x?xf32>, vector<[4]xi1>, vector<[4]x[4]xf32>
623 /// ```
624 struct FoldTransferWriteOfExtractTileSlice
625  : public OpRewritePattern<vector::TransferWriteOp> {
627 
628  LogicalResult matchAndRewrite(vector::TransferWriteOp writeOp,
629  PatternRewriter &rewriter) const final {
630  if (!isa<MemRefType>(writeOp.getBase().getType()))
631  return rewriter.notifyMatchFailure(writeOp, "destination not a memref");
632 
633  if (writeOp.hasOutOfBoundsDim())
634  return rewriter.notifyMatchFailure(writeOp,
635  "not inbounds transfer write");
636 
637  auto extractTileSlice =
638  writeOp.getVector().getDefiningOp<arm_sme::ExtractTileSliceOp>();
639  if (!extractTileSlice)
640  return rewriter.notifyMatchFailure(
641  writeOp, "vector to store not from ExtractTileSliceOp");
642 
643  AffineMap map = writeOp.getPermutationMap();
644  if (!map.isMinorIdentity())
645  return rewriter.notifyMatchFailure(writeOp,
646  "unsupported permutation map");
647 
648  Value mask = writeOp.getMask();
649  if (!mask) {
650  auto maskType = writeOp.getVectorType().clone(rewriter.getI1Type());
651  mask = arith::ConstantOp::create(rewriter, writeOp.getLoc(), maskType,
652  DenseElementsAttr::get(maskType, true));
653  }
654 
655  rewriter.replaceOpWithNewOp<arm_sme::StoreTileSliceOp>(
656  writeOp, extractTileSlice.getTile(),
657  extractTileSlice.getTileSliceIndex(), mask, writeOp.getBase(),
658  writeOp.getIndices(), extractTileSlice.getLayout());
659  return success();
660  }
661 };
662 
663 /// Lower a `vector.extract` from a 2-D scalable `vector.create_mask` to
664 /// `arm_sve.psel`. Note: While psel is under ArmSVE it requires SME (or
665 /// SVE 2.1), so this is currently the most logical place for this lowering.
666 ///
667 /// Example:
668 /// ```mlir
669 /// %mask = vector.create_mask %a, %b : vector<[4]x[8]xi1>
670 /// %slice = vector.extract %mask[%index]
671 /// : vector<[8]xi1> from vector<[4]x[8]xi1>
672 /// ```
673 /// Becomes:
674 /// ```
675 /// %mask_rows = vector.create_mask %a : vector<[4]xi1>
676 /// %mask_cols = vector.create_mask %b : vector<[8]xi1>
677 /// %slice = arm_sve.psel %mask_cols, %mask_rows[%index]
678 /// : vector<[8]xi1>, vector<[4]xi1>
679 /// ```
680 struct ExtractFromCreateMaskToPselLowering
681  : public OpRewritePattern<vector::ExtractOp> {
683 
684  LogicalResult matchAndRewrite(vector::ExtractOp extractOp,
685  PatternRewriter &rewriter) const override {
686  if (extractOp.getNumIndices() != 1)
687  return rewriter.notifyMatchFailure(extractOp, "not single extract index");
688 
689  auto resultType = extractOp.getResult().getType();
690  auto resultVectorType = dyn_cast<VectorType>(resultType);
691  if (!resultVectorType)
692  return rewriter.notifyMatchFailure(extractOp, "result not VectorType");
693 
694  auto createMaskOp =
695  extractOp.getVector().getDefiningOp<vector::CreateMaskOp>();
696  if (!createMaskOp)
697  return rewriter.notifyMatchFailure(extractOp, "source not CreateMaskOp");
698 
699  auto maskType = createMaskOp.getVectorType();
700  if (maskType.getRank() != 2 || !maskType.allDimsScalable())
701  return rewriter.notifyMatchFailure(createMaskOp, "not 2-D scalable mask");
702 
703  auto isSVEPredicateSize = [](int64_t size) {
704  return size > 0 && size <= 16 && llvm::isPowerOf2_32(uint32_t(size));
705  };
706 
707  auto rowsBaseSize = maskType.getDimSize(0);
708  auto colsBaseSize = maskType.getDimSize(1);
709  if (!isSVEPredicateSize(rowsBaseSize) || !isSVEPredicateSize(colsBaseSize))
710  return rewriter.notifyMatchFailure(
711  createMaskOp, "mask dimensions not SVE predicate-sized");
712 
713  auto loc = extractOp.getLoc();
714  VectorType rowMaskType = VectorType::Builder(maskType).dropDim(1);
715  VectorType colMaskType = VectorType::Builder(maskType).dropDim(0);
716 
717  // Create the two 1-D masks at the location of the 2-D create_mask (which is
718  // usually outside a loop). This prevents the need for later hoisting.
719  rewriter.setInsertionPoint(createMaskOp);
720  auto rowMask = vector::CreateMaskOp::create(rewriter, loc, rowMaskType,
721  createMaskOp.getOperand(0));
722  auto colMask = vector::CreateMaskOp::create(rewriter, loc, colMaskType,
723  createMaskOp.getOperand(1));
724 
725  rewriter.setInsertionPoint(extractOp);
726  auto position =
727  vector::getAsValues(rewriter, loc, extractOp.getMixedPosition());
728  rewriter.replaceOpWithNewOp<arm_sve::PselOp>(extractOp, colMask, rowMask,
729  position[0]);
730  return success();
731  }
732 };
733 
734 // Convert all `vector.splat` to `vector.broadcast`. There is a path from
735 // `vector.broadcast` to ArmSME via another pattern.
736 struct ConvertSplatToBroadcast : public OpRewritePattern<vector::SplatOp> {
738 
739  LogicalResult matchAndRewrite(vector::SplatOp splatOp,
740  PatternRewriter &rewriter) const final {
741 
742  rewriter.replaceOpWithNewOp<vector::BroadcastOp>(splatOp, splatOp.getType(),
743  splatOp.getInput());
744  return success();
745  }
746 };
747 
748 } // namespace
749 
751  MLIRContext &ctx) {
752  patterns.add<BroadcastOpToArmSMELowering, ConvertSplatToBroadcast,
753  TransferReadToArmSMELowering, TransferWriteToArmSMELowering,
754  TransposeOpToArmSMELowering, VectorLoadToArmSMELowering,
755  VectorStoreToArmSMELowering, VectorOuterProductToArmSMELowering,
756  VectorExtractToArmSMELowering, VectorInsertToArmSMELowering,
757  VectorPrintToArmSMELowering, FoldTransferWriteOfExtractTileSlice,
758  ExtractFromCreateMaskToPselLowering>(&ctx);
759 }
union mlir::linalg::@1224::ArityGroupAndKind::Kind kind
static void printOp(llvm::raw_ostream &os, Operation *op, OpPrintingFlags &flags)
Definition: Unit.cpp:18
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
Definition: AffineMap.h:46
bool isMinorIdentity() const
Returns true if this affine map is a minor identity, i.e.
Definition: AffineMap.cpp:151
static AffineMap getPermutationMap(ArrayRef< unsigned > permutation, MLIRContext *context)
Returns an AffineMap representing a permutation.
Definition: AffineMap.cpp:260
bool isIdentity() const
Returns true if this affine map is an identity affine map.
Definition: AffineMap.cpp:341
bool isPermutation() const
Returns true if the AffineMap represents a symbol-less permutation map.
Definition: AffineMap.cpp:641
IntegerAttr getIndexAttr(int64_t value)
Definition: Builders.cpp:103
IntegerType getI1Type()
Definition: Builders.cpp:52
IndexType getIndexType()
Definition: Builders.cpp:50
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:76
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
This class helps build Operations.
Definition: Builders.h:205
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
Definition: Builders.h:429
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Definition: Builders.h:396
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
bool hasOneUse()
Returns true if this operation has exactly one use.
Definition: Operation.h:849
Location getLoc()
The source location the operation was defined or derived from.
Definition: Operation.h:223
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:769
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
Definition: PatternMatch.h:702
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
Definition: PatternMatch.h:614
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Definition: PatternMatch.h:519
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:387
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition: Value.cpp:18
This is a builder type that keeps local references to arguments.
Definition: BuiltinTypes.h:286
Builder & dropDim(unsigned pos)
Erase a dim from shape @pos.
Definition: BuiltinTypes.h:311
static ConstantIndexOp create(OpBuilder &builder, Location location, int64_t value)
Definition: ArithOps.cpp:359
scf::ForOp createLoopOverTileSlices(PatternRewriter &rewriter, Location loc, Value initTile, std::function< Value(OpBuilder &, Location, Value, Value)> makeLoopBody)
Generates a for loop over ZA tile slices where the induction variable is the tile slice index and eac...
Definition: Utils.cpp:74
bool isValidSMETileVectorType(VectorType vType)
Returns true if vType is a valid vector type for an SME tile or false otherwise.
Definition: Utils.cpp:28
SmallVector< Value > getAsValues(OpBuilder &builder, Location loc, ArrayRef< OpFoldResult > foldResults)
Convert foldResults into Values.
Definition: VectorOps.cpp:369
Include the generated interface declarations.
const FrozenRewritePatternSet & patterns
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
void populateVectorToArmSMEPatterns(RewritePatternSet &patterns, MLIRContext &ctx)
Collect a set of patterns to lower Vector ops to ArmSME ops that map to LLVM intrinsics.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Definition: PatternMatch.h:314