MLIR  20.0.0git
VectorToArmSME.cpp
Go to the documentation of this file.
1 //===- VectorToArmSME.cpp - Conversion from Vector to the ArmSME dialect --===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
10 
15 #include "mlir/IR/BuiltinTypes.h"
16 #include "llvm/Support/Casting.h"
17 
18 using namespace mlir;
19 
20 namespace {
21 
22 /// Conversion pattern for vector.transfer_read.
23 ///
24 /// ---
25 ///
26 /// Example 1: op with identity permutation map to horizontal
27 /// arm_sme.tile_load:
28 ///
29 /// vector.transfer_read ... permutation_map: (d0, d1) -> (d0, d1)
30 ///
31 /// is converted to:
32 ///
33 /// arm_sme.tile_load ...
34 ///
35 /// ---
36 ///
37 /// Example 2: op with transpose permutation map to vertical arm_sme.tile_load
38 /// (in-flight transpose):
39 ///
40 /// vector.transfer_read ... permutation_map: (d0, d1) -> (d1, d0)
41 ///
42 /// is converted to:
43 ///
44 /// arm_sme.tile_load ... layout<vertical>
45 struct TransferReadToArmSMELowering
46  : public OpRewritePattern<vector::TransferReadOp> {
48 
49  LogicalResult matchAndRewrite(vector::TransferReadOp transferReadOp,
50  PatternRewriter &rewriter) const final {
51  // The permutation map must have two results.
52  if (transferReadOp.getTransferRank() != 2)
53  return rewriter.notifyMatchFailure(transferReadOp,
54  "not a 2 result permutation map");
55 
56  auto vectorType = transferReadOp.getVectorType();
57  if (!arm_sme::isValidSMETileVectorType(vectorType))
58  return rewriter.notifyMatchFailure(transferReadOp,
59  "not a valid vector type for SME");
60 
61  if (!llvm::isa<MemRefType>(transferReadOp.getSource().getType()))
62  return rewriter.notifyMatchFailure(transferReadOp, "not a memref source");
63 
64  // Out-of-bounds dims are not supported.
65  if (transferReadOp.hasOutOfBoundsDim())
66  return rewriter.notifyMatchFailure(transferReadOp,
67  "not inbounds transfer read");
68 
69  AffineMap map = transferReadOp.getPermutationMap();
70  if (!map.isPermutation())
71  return rewriter.notifyMatchFailure(transferReadOp,
72  "unsupported permutation map");
73 
74  // Note: For 2D vector types the only non-identity permutation is a simple
75  // transpose [1, 0].
76  bool transposed = !map.isIdentity();
77  arm_sme::TileSliceLayout layout =
78  transposed ? arm_sme::TileSliceLayout::Vertical
79  : arm_sme::TileSliceLayout::Horizontal;
80 
81  // Padding isn't optional for transfer_read, but is only used in the case
82  // of out-of-bounds accesses (not supported here) and/or masking. Mask is
83  // optional, if it's not present don't pass padding.
84  auto mask = transferReadOp.getMask();
85  auto padding = mask ? transferReadOp.getPadding() : nullptr;
86  rewriter.replaceOpWithNewOp<arm_sme::TileLoadOp>(
87  transferReadOp, vectorType, transferReadOp.getSource(),
88  transferReadOp.getIndices(), padding, mask, layout);
89 
90  return success();
91  }
92 };
93 
94 /// Conversion pattern for vector.transfer_write.
95 ///
96 /// ---
97 ///
98 /// Example 1: op with identity permutation map to horizontal
99 /// arm_sme.tile_store:
100 ///
101 /// vector.transfer_write %vector, %source[%c0, %c0]
102 /// {in_bounds = [true, true]} : vector<[16]x[16]xi8>, memref<?x?xi8>
103 ///
104 /// is converted to:
105 ///
106 /// arm_sme.tile_store %vector, %source[%c0, %c0] : memref<?x?xi8>,
107 /// vector<[16]x[16]xi8>
108 /// ---
109 ///
110 /// Example 2: op with transpose permutation map to vertical arm_sme.tile_store
111 /// (in-flight transpose):
112 ///
113 /// vector.transfer_write %vector, %source[%c0, %c0]
114 /// {permutation_map = affine_map<(d0, d1) -> (d1, d0)>,
115 /// in_bounds = [true, true]} : vector<[16]x[16]xi8>, memref<?x?xi8>
116 ///
117 /// is converted to:
118 ///
119 /// arm_sme.tile_store %vector, %source[%c0, %c0] layout<vertical>
120 /// : memref<?x?xi8>, vector<[16]x[16]xi8>
121 struct TransferWriteToArmSMELowering
122  : public OpRewritePattern<vector::TransferWriteOp> {
124 
125  LogicalResult matchAndRewrite(vector::TransferWriteOp writeOp,
126  PatternRewriter &rewriter) const final {
127  auto vType = writeOp.getVectorType();
129  return failure();
130 
131  if (!llvm::isa<MemRefType>(writeOp.getSource().getType()))
132  return failure();
133 
134  // Out-of-bounds dims are not supported.
135  if (writeOp.hasOutOfBoundsDim())
136  return rewriter.notifyMatchFailure(writeOp,
137  "not inbounds transfer write");
138 
139  AffineMap map = writeOp.getPermutationMap();
140  if (!map.isPermutation())
141  return rewriter.notifyMatchFailure(writeOp,
142  "unsupported permutation map");
143 
144  // Note: For 2D vector types the only non-identity permutation is a simple
145  // transpose [1, 0].
146  bool transposed = !map.isIdentity();
147  arm_sme::TileSliceLayout layout =
148  transposed ? arm_sme::TileSliceLayout::Vertical
149  : arm_sme::TileSliceLayout::Horizontal;
150 
151  rewriter.replaceOpWithNewOp<arm_sme::TileStoreOp>(
152  writeOp, writeOp.getVector(), writeOp.getSource(), writeOp.getIndices(),
153  writeOp.getMask(), layout);
154  return success();
155  }
156 };
157 
158 /// Conversion pattern for vector.load.
159 struct VectorLoadToArmSMELowering : public OpRewritePattern<vector::LoadOp> {
161 
162  LogicalResult matchAndRewrite(vector::LoadOp load,
163  PatternRewriter &rewriter) const override {
164  if (!arm_sme::isValidSMETileVectorType(load.getVectorType()))
165  return failure();
166 
167  rewriter.replaceOpWithNewOp<arm_sme::TileLoadOp>(
168  load, load.getVectorType(), load.getBase(), load.getIndices());
169 
170  return success();
171  }
172 };
173 
174 /// Conversion pattern for vector.store.
175 struct VectorStoreToArmSMELowering : public OpRewritePattern<vector::StoreOp> {
177 
178  LogicalResult matchAndRewrite(vector::StoreOp store,
179  PatternRewriter &rewriter) const override {
180  if (!arm_sme::isValidSMETileVectorType(store.getVectorType()))
181  return failure();
182 
183  rewriter.replaceOpWithNewOp<arm_sme::TileStoreOp>(
184  store, store.getValueToStore(), store.getBase(), store.getIndices());
185 
186  return success();
187  }
188 };
189 
190 /// Conversion pattern for vector.broadcast.
191 ///
192 /// Example:
193 ///
194 /// %broadcast_to_tile = vector.broadcast %src : i32 to vector<[4]x[4]xi32>
195 ///
196 /// is converted to:
197 ///
198 /// %broadcast_to_1d = vector.broadcast %src : i32 to vector<[4]xi32>
199 /// %broadcast_to_tile = scf.for %tile_slice_index = %c0 to %num_tile_slices
200 /// step %c1 iter_args(%iter_tile = %init_tile) -> (vector<[4]x[4]xi32>)
201 /// {
202 /// %tile_update = arm_sme.insert_tile_slice
203 /// %broadcast_to_1d, %iter_tile[%tile_slice_index] :
204 /// vector<[4]xi32> into vector<[4]x[4]xi32>
205 /// scf.yield %tile_update : vector<[4]x[4]xi32>
206 /// }
207 ///
208 /// Supports scalar, 0-d vector, and 1-d vector broadcasts.
209 struct BroadcastOpToArmSMELowering
210  : public OpRewritePattern<vector::BroadcastOp> {
212 
213  LogicalResult matchAndRewrite(vector::BroadcastOp broadcastOp,
214  PatternRewriter &rewriter) const final {
215  auto tileType = broadcastOp.getResultVectorType();
216  if (!tileType || !arm_sme::isValidSMETileVectorType(tileType))
217  return failure();
218 
219  auto loc = broadcastOp.getLoc();
220 
221  auto srcType = broadcastOp.getSourceType();
222  auto srcVectorType = dyn_cast<VectorType>(srcType);
223 
224  Value broadcastOp1D;
225  if (srcType.isIntOrFloat() ||
226  (srcVectorType && (srcVectorType.getRank() == 0))) {
227  // Broadcast scalar or 0-d vector to 1-d vector.
228  VectorType tileSliceType = VectorType::Builder(tileType).dropDim(0);
229  broadcastOp1D = rewriter.create<vector::BroadcastOp>(
230  loc, tileSliceType, broadcastOp.getSource());
231  } else if (srcVectorType && (srcVectorType.getRank() == 1))
232  // Value to broadcast is already a 1-d vector, nothing to do.
233  broadcastOp1D = broadcastOp.getSource();
234  else
235  return failure();
236 
237  auto initTile = rewriter.create<arm_sme::GetTileOp>(loc, tileType);
238 
239  auto makeLoopBody = [&](OpBuilder &b, Location loc, Value tileSliceIndex,
240  Value currentTile) {
241  // Create 'arm_sme.insert_tile_slice' to broadcast the value
242  // to each tile slice.
243  auto nextTile = b.create<arm_sme::InsertTileSliceOp>(
244  loc, tileType, broadcastOp1D, currentTile, tileSliceIndex);
245  return nextTile.getResult();
246  };
247 
248  // Create a loop over ZA tile slices.
249  auto forOp =
250  createLoopOverTileSlices(rewriter, loc, initTile, makeLoopBody);
251 
252  rewriter.replaceOp(broadcastOp, forOp.getResult(0));
253 
254  return success();
255  }
256 };
257 
258 /// Conversion pattern for vector.splat.
259 ///
260 /// Example:
261 ///
262 /// %splat_to_tile = vector.splat %src : i32 to vector<[4]x[4]xi32>
263 ///
264 /// is converted to:
265 ///
266 /// %broadcast_to_1d = vector.broadcast %src : i32 to vector<[4]xi32>
267 /// %broadcast_to_tile = scf.for %tile_slice_index = %c0 to %num_tile_slices
268 /// step %c1 iter_args(%iter_tile = %init_tile) -> (vector<[4]x[4]xi32>)
269 /// {
270 /// %tile_update = arm_sme.insert_tile_slice
271 /// %broadcast_to_1d, %iter_tile[%tile_slice_index] :
272 /// vector<[4]xi32> into vector<[4]x[4]xi32>
273 /// scf.yield %tile_update : vector<[4]x[4]xi32>
274 /// }
275 ///
276 /// This is identical to vector.broadcast of a scalar.
277 struct SplatOpToArmSMELowering : public OpRewritePattern<vector::SplatOp> {
279 
280  LogicalResult matchAndRewrite(vector::SplatOp splatOp,
281  PatternRewriter &rewriter) const final {
282  auto tileType = splatOp.getResult().getType();
283  if (!tileType || !arm_sme::isValidSMETileVectorType(tileType))
284  return failure();
285 
286  auto loc = splatOp.getLoc();
287  auto srcType = splatOp.getOperand().getType();
288 
289  assert(srcType.isIntOrFloat() && "Invalid source type for vector.splat");
290  // Avoid unused-variable warning when building without assertions.
291  (void)srcType;
292 
293  // First, broadcast the scalar to a 1-d vector.
294  VectorType tileSliceType = VectorType::Builder(tileType).dropDim(0);
295  Value broadcastOp1D = rewriter.create<vector::BroadcastOp>(
296  loc, tileSliceType, splatOp.getInput());
297 
298  auto initTile = rewriter.create<arm_sme::GetTileOp>(loc, tileType);
299 
300  auto makeLoopBody = [&](OpBuilder &b, Location loc, Value tileSliceIndex,
301  Value currentTile) {
302  auto nextTile = b.create<arm_sme::InsertTileSliceOp>(
303  loc, tileType, broadcastOp1D, currentTile, tileSliceIndex);
304  return nextTile.getResult();
305  };
306 
307  // Next, create a loop over ZA tile slices and "move" the generated 1-d
308  // vector to each slice.
309  auto forOp =
310  createLoopOverTileSlices(rewriter, loc, initTile, makeLoopBody);
311 
312  rewriter.replaceOp(splatOp, forOp.getResult(0));
313 
314  return success();
315  }
316 };
317 
318 /// Conversion pattern for vector.transpose.
319 ///
320 /// Stores the input tile to memory and reloads vertically.
321 ///
322 /// Example:
323 ///
324 /// %transposed_src = vector.transpose %src, [1, 0]
325 /// : vector<[4]x[4]xi32> to vector<[4]x[4]xi32>
326 ///
327 /// is converted to:
328 ///
329 /// %alloca = memref.alloca(%svl_s, %svl_s) : memref<?x?xi32>
330 /// %arm_sme.tile_store %src, <hor>, %alloca[%c0, %c0]
331 /// : memref<?x?xi32>, vector<[4]x[4]xi32>
332 /// %transposed_src = arm_sme.tile_load %alloca[%c0, %c0]
333 /// layout<vertical> : memref<?x?xi32>, vector<[4]x[4]xi32>
334 ///
335 /// NOTE: Tranposing via memory is obviously expensive, the current intention
336 /// is to avoid the transpose if possible, this is therefore intended as a
337 /// fallback and to provide base support for Vector ops. If it turns out
338 /// transposes can't be avoided then this should be replaced with a more optimal
339 /// implementation, perhaps with tile <-> vector (MOVA) ops.
340 struct TransposeOpToArmSMELowering
341  : public OpRewritePattern<vector::TransposeOp> {
343 
344  LogicalResult matchAndRewrite(vector::TransposeOp transposeOp,
345  PatternRewriter &rewriter) const final {
346  auto tileType = transposeOp.getResultVectorType();
347  if (!tileType || !arm_sme::isValidSMETileVectorType(tileType))
348  return failure();
349 
350  // Bail unless this is a true 2-D matrix transpose.
351  ArrayRef<int64_t> permutation = transposeOp.getPermutation();
352  if (permutation[0] != 1 || permutation[1] != 0)
353  return failure();
354 
355  auto loc = transposeOp.getLoc();
356  Value input = transposeOp.getVector();
357 
358  if (auto xferOp = input.getDefiningOp<vector::TransferReadOp>();
359  xferOp && xferOp->hasOneUse()) {
360  // Fold transpose into transfer_read to enable in-flight transpose when
361  // converting to arm_sme.tile_load.
362  rewriter.modifyOpInPlace(xferOp, [&]() {
363  xferOp->setAttr(xferOp.getPermutationMapAttrName(),
365  permutation, transposeOp.getContext())));
366  });
367  rewriter.replaceOp(transposeOp, xferOp);
368  return success();
369  }
370 
371  // Allocate buffer to store input tile to.
372  Value vscale =
373  rewriter.create<vector::VectorScaleOp>(loc, rewriter.getIndexType());
374  Value minTileSlices = rewriter.create<arith::ConstantOp>(
375  loc, rewriter.getIndexAttr(tileType.getDimSize(0)));
376  Value c0 =
377  rewriter.create<arith::ConstantOp>(loc, rewriter.getIndexAttr(0));
378  Value numTileSlices =
379  rewriter.create<arith::MulIOp>(loc, vscale, minTileSlices);
380  auto bufferType =
381  MemRefType::get({ShapedType::kDynamic, ShapedType::kDynamic},
382  tileType.getElementType());
383  auto buffer = rewriter.create<memref::AllocaOp>(
384  loc, bufferType, ValueRange{numTileSlices, numTileSlices});
385 
386  // Store input tile.
387  auto tileStoreOp = rewriter.create<arm_sme::TileStoreOp>(
388  loc, input, buffer, ValueRange{c0, c0});
389 
390  // Reload input tile vertically.
391  rewriter.replaceOpWithNewOp<arm_sme::TileLoadOp>(
392  transposeOp, tileType, tileStoreOp.getBase(), tileStoreOp.getIndices(),
393  arm_sme::TileSliceLayout::Vertical);
394 
395  return success();
396  }
397 };
398 
399 /// Conversion pattern for vector.outerproduct.
400 ///
401 /// If the vector.outerproduct is masked (and the mask is from a
402 /// vector.create_mask), then the mask is decomposed into two 1-D masks for the
403 /// operands.
404 ///
405 /// Example:
406 ///
407 /// %mask = vector.create_mask %dimA, %dimB : vector<[4]x[4]xi1>
408 /// %result = vector.mask %mask {
409 /// vector.outerproduct %vecA, %vecB
410 /// : vector<[4]xf32>, vector<[4]xf32>
411 /// } : vector<[4]x[4]xi1> -> vector<[4]x[4]xf32>
412 ///
413 /// is converted to:
414 ///
415 /// %maskA = vector.create_mask %dimA : vector<[4]xi1>
416 /// %maskB = vector.create_mask %dimB : vector<[4]xi1>
417 /// %result = arm_sme.outerproduct %vecA, %vecB masks(%maskA, %maskB)
418 /// : vector<[4]xf32>, vector<[4]xf32>
419 ///
420 /// Unmasked outerproducts can be directly replaced with the arm_sme op.
421 ///
422 /// Example:
423 ///
424 /// %result = vector.outerproduct %vecA, %vecB
425 /// : vector<[4]xf32>, vector<[4]xf32>
426 ///
427 /// is converted to:
428 ///
429 /// %result = arm_sme.outerproduct %vecA, %vecB
430 /// : vector<[4]xf32>, vector<[4]xf32>
431 ///
432 struct VectorOuterProductToArmSMELowering
433  : public OpRewritePattern<vector::OuterProductOp> {
434 
436 
437  LogicalResult matchAndRewrite(vector::OuterProductOp outerProductOp,
438  PatternRewriter &rewriter) const override {
439 
440  // We don't yet support lowering AXPY operations to SME. These could be
441  // lowered by masking out all but the first element of the LHS.
442  if (!isa<VectorType>(outerProductOp.getOperandTypeRHS()))
443  return rewriter.notifyMatchFailure(outerProductOp,
444  "AXPY operations not supported");
445 
447  outerProductOp.getResultVectorType()))
448  return rewriter.notifyMatchFailure(
449  outerProductOp, "outer product does not fit into SME tile");
450 
451  auto kind = outerProductOp.getKind();
452  if (kind != vector::CombiningKind::ADD)
453  return rewriter.notifyMatchFailure(
454  outerProductOp,
455  "unsupported kind (lowering to SME only supports ADD at the moment)");
456 
457  Value lhsMask = {};
458  Value rhsMask = {};
459  Operation *rootOp = outerProductOp;
460  auto loc = outerProductOp.getLoc();
461  if (outerProductOp.isMasked()) {
462  auto maskOp = outerProductOp.getMaskingOp();
463  rewriter.setInsertionPoint(maskOp);
464  rootOp = maskOp;
465  auto operandMasks = decomposeResultMask(loc, maskOp.getMask(), rewriter);
466  if (failed(operandMasks))
467  return failure();
468  std::tie(lhsMask, rhsMask) = *operandMasks;
469  }
470 
471  rewriter.replaceOpWithNewOp<arm_sme::OuterProductOp>(
472  rootOp, outerProductOp.getResultVectorType(), outerProductOp.getLhs(),
473  outerProductOp.getRhs(), lhsMask, rhsMask, outerProductOp.getAcc());
474 
475  return success();
476  }
477 
478  static FailureOr<std::pair<Value, Value>>
479  decomposeResultMask(Location loc, Value mask, PatternRewriter &rewriter) {
480  // Attempt to extract masks from vector.create_mask.
481  // TODO: Add support for other mask sources.
482  auto createMaskOp = mask.getDefiningOp<vector::CreateMaskOp>();
483  if (!createMaskOp)
484  return failure();
485 
486  auto maskType = createMaskOp.getVectorType();
487  Value lhsMaskDim = createMaskOp.getOperand(0);
488  Value rhsMaskDim = createMaskOp.getOperand(1);
489 
490  VectorType operandMaskType = VectorType::Builder(maskType).dropDim(0);
491  Value lhsMask =
492  rewriter.create<vector::CreateMaskOp>(loc, operandMaskType, lhsMaskDim);
493  Value rhsMask =
494  rewriter.create<vector::CreateMaskOp>(loc, operandMaskType, rhsMaskDim);
495 
496  return std::make_pair(lhsMask, rhsMask);
497  }
498 };
499 
500 /// Lower `vector.extract` using `arm_sme.extract_tile_slice`.
501 ///
502 /// Example:
503 /// ```
504 /// %el = vector.extract %tile[%row, %col]: i32 from vector<[4]x[4]xi32>
505 /// ```
506 /// Becomes:
507 /// ```
508 /// %slice = arm_sme.extract_tile_slice %tile[%row]
509 /// : vector<[4]xi32> from vector<[4]x[4]xi32>
510 /// %el = vector.extract %slice[%col] : i32 from vector<[4]xi32>
511 /// ```
512 struct VectorExtractToArmSMELowering
513  : public OpRewritePattern<vector::ExtractOp> {
515 
516  LogicalResult matchAndRewrite(vector::ExtractOp extractOp,
517  PatternRewriter &rewriter) const override {
518  VectorType sourceType = extractOp.getSourceVectorType();
519  if (!arm_sme::isValidSMETileVectorType(sourceType))
520  return failure();
521 
522  auto loc = extractOp.getLoc();
523  auto position = extractOp.getMixedPosition();
524 
525  Value sourceVector = extractOp.getVector();
526 
527  // Extract entire vector. Should be handled by folder, but just to be safe.
528  if (position.empty()) {
529  rewriter.replaceOp(extractOp, sourceVector);
530  return success();
531  }
532 
533  Value sliceIndex = vector::getAsValues(rewriter, loc, position[0]).front();
534  auto extractTileSlice = rewriter.create<arm_sme::ExtractTileSliceOp>(
535  loc, sourceVector, sliceIndex);
536 
537  if (position.size() == 1) {
538  // Single index case: Extracts a 1D slice.
539  rewriter.replaceOp(extractOp, extractTileSlice);
540  return success();
541  }
542 
543  // Two indices case: Extracts a single element.
544  assert(position.size() == 2);
545  rewriter.replaceOpWithNewOp<vector::ExtractOp>(extractOp, extractTileSlice,
546  position[1]);
547 
548  return success();
549  }
550 };
551 
552 /// Lower `vector.insert` using `arm_sme.insert_tile_slice` and
553 /// `arm_sme.extract_tile_slice`.
554 ///
555 /// Example:
556 /// ```
557 /// %new_tile = vector.insert %el, %tile[%row, %col]
558 /// : i32 into vector<[4]x[4]xi32>
559 /// ```
560 /// Becomes:
561 /// ```
562 /// %slice = arm_sme.extract_tile_slice %tile[%row]
563 /// : vector<[4]xi32> from vector<[4]x[4]xi32>
564 /// %new_slice = vector.insert %el, %slice[%col] : i32 into vector<[4]xi32>
565 /// %new_tile = arm_sme.insert_tile_slice %new_slice, %tile[%row]
566 /// : vector<[4]xi32> into vector<[4]x[4]xi32>
567 /// ```
568 struct VectorInsertToArmSMELowering
569  : public OpRewritePattern<vector::InsertOp> {
571 
572  LogicalResult matchAndRewrite(vector::InsertOp insertOp,
573  PatternRewriter &rewriter) const override {
574  VectorType resultType = insertOp.getResult().getType();
575 
576  if (!arm_sme::isValidSMETileVectorType(resultType))
577  return failure();
578 
579  auto loc = insertOp.getLoc();
580  auto position = insertOp.getMixedPosition();
581 
582  Value source = insertOp.getSource();
583 
584  // Overwrite entire vector with value. Should be handled by folder, but
585  // just to be safe.
586  if (position.empty()) {
587  rewriter.replaceOp(insertOp, source);
588  return success();
589  }
590 
591  Value tileSlice = source;
592  Value sliceIndex = vector::getAsValues(rewriter, loc, position[0]).front();
593  if (position.size() == 2) {
594  // Two indices case: Insert single element into tile.
595  // We need to first extract the existing slice and update the element.
596  tileSlice = rewriter.create<arm_sme::ExtractTileSliceOp>(
597  loc, insertOp.getDest(), sliceIndex);
598  tileSlice = rewriter.create<vector::InsertOp>(loc, source, tileSlice,
599  position[1]);
600  }
601 
602  // Insert the slice into the destination tile.
603  rewriter.replaceOpWithNewOp<arm_sme::InsertTileSliceOp>(
604  insertOp, tileSlice, insertOp.getDest(), sliceIndex);
605  return success();
606  }
607 };
608 
609 /// Lowers `vector.print` of a tile into a loop over the rows of the tile,
610 /// extracting them via `arm_sme.extract_tile_slice`, then printing with
611 /// a 1D `vector.print`.
612 ///
613 /// BEFORE:
614 /// ```mlir
615 /// vector.print %tile : vector<[4]x[4]xf32>
616 /// ```
617 /// AFTER:
618 /// ```mlir
619 /// %c0 = arith.constant 0 : index
620 /// %c1 = arith.constant 1 : index
621 /// %c4 = arith.constant 4 : index
622 /// %vscale = vector.vscale
623 /// %svl_s = arith.muli %c4, %vscale : index
624 /// scf.for %i = %c0 to %svl_s step %c1 {
625 /// %tile_slice = arm_sme.extract_tile_slice %tile[%i]
626 /// : vector<[4]xf32> from vector<[4]x[4]xf32>
627 /// vector.print %tile_slice : vector<[4]xf32>
628 /// }
629 /// ```
630 struct VectorPrintToArmSMELowering : public OpRewritePattern<vector::PrintOp> {
632 
633  LogicalResult matchAndRewrite(vector::PrintOp printOp,
634  PatternRewriter &rewriter) const override {
635  if (!printOp.getSource())
636  return failure();
637 
638  VectorType vectorType = dyn_cast<VectorType>(printOp.getPrintType());
639  if (!vectorType || !arm_sme::isValidSMETileVectorType(vectorType))
640  return failure();
641 
642  auto loc = printOp.getLoc();
643 
644  // Create a loop over the rows of the tile.
645  auto vscale = rewriter.create<vector::VectorScaleOp>(loc);
646  auto minTileRows =
647  rewriter.create<arith::ConstantIndexOp>(loc, vectorType.getDimSize(0));
648  auto lowerBound = rewriter.create<arith::ConstantIndexOp>(loc, 0);
649  auto upperBound = rewriter.create<arith::MulIOp>(loc, minTileRows, vscale);
650  auto step = rewriter.create<arith::ConstantIndexOp>(loc, 1);
651  auto forOp = rewriter.create<scf::ForOp>(loc, lowerBound, upperBound, step);
652  {
653  // Loop body.
654  rewriter.setInsertionPointToStart(forOp.getBody());
655  // Extract the current row from the tile.
656  Value rowIndex = forOp.getInductionVar();
657  auto tileSlice = rewriter.create<arm_sme::ExtractTileSliceOp>(
658  loc, printOp.getSource(), rowIndex);
659  // Print the row with a 1D vector.print.
660  rewriter.create<vector::PrintOp>(loc, tileSlice,
661  printOp.getPunctuation());
662  }
663 
664  rewriter.eraseOp(printOp);
665  return success();
666  }
667 };
668 
669 /// Folds a ExtractTileSliceOp + TransferWriteOp to a StoreTileSliceOp.
670 ///
671 /// BEFORE:
672 /// ```mlir
673 /// %slice = arm_sme.extract_tile_slice %tile[%index]
674 /// : vector<[4]xf32> from vector<[4]x[4]xf32>
675 /// vector.transfer_write %slice, %memref[%i, %j], %mask {in_bounds = [true]}
676 /// : vector<[4]xf32>, memref<?x?xf32>
677 /// ```
678 /// AFTER:
679 /// ```mlir
680 /// arm_sme.store_tile_slice %tile, %index, %mask, %memref[%i, %j]
681 /// : memref<?x?xf32>, vector<[4]xi1>, vector<[4]x[4]xf32>
682 /// ```
683 struct FoldTransferWriteOfExtractTileSlice
684  : public OpRewritePattern<vector::TransferWriteOp> {
686 
687  LogicalResult matchAndRewrite(vector::TransferWriteOp writeOp,
688  PatternRewriter &rewriter) const final {
689  if (!isa<MemRefType>(writeOp.getSource().getType()))
690  return rewriter.notifyMatchFailure(writeOp, "destination not a memref");
691 
692  if (writeOp.hasOutOfBoundsDim())
693  return rewriter.notifyMatchFailure(writeOp,
694  "not inbounds transfer write");
695 
696  auto extractTileSlice =
697  writeOp.getVector().getDefiningOp<arm_sme::ExtractTileSliceOp>();
698  if (!extractTileSlice)
699  return rewriter.notifyMatchFailure(
700  writeOp, "vector to store not from ExtractTileSliceOp");
701 
702  AffineMap map = writeOp.getPermutationMap();
703  if (!map.isMinorIdentity())
704  return rewriter.notifyMatchFailure(writeOp,
705  "unsupported permutation map");
706 
707  Value mask = writeOp.getMask();
708  if (!mask) {
709  auto maskType = writeOp.getVectorType().clone(rewriter.getI1Type());
710  mask = rewriter.create<arith::ConstantOp>(
711  writeOp.getLoc(), maskType, DenseElementsAttr::get(maskType, true));
712  }
713 
714  rewriter.replaceOpWithNewOp<arm_sme::StoreTileSliceOp>(
715  writeOp, extractTileSlice.getTile(),
716  extractTileSlice.getTileSliceIndex(), mask, writeOp.getSource(),
717  writeOp.getIndices(), extractTileSlice.getLayout());
718  return success();
719  }
720 };
721 
722 /// Lower a `vector.extract` from a 2-D scalable `vector.create_mask` to
723 /// `arm_sve.psel`. Note: While psel is under ArmSVE it requires SME (or
724 /// SVE 2.1), so this is currently the most logical place for this lowering.
725 ///
726 /// Example:
727 /// ```mlir
728 /// %mask = vector.create_mask %a, %b : vector<[4]x[8]xi1>
729 /// %slice = vector.extract %mask[%index]
730 /// : vector<[8]xi1> from vector<[4]x[8]xi1>
731 /// ```
732 /// Becomes:
733 /// ```
734 /// %mask_rows = vector.create_mask %a : vector<[4]xi1>
735 /// %mask_cols = vector.create_mask %b : vector<[8]xi1>
736 /// %slice = arm_sve.psel %mask_cols, %mask_rows[%index]
737 /// : vector<[8]xi1>, vector<[4]xi1>
738 /// ```
739 struct ExtractFromCreateMaskToPselLowering
740  : public OpRewritePattern<vector::ExtractOp> {
742 
743  LogicalResult matchAndRewrite(vector::ExtractOp extractOp,
744  PatternRewriter &rewriter) const override {
745  if (extractOp.getNumIndices() != 1)
746  return rewriter.notifyMatchFailure(extractOp, "not single extract index");
747 
748  auto resultType = extractOp.getResult().getType();
749  auto resultVectorType = dyn_cast<VectorType>(resultType);
750  if (!resultVectorType)
751  return rewriter.notifyMatchFailure(extractOp, "result not VectorType");
752 
753  auto createMaskOp =
754  extractOp.getVector().getDefiningOp<vector::CreateMaskOp>();
755  if (!createMaskOp)
756  return rewriter.notifyMatchFailure(extractOp, "source not CreateMaskOp");
757 
758  auto maskType = createMaskOp.getVectorType();
759  if (maskType.getRank() != 2 || !maskType.allDimsScalable())
760  return rewriter.notifyMatchFailure(createMaskOp, "not 2-D scalable mask");
761 
762  auto isSVEPredicateSize = [](int64_t size) {
763  return size > 0 && size <= 16 && llvm::isPowerOf2_32(uint32_t(size));
764  };
765 
766  auto rowsBaseSize = maskType.getDimSize(0);
767  auto colsBaseSize = maskType.getDimSize(1);
768  if (!isSVEPredicateSize(rowsBaseSize) || !isSVEPredicateSize(colsBaseSize))
769  return rewriter.notifyMatchFailure(
770  createMaskOp, "mask dimensions not SVE predicate-sized");
771 
772  auto loc = extractOp.getLoc();
773  VectorType rowMaskType = VectorType::Builder(maskType).dropDim(1);
774  VectorType colMaskType = VectorType::Builder(maskType).dropDim(0);
775 
776  // Create the two 1-D masks at the location of the 2-D create_mask (which is
777  // usually outside a loop). This prevents the need for later hoisting.
778  rewriter.setInsertionPoint(createMaskOp);
779  auto rowMask = rewriter.create<vector::CreateMaskOp>(
780  loc, rowMaskType, createMaskOp.getOperand(0));
781  auto colMask = rewriter.create<vector::CreateMaskOp>(
782  loc, colMaskType, createMaskOp.getOperand(1));
783 
784  rewriter.setInsertionPoint(extractOp);
785  auto position =
786  vector::getAsValues(rewriter, loc, extractOp.getMixedPosition());
787  rewriter.replaceOpWithNewOp<arm_sve::PselOp>(extractOp, colMask, rowMask,
788  position[0]);
789  return success();
790  }
791 };
792 
793 } // namespace
794 
796  MLIRContext &ctx) {
797  patterns.add<BroadcastOpToArmSMELowering, SplatOpToArmSMELowering,
798  TransferReadToArmSMELowering, TransferWriteToArmSMELowering,
799  TransposeOpToArmSMELowering, VectorLoadToArmSMELowering,
800  VectorStoreToArmSMELowering, VectorOuterProductToArmSMELowering,
801  VectorExtractToArmSMELowering, VectorInsertToArmSMELowering,
802  VectorPrintToArmSMELowering, FoldTransferWriteOfExtractTileSlice,
803  ExtractFromCreateMaskToPselLowering>(&ctx);
804 }
static void printOp(llvm::raw_ostream &os, Operation *op, OpPrintingFlags &flags)
Definition: Unit.cpp:19
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
Definition: AffineMap.h:46
bool isMinorIdentity() const
Returns true if this affine map is a minor identity, i.e.
Definition: AffineMap.cpp:155
static AffineMap getPermutationMap(ArrayRef< unsigned > permutation, MLIRContext *context)
Returns an AffineMap representing a permutation.
Definition: AffineMap.cpp:264
bool isIdentity() const
Returns true if this affine map is an identity affine map.
Definition: AffineMap.cpp:345
bool isPermutation() const
Returns true if the AffineMap represents a symbol-less permutation map.
Definition: AffineMap.cpp:648
IntegerAttr getIndexAttr(int64_t value)
Definition: Builders.cpp:148
IntegerType getI1Type()
Definition: Builders.cpp:97
IndexType getIndexType()
Definition: Builders.cpp:95
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:66
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
This class helps build Operations.
Definition: Builders.h:215
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
Definition: Builders.h:439
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Definition: Builders.h:406
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:488
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
bool hasOneUse()
Returns true if this operation has exactly one use.
Definition: Operation.h:845
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Definition: Operation.h:402
Location getLoc()
The source location the operation was defined or derived from.
Definition: Operation.h:223
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:785
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
Definition: PatternMatch.h:847
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
Definition: PatternMatch.h:718
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
Definition: PatternMatch.h:630
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Definition: PatternMatch.h:536
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:381
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition: Value.cpp:20
This is a builder type that keeps local references to arguments.
Definition: BuiltinTypes.h:317
Builder & dropDim(unsigned pos)
Erase a dim from shape @pos.
Definition: BuiltinTypes.h:342
scf::ForOp createLoopOverTileSlices(PatternRewriter &rewriter, Location loc, Value initTile, std::function< Value(OpBuilder &, Location, Value, Value)> makeLoopBody)
Generates a for loop over ZA tile slices where the induction variable is the tile slice index and eac...
Definition: Utils.cpp:75
bool isValidSMETileVectorType(VectorType vType)
Returns true if vType is a valid vector type for an SME tile or false otherwise.
Definition: Utils.cpp:29
SmallVector< Value > getAsValues(OpBuilder &builder, Location loc, ArrayRef< OpFoldResult > foldResults)
Convert foldResults into Values.
Definition: VectorOps.cpp:338
Include the generated interface declarations.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
void populateVectorToArmSMEPatterns(RewritePatternSet &patterns, MLIRContext &ctx)
Collect a set of patterns to lower Vector ops to ArmSME ops that map to LLVM intrinsics.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Definition: PatternMatch.h:358