MLIR  19.0.0git
VectorToArmSME.cpp
Go to the documentation of this file.
1 //===- VectorToArmSME.cpp - Conversion from Vector to the ArmSME dialect --===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
10 
14 #include "mlir/IR/BuiltinTypes.h"
15 #include "llvm/Support/Casting.h"
16 
17 using namespace mlir;
18 
19 namespace {
20 
21 /// Conversion pattern for vector.transfer_read.
22 ///
23 /// ---
24 ///
25 /// Example 1: op with identity permutation map to horizontal
26 /// arm_sme.tile_load:
27 ///
28 /// vector.transfer_read ... permutation_map: (d0, d1) -> (d0, d1)
29 ///
30 /// is converted to:
31 ///
32 /// arm_sme.tile_load ...
33 ///
34 /// ---
35 ///
36 /// Example 2: op with transpose permutation map to vertical arm_sme.tile_load
37 /// (in-flight transpose):
38 ///
39 /// vector.transfer_read ... permutation_map: (d0, d1) -> (d1, d0)
40 ///
41 /// is converted to:
42 ///
43 /// arm_sme.tile_load ... layout<vertical>
44 struct TransferReadToArmSMELowering
45  : public OpRewritePattern<vector::TransferReadOp> {
47 
48  LogicalResult matchAndRewrite(vector::TransferReadOp transferReadOp,
49  PatternRewriter &rewriter) const final {
50  // The permutation map must have two results.
51  if (transferReadOp.getTransferRank() != 2)
52  return rewriter.notifyMatchFailure(transferReadOp,
53  "not a 2 result permutation map");
54 
55  auto vectorType = transferReadOp.getVectorType();
56  if (!arm_sme::isValidSMETileVectorType(vectorType))
57  return rewriter.notifyMatchFailure(transferReadOp,
58  "not a valid vector type for SME");
59 
60  if (!llvm::isa<MemRefType>(transferReadOp.getSource().getType()))
61  return rewriter.notifyMatchFailure(transferReadOp, "not a memref source");
62 
63  // Out-of-bounds dims are not supported.
64  if (transferReadOp.hasOutOfBoundsDim())
65  return rewriter.notifyMatchFailure(transferReadOp,
66  "not inbounds transfer read");
67 
68  arm_sme::TileSliceLayout layout;
69 
70  AffineExpr d0, d1;
71  bindDims(transferReadOp.getContext(), d0, d1);
72  AffineMap map = transferReadOp.getPermutationMap();
73  if (map.isIdentity())
74  layout = arm_sme::TileSliceLayout::Horizontal;
75  else if (map == AffineMap::get(map.getNumDims(), 0, {d1, d0},
76  transferReadOp.getContext()))
77  layout = arm_sme::TileSliceLayout::Vertical;
78  else
79  return rewriter.notifyMatchFailure(transferReadOp,
80  "unsupported permutation map");
81 
82  // Padding isn't optional for transfer_read, but is only used in the case
83  // of out-of-bounds accesses (not supported here) and/or masking. Mask is
84  // optional, if it's not present don't pass padding.
85  auto mask = transferReadOp.getMask();
86  auto padding = mask ? transferReadOp.getPadding() : nullptr;
87  rewriter.replaceOpWithNewOp<arm_sme::TileLoadOp>(
88  transferReadOp, vectorType, transferReadOp.getSource(),
89  transferReadOp.getIndices(), padding, mask, layout);
90 
91  return success();
92  }
93 };
94 
95 /// Conversion pattern for vector.transfer_write.
96 ///
97 /// ---
98 ///
99 /// Example 1: op with identity permutation map to horizontal
100 /// arm_sme.tile_store:
101 ///
102 /// vector.transfer_write %vector, %source[%c0, %c0]
103 /// {in_bounds = [true, true]} : vector<[16]x[16]xi8>, memref<?x?xi8>
104 ///
105 /// is converted to:
106 ///
107 /// arm_sme.tile_store %vector, %source[%c0, %c0] : memref<?x?xi8>,
108 /// vector<[16]x[16]xi8>
109 /// ---
110 ///
111 /// Example 2: op with transpose permutation map to vertical arm_sme.tile_store
112 /// (in-flight transpose):
113 ///
114 /// vector.transfer_write %vector, %source[%c0, %c0]
115 /// {permutation_map = affine_map<(d0, d1) -> (d1, d0)>,
116 /// in_bounds = [true, true]} : vector<[16]x[16]xi8>, memref<?x?xi8>
117 ///
118 /// is converted to:
119 ///
120 /// arm_sme.tile_store %vector, %source[%c0, %c0] layout<vertical>
121 /// : memref<?x?xi8>, vector<[16]x[16]xi8>
122 struct TransferWriteToArmSMELowering
123  : public OpRewritePattern<vector::TransferWriteOp> {
125 
126  LogicalResult matchAndRewrite(vector::TransferWriteOp writeOp,
127  PatternRewriter &rewriter) const final {
128  auto vType = writeOp.getVectorType();
130  return failure();
131 
132  if (!llvm::isa<MemRefType>(writeOp.getSource().getType()))
133  return failure();
134 
135  // Out-of-bounds dims are not supported.
136  if (writeOp.hasOutOfBoundsDim())
137  return rewriter.notifyMatchFailure(writeOp,
138  "not inbounds transfer write");
139 
140  AffineExpr d0, d1;
141  bindDims(writeOp.getContext(), d0, d1);
142  AffineMap map = writeOp.getPermutationMap();
143  bool isTranspose = (map == AffineMap::get(map.getNumDims(), 0, {d1, d0},
144  writeOp.getContext()));
145 
146  if (!map.isIdentity() && !isTranspose)
147  return rewriter.notifyMatchFailure(writeOp,
148  "unsupported permutation map");
149 
150  arm_sme::TileSliceLayout layout =
151  isTranspose ? arm_sme::TileSliceLayout::Vertical
152  : arm_sme::TileSliceLayout::Horizontal;
153 
154  rewriter.replaceOpWithNewOp<arm_sme::TileStoreOp>(
155  writeOp, writeOp.getVector(), writeOp.getSource(), writeOp.getIndices(),
156  writeOp.getMask(), layout);
157  return success();
158  }
159 };
160 
161 /// Conversion pattern for vector.load.
162 struct VectorLoadToArmSMELowering : public OpRewritePattern<vector::LoadOp> {
164 
165  LogicalResult matchAndRewrite(vector::LoadOp load,
166  PatternRewriter &rewriter) const override {
167  if (!arm_sme::isValidSMETileVectorType(load.getVectorType()))
168  return failure();
169 
170  rewriter.replaceOpWithNewOp<arm_sme::TileLoadOp>(
171  load, load.getVectorType(), load.getBase(), load.getIndices());
172 
173  return success();
174  }
175 };
176 
177 /// Conversion pattern for vector.store.
178 struct VectorStoreToArmSMELowering : public OpRewritePattern<vector::StoreOp> {
180 
181  LogicalResult matchAndRewrite(vector::StoreOp store,
182  PatternRewriter &rewriter) const override {
183  if (!arm_sme::isValidSMETileVectorType(store.getVectorType()))
184  return failure();
185 
186  rewriter.replaceOpWithNewOp<arm_sme::TileStoreOp>(
187  store, store.getValueToStore(), store.getBase(), store.getIndices());
188 
189  return success();
190  }
191 };
192 
193 /// Conversion pattern for vector.broadcast.
194 ///
195 /// Example:
196 ///
197 /// %broadcast_to_tile = vector.broadcast %src : i32 to vector<[4]x[4]xi32>
198 ///
199 /// is converted to:
200 ///
201 /// %broadcast_to_1d = vector.broadcast %src : i32 to vector<[4]xi32>
202 /// %broadcast_to_tile = scf.for %tile_slice_index = %c0 to %num_tile_slices
203 /// step %c1 iter_args(%iter_tile = %init_tile) -> (vector<[4]x[4]xi32>)
204 /// {
205 /// %tile_update = arm_sme.move_vector_to_tile_slice
206 /// %broadcast_to_1d, %iter_tile, %tile_slice_index :
207 /// vector<[4]xi32> into vector<[4]x[4]xi32>
208 /// scf.yield %tile_update : vector<[4]x[4]xi32>
209 /// }
210 ///
211 /// Supports scalar, 0-d vector, and 1-d vector broadcasts.
212 struct BroadcastOpToArmSMELowering
213  : public OpRewritePattern<vector::BroadcastOp> {
215 
216  LogicalResult matchAndRewrite(vector::BroadcastOp broadcastOp,
217  PatternRewriter &rewriter) const final {
218  auto tileType = broadcastOp.getResultVectorType();
219  if (!tileType || !arm_sme::isValidSMETileVectorType(tileType))
220  return failure();
221 
222  auto loc = broadcastOp.getLoc();
223 
224  auto srcType = broadcastOp.getSourceType();
225  auto srcVectorType = dyn_cast<VectorType>(srcType);
226 
227  Value broadcastOp1D;
228  if (srcType.isIntOrFloat() ||
229  (srcVectorType && (srcVectorType.getRank() == 0))) {
230  // Broadcast scalar or 0-d vector to 1-d vector.
231  VectorType tileSliceType = VectorType::Builder(tileType).dropDim(0);
232  broadcastOp1D = rewriter.create<vector::BroadcastOp>(
233  loc, tileSliceType, broadcastOp.getSource());
234  } else if (srcVectorType && (srcVectorType.getRank() == 1))
235  // Value to broadcast is already a 1-d vector, nothing to do.
236  broadcastOp1D = broadcastOp.getSource();
237  else
238  return failure();
239 
240  auto initTile = rewriter.create<arm_sme::GetTileOp>(loc, tileType);
241 
242  auto makeLoopBody = [&](OpBuilder &b, Location loc, Value tileSliceIndex,
243  Value currentTile) {
244  // Create 'arm_sme.move_vector_to_tile_slice' to broadcast the value
245  // to each tile slice.
246  auto nextTile = b.create<arm_sme::MoveVectorToTileSliceOp>(
247  loc, tileType, broadcastOp1D, currentTile, tileSliceIndex);
248  return nextTile.getResult();
249  };
250 
251  // Create a loop over ZA tile slices.
252  auto forOp =
253  createLoopOverTileSlices(rewriter, loc, initTile, makeLoopBody);
254 
255  rewriter.replaceOp(broadcastOp, forOp.getResult(0));
256 
257  return success();
258  }
259 };
260 
261 /// Conversion pattern for vector.splat.
262 ///
263 /// Example:
264 ///
265 /// %splat_to_tile = vector.splat %src : i32 to vector<[4]x[4]xi32>
266 ///
267 /// is converted to:
268 ///
269 /// %broadcast_to_1d = vector.broadcast %src : i32 to vector<[4]xi32>
270 /// %broadcast_to_tile = scf.for %tile_slice_index = %c0 to %num_tile_slices
271 /// step %c1 iter_args(%iter_tile = %init_tile) -> (vector<[4]x[4]xi32>)
272 /// {
273 /// %tile_update = arm_sme.move_vector_to_tile_slice
274 /// %broadcast_to_1d, %iter_tile, %tile_slice_index :
275 /// vector<[4]xi32> into vector<[4]x[4]xi32>
276 /// scf.yield %tile_update : vector<[4]x[4]xi32>
277 /// }
278 ///
279 /// This is identical to vector.broadcast of a scalar.
280 struct SplatOpToArmSMELowering : public OpRewritePattern<vector::SplatOp> {
282 
283  LogicalResult matchAndRewrite(vector::SplatOp splatOp,
284  PatternRewriter &rewriter) const final {
285  auto tileType = splatOp.getResult().getType();
286  if (!tileType || !arm_sme::isValidSMETileVectorType(tileType))
287  return failure();
288 
289  auto loc = splatOp.getLoc();
290  auto srcType = splatOp.getOperand().getType();
291 
292  assert(srcType.isIntOrFloat() && "Invalid source type for vector.splat");
293  // Avoid unused-variable warning when building without assertions.
294  (void)srcType;
295 
296  // First, broadcast the scalar to a 1-d vector.
297  VectorType tileSliceType = VectorType::Builder(tileType).dropDim(0);
298  Value broadcastOp1D = rewriter.create<vector::BroadcastOp>(
299  loc, tileSliceType, splatOp.getInput());
300 
301  auto initTile = rewriter.create<arm_sme::GetTileOp>(loc, tileType);
302 
303  auto makeLoopBody = [&](OpBuilder &b, Location loc, Value tileSliceIndex,
304  Value currentTile) {
305  auto nextTile = b.create<arm_sme::MoveVectorToTileSliceOp>(
306  loc, tileType, broadcastOp1D, currentTile, tileSliceIndex);
307  return nextTile.getResult();
308  };
309 
310  // Next, create a loop over ZA tile slices and "move" the generated 1-d
311  // vector to each slice.
312  auto forOp =
313  createLoopOverTileSlices(rewriter, loc, initTile, makeLoopBody);
314 
315  rewriter.replaceOp(splatOp, forOp.getResult(0));
316 
317  return success();
318  }
319 };
320 
321 /// Conversion pattern for vector.transpose.
322 ///
323 /// Stores the input tile to memory and reloads vertically.
324 ///
325 /// Example:
326 ///
327 /// %transposed_src = vector.transpose %src, [1, 0]
328 /// : vector<[4]x[4]xi32> to vector<[4]x[4]xi32>
329 ///
330 /// is converted to:
331 ///
332 /// %alloca = memref.alloca(%svl_s, %svl_s) : memref<?x?xi32>
333 /// %arm_sme.tile_store %src, <hor>, %alloca[%c0, %c0]
334 /// : memref<?x?xi32>, vector<[4]x[4]xi32>
335 /// %transposed_src = arm_sme.tile_load %alloca[%c0, %c0]
336 /// layout<vertical> : memref<?x?xi32>, vector<[4]x[4]xi32>
337 ///
338 /// NOTE: Tranposing via memory is obviously expensive, the current intention
339 /// is to avoid the transpose if possible, this is therefore intended as a
340 /// fallback and to provide base support for Vector ops. If it turns out
341 /// transposes can't be avoided then this should be replaced with a more optimal
342 /// implementation, perhaps with tile <-> vector (MOVA) ops.
343 struct TransposeOpToArmSMELowering
344  : public OpRewritePattern<vector::TransposeOp> {
346 
347  LogicalResult matchAndRewrite(vector::TransposeOp transposeOp,
348  PatternRewriter &rewriter) const final {
349  auto tileType = transposeOp.getResultVectorType();
350  if (!tileType || !arm_sme::isValidSMETileVectorType(tileType))
351  return failure();
352 
353  // Bail unless this is a true 2-D matrix transpose.
354  ArrayRef<int64_t> permutation = transposeOp.getPermutation();
355  if (permutation[0] != 1 || permutation[1] != 0)
356  return failure();
357 
358  auto loc = transposeOp.getLoc();
359 
360  // Allocate buffer to store input tile to.
361  Value vscale =
362  rewriter.create<vector::VectorScaleOp>(loc, rewriter.getIndexType());
363  Value minTileSlices = rewriter.create<arith::ConstantOp>(
364  loc, rewriter.getIndexAttr(tileType.getDimSize(0)));
365  Value c0 =
366  rewriter.create<arith::ConstantOp>(loc, rewriter.getIndexAttr(0));
367  Value numTileSlices =
368  rewriter.create<arith::MulIOp>(loc, vscale, minTileSlices);
369  auto bufferType =
370  MemRefType::get({ShapedType::kDynamic, ShapedType::kDynamic},
371  tileType.getElementType());
372  auto buffer = rewriter.create<memref::AllocaOp>(
373  loc, bufferType, ValueRange{numTileSlices, numTileSlices});
374 
375  Value input = transposeOp.getVector();
376 
377  // Store input tile.
378  auto tileStoreOp = rewriter.create<arm_sme::TileStoreOp>(
379  loc, input, buffer, ValueRange{c0, c0});
380 
381  // Reload input tile vertically.
382  rewriter.replaceOpWithNewOp<arm_sme::TileLoadOp>(
383  transposeOp, tileType, tileStoreOp.getBase(), tileStoreOp.getIndices(),
384  arm_sme::TileSliceLayout::Vertical);
385 
386  return success();
387  }
388 };
389 
390 /// Conversion pattern for vector.outerproduct.
391 ///
392 /// If the vector.outerproduct is masked (and the mask is from a
393 /// vector.create_mask), then the mask is decomposed into two 1-D masks for the
394 /// operands.
395 ///
396 /// Example:
397 ///
398 /// %mask = vector.create_mask %dimA, %dimB : vector<[4]x[4]xi1>
399 /// %result = vector.mask %mask {
400 /// vector.outerproduct %vecA, %vecB
401 /// : vector<[4]xf32>, vector<[4]xf32>
402 /// } : vector<[4]x[4]xi1> -> vector<[4]x[4]xf32>
403 ///
404 /// is converted to:
405 ///
406 /// %maskA = vector.create_mask %dimA : vector<[4]xi1>
407 /// %maskB = vector.create_mask %dimB : vector<[4]xi1>
408 /// %result = arm_sme.outerproduct %vecA, %vecB masks(%maskA, %maskB)
409 /// : vector<[4]xf32>, vector<[4]xf32>
410 ///
411 /// Unmasked outerproducts can be directly replaced with the arm_sme op.
412 ///
413 /// Example:
414 ///
415 /// %result = vector.outerproduct %vecA, %vecB
416 /// : vector<[4]xf32>, vector<[4]xf32>
417 ///
418 /// is converted to:
419 ///
420 /// %result = arm_sme.outerproduct %vecA, %vecB
421 /// : vector<[4]xf32>, vector<[4]xf32>
422 ///
423 struct VectorOuterProductToArmSMELowering
424  : public OpRewritePattern<vector::OuterProductOp> {
425 
427 
428  LogicalResult matchAndRewrite(vector::OuterProductOp outerProductOp,
429  PatternRewriter &rewriter) const override {
430 
431  // We don't yet support lowering AXPY operations to SME. These could be
432  // lowered by masking out all but the first element of the LHS.
433  if (!isa<VectorType>(outerProductOp.getOperandTypeRHS()))
434  return rewriter.notifyMatchFailure(outerProductOp,
435  "AXPY operations not supported");
436 
438  outerProductOp.getResultVectorType()))
439  return rewriter.notifyMatchFailure(
440  outerProductOp, "outer product does not fit into SME tile");
441 
442  auto kind = outerProductOp.getKind();
443  if (kind != vector::CombiningKind::ADD)
444  return rewriter.notifyMatchFailure(
445  outerProductOp,
446  "unsupported kind (lowering to SME only supports ADD at the moment)");
447 
448  Value lhsMask = {};
449  Value rhsMask = {};
450  Operation *rootOp = outerProductOp;
451  auto loc = outerProductOp.getLoc();
452  if (outerProductOp.isMasked()) {
453  auto maskOp = outerProductOp.getMaskingOp();
454  rewriter.setInsertionPoint(maskOp);
455  rootOp = maskOp;
456  auto operandMasks = decomposeResultMask(loc, maskOp.getMask(), rewriter);
457  if (failed(operandMasks))
458  return failure();
459  std::tie(lhsMask, rhsMask) = *operandMasks;
460  }
461 
462  rewriter.replaceOpWithNewOp<arm_sme::OuterProductOp>(
463  rootOp, outerProductOp.getResultVectorType(), outerProductOp.getLhs(),
464  outerProductOp.getRhs(), lhsMask, rhsMask, outerProductOp.getAcc());
465 
466  return success();
467  }
468 
470  decomposeResultMask(Location loc, Value mask, PatternRewriter &rewriter) {
471  // Attempt to extract masks from vector.create_mask.
472  // TODO: Add support for other mask sources.
473  auto createMaskOp = mask.getDefiningOp<vector::CreateMaskOp>();
474  if (!createMaskOp)
475  return failure();
476 
477  auto maskType = createMaskOp.getVectorType();
478  Value lhsMaskDim = createMaskOp.getOperand(0);
479  Value rhsMaskDim = createMaskOp.getOperand(1);
480 
481  VectorType operandMaskType = VectorType::Builder(maskType).dropDim(0);
482  Value lhsMask =
483  rewriter.create<vector::CreateMaskOp>(loc, operandMaskType, lhsMaskDim);
484  Value rhsMask =
485  rewriter.create<vector::CreateMaskOp>(loc, operandMaskType, rhsMaskDim);
486 
487  return std::make_pair(lhsMask, rhsMask);
488  }
489 };
490 
491 /// Lower `vector.extract` using `arm_sme.move_tile_slice_to_vector`.
492 ///
493 /// Example:
494 /// ```
495 /// %el = vector.extract %tile[%row, %col]: i32 from vector<[4]x[4]xi32>
496 /// ```
497 /// Becomes:
498 /// ```
499 /// %slice = arm_sme.move_tile_slice_to_vector %tile[%row]
500 /// : vector<[4]xi32> from vector<[4]x[4]xi32>
501 /// %el = vector.extract %slice[%col] : i32 from vector<[4]xi32>
502 /// ```
503 struct VectorExtractToArmSMELowering
504  : public OpRewritePattern<vector::ExtractOp> {
506 
507  LogicalResult matchAndRewrite(vector::ExtractOp extractOp,
508  PatternRewriter &rewriter) const override {
509  VectorType sourceType = extractOp.getSourceVectorType();
510  if (!arm_sme::isValidSMETileVectorType(sourceType))
511  return failure();
512 
513  auto loc = extractOp.getLoc();
514  auto position = extractOp.getMixedPosition();
515 
516  Value sourceVector = extractOp.getVector();
517 
518  // Extract entire vector. Should be handled by folder, but just to be safe.
519  if (position.empty()) {
520  rewriter.replaceOp(extractOp, sourceVector);
521  return success();
522  }
523 
524  Value sliceIndex = vector::getAsValues(rewriter, loc, position[0]).front();
525  auto moveTileSliceToVector =
526  rewriter.create<arm_sme::MoveTileSliceToVectorOp>(loc, sourceVector,
527  sliceIndex);
528 
529  if (position.size() == 1) {
530  // Single index case: Extracts a 1D slice.
531  rewriter.replaceOp(extractOp, moveTileSliceToVector);
532  return success();
533  }
534 
535  // Two indices case: Extracts a single element.
536  assert(position.size() == 2);
537  rewriter.replaceOpWithNewOp<vector::ExtractOp>(
538  extractOp, moveTileSliceToVector, position[1]);
539 
540  return success();
541  }
542 };
543 
544 /// Lower `vector.insert` using `arm_sme.move_vector_to_tile_slice` and
545 /// `arm_sme.move_tile_slice_to_vector`.
546 ///
547 /// Example:
548 /// ```
549 /// %new_tile = vector.insert %el, %tile[%row, %col]
550 /// : i32 into vector<[4]x[4]xi32>
551 /// ```
552 /// Becomes:
553 /// ```
554 /// %slice = arm_sme.move_tile_slice_to_vector %tile[%row]
555 /// : vector<[4]xi32> from vector<[4]x[4]xi32>
556 /// %new_slice = vector.insert %el, %slice[%col] : i32 into vector<[4]xi32>
557 /// %new_tile = arm_sme.move_vector_to_tile_slice %new_slice, %tile, %row
558 /// : vector<[4]xi32> into vector<[4]x[4]xi32>
559 /// ```
560 struct VectorInsertToArmSMELowering
561  : public OpRewritePattern<vector::InsertOp> {
563 
564  LogicalResult matchAndRewrite(vector::InsertOp insertOp,
565  PatternRewriter &rewriter) const override {
566  VectorType resultType = insertOp.getResult().getType();
567 
568  if (!arm_sme::isValidSMETileVectorType(resultType))
569  return failure();
570 
571  auto loc = insertOp.getLoc();
572  auto position = insertOp.getMixedPosition();
573 
574  Value source = insertOp.getSource();
575 
576  // Overwrite entire vector with value. Should be handled by folder, but
577  // just to be safe.
578  if (position.empty()) {
579  rewriter.replaceOp(insertOp, source);
580  return success();
581  }
582 
583  Value tileSlice = source;
584  Value sliceIndex = vector::getAsValues(rewriter, loc, position[0]).front();
585  if (position.size() == 2) {
586  // Two indices case: Insert single element into tile.
587  // We need to first extract the existing slice and update the element.
588  tileSlice = rewriter.create<arm_sme::MoveTileSliceToVectorOp>(
589  loc, insertOp.getDest(), sliceIndex);
590  tileSlice = rewriter.create<vector::InsertOp>(loc, source, tileSlice,
591  position[1]);
592  }
593 
594  // Insert the slice into the destination tile.
595  rewriter.replaceOpWithNewOp<arm_sme::MoveVectorToTileSliceOp>(
596  insertOp, tileSlice, insertOp.getDest(), sliceIndex);
597  return success();
598  }
599 };
600 
601 /// Lowers `vector.print` of a tile into a loop over the rows of the tile,
602 /// extracting them via `arm_sme.move_tile_slice_to_vector`, then printing with
603 /// a 1D `vector.print`.
604 ///
605 /// BEFORE:
606 /// ```mlir
607 /// vector.print %tile : vector<[4]x[4]xf32>
608 /// ```
609 /// AFTER:
610 /// ```mlir
611 /// %c0 = arith.constant 0 : index
612 /// %c1 = arith.constant 1 : index
613 /// %c4 = arith.constant 4 : index
614 /// %vscale = vector.vscale
615 /// %svl_s = arith.muli %c4, %vscale : index
616 /// scf.for %i = %c0 to %svl_s step %c1 {
617 /// %tile_slice = arm_sme.move_tile_slice_to_vector %tile[%i]
618 /// : vector<[4]xf32> from vector<[4]x[4]xf32>
619 /// vector.print %tile_slice : vector<[4]xf32>
620 /// }
621 /// ```
622 struct VectorPrintToArmSMELowering : public OpRewritePattern<vector::PrintOp> {
624 
625  LogicalResult matchAndRewrite(vector::PrintOp printOp,
626  PatternRewriter &rewriter) const override {
627  if (!printOp.getSource())
628  return failure();
629 
630  VectorType vectorType = dyn_cast<VectorType>(printOp.getPrintType());
631  if (!vectorType || !arm_sme::isValidSMETileVectorType(vectorType))
632  return failure();
633 
634  auto loc = printOp.getLoc();
635 
636  // Create a loop over the rows of the tile.
637  auto vscale = rewriter.create<vector::VectorScaleOp>(loc);
638  auto minTileRows =
639  rewriter.create<arith::ConstantIndexOp>(loc, vectorType.getDimSize(0));
640  auto lowerBound = rewriter.create<arith::ConstantIndexOp>(loc, 0);
641  auto upperBound = rewriter.create<arith::MulIOp>(loc, minTileRows, vscale);
642  auto step = rewriter.create<arith::ConstantIndexOp>(loc, 1);
643  auto forOp = rewriter.create<scf::ForOp>(loc, lowerBound, upperBound, step);
644  {
645  // Loop body.
646  rewriter.setInsertionPointToStart(forOp.getBody());
647  // Extract the current row from the tile.
648  Value rowIndex = forOp.getInductionVar();
649  auto tileSlice = rewriter.create<arm_sme::MoveTileSliceToVectorOp>(
650  loc, printOp.getSource(), rowIndex);
651  // Print the row with a 1D vector.print.
652  rewriter.create<vector::PrintOp>(loc, tileSlice,
653  printOp.getPunctuation());
654  }
655 
656  rewriter.eraseOp(printOp);
657  return success();
658  }
659 };
660 
661 } // namespace
662 
664  MLIRContext &ctx) {
665  patterns.add<BroadcastOpToArmSMELowering, SplatOpToArmSMELowering,
666  TransferReadToArmSMELowering, TransferWriteToArmSMELowering,
667  TransposeOpToArmSMELowering, VectorLoadToArmSMELowering,
668  VectorStoreToArmSMELowering, VectorOuterProductToArmSMELowering,
669  VectorExtractToArmSMELowering, VectorInsertToArmSMELowering,
670  VectorPrintToArmSMELowering>(&ctx);
671 }
static void printOp(llvm::raw_ostream &os, Operation *op, OpPrintingFlags &flags)
Definition: Unit.cpp:19
Base type for affine expression.
Definition: AffineExpr.h:69
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
Definition: AffineMap.h:47
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
unsigned getNumDims() const
Definition: AffineMap.cpp:378
static AffineMap getPermutationMap(ArrayRef< unsigned > permutation, MLIRContext *context)
Returns an AffineMap representing a permutation.
Definition: AffineMap.cpp:248
bool isIdentity() const
Returns true if this affine map is an identity affine map.
Definition: AffineMap.cpp:329
IntegerAttr getIndexAttr(int64_t value)
Definition: Builders.cpp:124
IndexType getIndexType()
Definition: Builders.cpp:71
This class provides support for representing a failure result, or a valid value of type T.
Definition: LogicalResult.h:78
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:63
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
This class helps build Operations.
Definition: Builders.h:209
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
Definition: Builders.h:433
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Definition: Builders.h:400
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:464
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Definition: Operation.h:402
Location getLoc()
The source location the operation was defined or derived from.
Definition: Operation.h:223
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:785
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
Definition: PatternMatch.h:846
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
Definition: PatternMatch.h:718
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Definition: PatternMatch.h:536
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:381
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition: Value.cpp:20
This is a builder type that keeps local references to arguments.
Definition: BuiltinTypes.h:305
Builder & dropDim(unsigned pos)
Erase a dim from shape @pos.
Definition: BuiltinTypes.h:330
scf::ForOp createLoopOverTileSlices(PatternRewriter &rewriter, Location loc, Value initTile, std::function< Value(OpBuilder &, Location, Value, Value)> makeLoopBody)
Generates a for loop over ZA tile slices where the induction variable is the tile slice index and eac...
Definition: Utils.cpp:75
bool isValidSMETileVectorType(VectorType vType)
Returns true if vType is a valid vector type for an SME tile or false otherwise.
Definition: Utils.cpp:29
SmallVector< Value > getAsValues(OpBuilder &builder, Location loc, ArrayRef< OpFoldResult > foldResults)
Convert foldResults into Values.
Definition: VectorOps.cpp:308
Include the generated interface declarations.
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:62
void bindDims(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to DimExpr at positions: [0 .
Definition: AffineExpr.h:349
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:56
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
void populateVectorToArmSMEPatterns(RewritePatternSet &patterns, MLIRContext &ctx)
Collect a set of patterns to lower Vector ops to ArmSME ops that map to LLVM intrinsics.
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value.
Definition: LogicalResult.h:72
This class represents an efficient way to signal success or failure.
Definition: LogicalResult.h:26
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Definition: PatternMatch.h:358