MLIR  18.0.0git
VectorToArmSME.cpp
Go to the documentation of this file.
1 //===- VectorToArmSME.cpp - Conversion from Vector to the ArmSME dialect --===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
10 
14 #include "mlir/IR/BuiltinTypes.h"
15 #include "llvm/Support/Casting.h"
16 
17 using namespace mlir;
18 
19 /// Returns true if 'val' is a splat of zero, false otherwise.
20 static bool isSplatZero(Type elemType, DenseElementsAttr val) {
21  if (llvm::isa<FloatType>(elemType))
22  return val && val.isSplat() && val.getSplatValue<APFloat>().isZero();
23  if (llvm::isa<IntegerType>(elemType))
24  return val && val.isSplat() && val.getSplatValue<APInt>().isZero();
25  return false;
26 }
27 
28 /// Generates a for loop over ZA tile slices where the induction variable is
29 /// the tile slice index. Sets the IR Builder insertion point as the loop body.
30 /// Callers of this method are responsible for restoring it if needed.
31 static scf::ForOp getLoopOverTileSlices(PatternRewriter &rewriter, Location loc,
32  Type eltType) {
33  auto step = rewriter.create<arith::ConstantIndexOp>(loc, 1);
34  auto minTileSlices = rewriter.create<arith::ConstantIndexOp>(
36  auto vscale =
37  rewriter.create<vector::VectorScaleOp>(loc, rewriter.getIndexType());
38  auto lowerBound = rewriter.create<arith::ConstantIndexOp>(loc, 0);
39  auto numTileSlices =
40  rewriter.create<arith::MulIOp>(loc, minTileSlices, vscale);
41  auto forOp =
42  rewriter.create<scf::ForOp>(loc, lowerBound, numTileSlices, step);
43  rewriter.setInsertionPointToStart(forOp.getBody());
44  return forOp;
45 }
46 
47 /// Returns a tile of the given vector type.
48 static arm_sme::CastTileToVector
50  VectorType type) {
51  unsigned tileElementWidth = type.getElementType().getIntOrFloatBitWidth();
52 
53  // Create 'arm_sme.get_tile' op.
54  auto tileId = rewriter.create<arm_sme::GetTileID>(
55  loc, rewriter.getIntegerType(tileElementWidth));
56 
57  // Create `arm_sme.cast_tile_to_vector` to cast tile ID to a vector type.
58  return rewriter.create<arm_sme::CastTileToVector>(loc, type, tileId);
59 }
60 
61 namespace {
62 
63 /// Conversion pattern for vector.transfer_read.
64 ///
65 /// ---
66 ///
67 /// Example 1: op with identity permutation map to horizontal
68 /// arm_sme.tile_load:
69 ///
70 /// vector.transfer_read ... permutation_map: (d0, d1) -> (d0, d1)
71 ///
72 /// is converted to:
73 ///
74 /// arm_sme.tile_load ...
75 ///
76 /// ---
77 ///
78 /// Example 2: op with transpose permutation map to vertical arm_sme.tile_load
79 /// (in-flight transpose):
80 ///
81 /// vector.transfer_read ... permutation_map: (d0, d1) -> (d1, d0)
82 ///
83 /// is converted to:
84 ///
85 /// arm_sme.tile_load ... layout<vertical>
86 struct TransferReadToArmSMELowering
87  : public OpRewritePattern<vector::TransferReadOp> {
89 
90  LogicalResult matchAndRewrite(vector::TransferReadOp transferReadOp,
91  PatternRewriter &rewriter) const final {
92  // The permutation map must have two results.
93  if (transferReadOp.getTransferRank() != 2)
94  return rewriter.notifyMatchFailure(transferReadOp,
95  "not a 2 result permutation map");
96 
97  auto vectorType = transferReadOp.getVectorType();
98  if (!arm_sme::isValidSMETileVectorType(vectorType))
99  return rewriter.notifyMatchFailure(transferReadOp,
100  "not a valid vector type for SME");
101 
102  if (!llvm::isa<MemRefType>(transferReadOp.getSource().getType()))
103  return rewriter.notifyMatchFailure(transferReadOp, "not a memref source");
104 
105  // Out-of-bounds dims are not supported.
106  if (transferReadOp.hasOutOfBoundsDim())
107  return rewriter.notifyMatchFailure(transferReadOp,
108  "not inbounds transfer read");
109 
110  arm_sme::TileSliceLayout layout;
111 
112  AffineExpr d0, d1;
113  bindDims(transferReadOp.getContext(), d0, d1);
114  AffineMap map = transferReadOp.getPermutationMap();
115  if (map.isIdentity())
116  layout = arm_sme::TileSliceLayout::Horizontal;
117  else if (map == AffineMap::get(map.getNumDims(), 0, {d1, d0},
118  transferReadOp.getContext()))
119  layout = arm_sme::TileSliceLayout::Vertical;
120  else
121  return rewriter.notifyMatchFailure(transferReadOp,
122  "unsupported permutation map");
123 
124  // Padding isn't optional for transfer_read, but is only used in the case
125  // of out-of-bounds accesses (not supported here) and/or masking. Mask is
126  // optional, if it's not present don't pass padding.
127  auto mask = transferReadOp.getMask();
128  auto padding = mask ? transferReadOp.getPadding() : nullptr;
129  rewriter.replaceOpWithNewOp<arm_sme::TileLoadOp>(
130  transferReadOp, vectorType, transferReadOp.getSource(),
131  transferReadOp.getIndices(), padding, mask, layout);
132 
133  return success();
134  }
135 };
136 
137 /// Conversion pattern for vector.transfer_write.
138 ///
139 /// ---
140 ///
141 /// Example 1: op with identity permutation map to horizontal
142 /// arm_sme.tile_store:
143 ///
144 /// vector.transfer_write %vector, %source[%c0, %c0]
145 /// {in_bounds = [true, true]} : vector<[16]x[16]xi8>, memref<?x?xi8>
146 ///
147 /// is converted to:
148 ///
149 /// arm_sme.tile_store %vector, %source[%c0, %c0] : memref<?x?xi8>,
150 /// vector<[16]x[16]xi8>
151 /// ---
152 ///
153 /// Example 2: op with transpose permutation map to vertical arm_sme.tile_store
154 /// (in-flight transpose):
155 ///
156 /// vector.transfer_write %vector, %source[%c0, %c0]
157 /// {permutation_map = affine_map<(d0, d1) -> (d1, d0)>,
158 /// in_bounds = [true, true]} : vector<[16]x[16]xi8>, memref<?x?xi8>
159 ///
160 /// is converted to:
161 ///
162 /// arm_sme.tile_store %vector, %source[%c0, %c0] layout<vertical>
163 /// : memref<?x?xi8>, vector<[16]x[16]xi8>
164 struct TransferWriteToArmSMELowering
165  : public OpRewritePattern<vector::TransferWriteOp> {
167 
168  LogicalResult matchAndRewrite(vector::TransferWriteOp writeOp,
169  PatternRewriter &rewriter) const final {
170  auto vType = writeOp.getVectorType();
172  return failure();
173 
174  if (!llvm::isa<MemRefType>(writeOp.getSource().getType()))
175  return failure();
176 
177  // Out-of-bounds dims are not supported.
178  if (writeOp.hasOutOfBoundsDim())
179  return rewriter.notifyMatchFailure(writeOp,
180  "not inbounds transfer write");
181 
182  AffineExpr d0, d1;
183  bindDims(writeOp.getContext(), d0, d1);
184  AffineMap map = writeOp.getPermutationMap();
185  bool isTranspose = (map == AffineMap::get(map.getNumDims(), 0, {d1, d0},
186  writeOp.getContext()));
187 
188  if (!map.isIdentity() && !isTranspose)
189  return rewriter.notifyMatchFailure(writeOp,
190  "unsupported permutation map");
191 
192  arm_sme::TileSliceLayout layout =
193  isTranspose ? arm_sme::TileSliceLayout::Vertical
194  : arm_sme::TileSliceLayout::Horizontal;
195 
196  rewriter.replaceOpWithNewOp<arm_sme::TileStoreOp>(
197  writeOp, writeOp.getVector(), writeOp.getSource(), writeOp.getIndices(),
198  writeOp.getMask(), layout);
199  return success();
200  }
201 };
202 
203 /// Conversion pattern for vector.load.
204 struct VectorLoadToArmSMELowering : public OpRewritePattern<vector::LoadOp> {
206 
207  LogicalResult matchAndRewrite(vector::LoadOp load,
208  PatternRewriter &rewriter) const override {
209  if (!arm_sme::isValidSMETileVectorType(load.getVectorType()))
210  return failure();
211 
212  rewriter.replaceOpWithNewOp<arm_sme::TileLoadOp>(
213  load, load.getVectorType(), load.getBase(), load.getIndices());
214 
215  return success();
216  }
217 };
218 
219 /// Conversion pattern for vector.store.
220 struct VectorStoreToArmSMELowering : public OpRewritePattern<vector::StoreOp> {
222 
223  LogicalResult matchAndRewrite(vector::StoreOp store,
224  PatternRewriter &rewriter) const override {
225  if (!arm_sme::isValidSMETileVectorType(store.getVectorType()))
226  return failure();
227 
228  rewriter.replaceOpWithNewOp<arm_sme::TileStoreOp>(
229  store, store.getValueToStore(), store.getBase(), store.getIndices());
230 
231  return success();
232  }
233 };
234 
235 /// Conversion pattern for dense arith.constant.
236 struct ConstantOpToArmSMELowering : public OpRewritePattern<arith::ConstantOp> {
238 
239  LogicalResult matchAndRewrite(arith::ConstantOp constantOp,
240  PatternRewriter &rewriter) const final {
241  auto tileType = dyn_cast<VectorType>(constantOp.getType());
242  if (!tileType || !arm_sme::isValidSMETileVectorType(tileType))
243  return failure();
244 
245  auto denseAttr = dyn_cast<DenseElementsAttr>(constantOp.getValueAttr());
246  if (!denseAttr || !denseAttr.isSplat())
247  return failure();
248 
249  auto tileElementType = tileType.getElementType();
250 
251  // Lower 'arith.constant dense<0>' to 'arm_sme.zero' op.
252  if (isSplatZero(tileElementType, denseAttr)) {
253  rewriter.replaceOpWithNewOp<arm_sme::ZeroOp>(constantOp, tileType);
254  return success();
255  }
256 
257  // Lower non-zero constants to a loop of 'arm_sme.move_vector_to_tile_slice'
258  // ops that broadcast the constant to each tile slice.
259  OpBuilder::InsertionGuard g(rewriter);
260  auto loc = constantOp.getLoc();
261 
262  // Unpack 1-d vector type from 2-d vector type.
263  auto tileSliceType =
264  VectorType::get(tileType.getShape().drop_front(), tileElementType,
265  /*scalableDims=*/{true});
266  auto denseAttr1D = DenseElementsAttr::get(
267  tileSliceType, denseAttr.getSplatValue<Attribute>());
268  auto constantOp1D = rewriter.create<arith::ConstantOp>(loc, denseAttr1D);
269 
270  arm_sme::CastTileToVector tile =
271  getSMETileAndCastToVector(rewriter, loc, tileType);
272 
273  auto forOp = getLoopOverTileSlices(rewriter, loc, tileElementType);
274  auto tileSliceIndex = forOp.getInductionVar();
275 
276  // Create 'arm_sme.move_vector_to_tile_slice' to write vector to tile slice.
277  rewriter.create<arm_sme::MoveVectorToTileSliceOp>(
278  loc, tileType, constantOp1D, tile, tileSliceIndex);
279 
280  rewriter.replaceOp(constantOp, tile);
281 
282  return success();
283  }
284 };
285 
286 /// Conversion pattern for vector.broadcast.
287 ///
288 /// Example:
289 ///
290 /// %broadcast_to_tile = vector.broadcast %src : i32 to vector<[4]x[4]xi32>
291 ///
292 /// is converted to:
293 ///
294 /// %broadcast_to_1d = vector.broadcast %src : i32 to vector<[4]xi32>
295 /// scf.for %tile_slice_index = %c0 to %num_tile_slices step %c1 {
296 /// arm_sme.move_vector_to_tile_slice %broadcast_to_1d, %tile,
297 /// %tile_slice_index : vector<[4]xi32> into vector<[4]x[4]xi32>
298 /// }
299 ///
300 /// Supports scalar, 0-d vector, and 1-d vector broadcasts.
301 struct BroadcastOpToArmSMELowering
302  : public OpRewritePattern<vector::BroadcastOp> {
304 
305  LogicalResult matchAndRewrite(vector::BroadcastOp broadcastOp,
306  PatternRewriter &rewriter) const final {
307  auto tileType = broadcastOp.getResultVectorType();
308  if (!tileType || !arm_sme::isValidSMETileVectorType(tileType))
309  return failure();
310 
311  OpBuilder::InsertionGuard g(rewriter);
312  auto loc = broadcastOp.getLoc();
313 
314  auto srcType = broadcastOp.getSourceType();
315  auto srcVectorType = dyn_cast<VectorType>(srcType);
316  auto tileElementType = tileType.getElementType();
317 
318  Value broadcastOp1D;
319  if (srcType.isIntOrFloat() ||
320  (srcVectorType && (srcVectorType.getRank() == 0))) {
321  // Broadcast scalar or 0-d vector to 1-d vector.
322  auto tileSliceType =
323  VectorType::get(tileType.getShape().drop_front(), tileElementType,
324  /*scalableDims=*/{true});
325  broadcastOp1D = rewriter.create<vector::BroadcastOp>(
326  loc, tileSliceType, broadcastOp.getSource());
327  } else if (srcVectorType && (srcVectorType.getRank() == 1))
328  // Value to broadcast is already a 1-d vector, nothing to do.
329  broadcastOp1D = broadcastOp.getSource();
330  else
331  return failure();
332 
333  arm_sme::CastTileToVector tile =
334  getSMETileAndCastToVector(rewriter, loc, tileType);
335 
336  // Create a loop over ZA tile slices.
337  auto forOp = getLoopOverTileSlices(rewriter, loc, tileElementType);
338  auto tileSliceIndex = forOp.getInductionVar();
339 
340  // Create 'arm_sme.move_vector_to_tile_slice' to broadcast the value to each
341  // tile slice.
342  rewriter.create<arm_sme::MoveVectorToTileSliceOp>(
343  loc, tileType, broadcastOp1D, tile, tileSliceIndex);
344 
345  rewriter.replaceOp(broadcastOp, tile);
346 
347  return success();
348  }
349 };
350 
351 /// Conversion pattern for vector.splat.
352 ///
353 /// Example:
354 ///
355 /// %splat_to_tile = vector.splat %src : i32 to vector<[4]x[4]xi32>
356 ///
357 /// is converted to:
358 ///
359 /// %broadcast_to_1d = vector.broadcast %src : i32 to vector<[4]xi32>
360 /// scf.for %tile_slice_index = %c0 to %num_tile_slices step %c1 {
361 /// arm_sme.move_vector_to_tile_slice %broadcast_to_1d, %tile,
362 /// %tile_slice_index : vector<[4]xi32> into vector<[4]x[4]xi32>
363 /// }
364 ///
365 /// This is identical to vector.broadcast of a scalar.
366 struct SplatOpToArmSMELowering : public OpRewritePattern<vector::SplatOp> {
368 
369  LogicalResult matchAndRewrite(vector::SplatOp splatOp,
370  PatternRewriter &rewriter) const final {
371  auto tileType = splatOp.getResult().getType();
372  if (!tileType || !arm_sme::isValidSMETileVectorType(tileType))
373  return failure();
374 
375  OpBuilder::InsertionGuard g(rewriter);
376  auto loc = splatOp.getLoc();
377 
378  auto srcType = splatOp.getOperand().getType();
379  auto tileElementType = tileType.getElementType();
380 
381  assert(srcType.isIntOrFloat() && "Invalid source type for vector.splat");
382  // Avoid unused-variable warning when building without assertions.
383  (void)srcType;
384 
385  // First, broadcast the scalar to a 1-d vector.
386  VectorType tileSliceType = VectorType::Builder(tileType).dropDim(0);
387  Value broadcastOp1D = rewriter.create<vector::BroadcastOp>(
388  loc, tileSliceType, splatOp.getInput());
389 
390  arm_sme::CastTileToVector tile =
391  getSMETileAndCastToVector(rewriter, loc, tileType);
392 
393  // Next, create a loop over ZA tile slices and "move" the generated 1-d
394  // vector to each slice.
395  auto forOp = getLoopOverTileSlices(rewriter, loc, tileElementType);
396  auto tileSliceIndex = forOp.getInductionVar();
397 
398  rewriter.create<arm_sme::MoveVectorToTileSliceOp>(
399  loc, tileType, broadcastOp1D, tile, tileSliceIndex);
400 
401  rewriter.replaceOp(splatOp, tile);
402 
403  return success();
404  }
405 };
406 
407 /// Conversion pattern for vector.transpose.
408 ///
409 /// Stores the input tile to memory and reloads vertically.
410 ///
411 /// Example:
412 ///
413 /// %transposed_src = vector.transpose %src, [1, 0]
414 /// : vector<[4]x[4]xi32> to vector<[4]x[4]xi32>
415 ///
416 /// is converted to:
417 ///
418 /// %alloca = memref.alloca(%svl_s, %svl_s) : memref<?x?xi32>
419 /// %arm_sme.tile_store %src, <hor>, %alloca[%c0, %c0]
420 /// : memref<?x?xi32>, vector<[4]x[4]xi32>
421 /// %transposed_src = arm_sme.tile_load %alloca[%c0, %c0]
422 /// layout<vertical> : memref<?x?xi32>, vector<[4]x[4]xi32>
423 ///
424 /// NOTE: Tranposing via memory is obviously expensive, the current intention
425 /// is to avoid the transpose if possible, this is therefore intended as a
426 /// fallback and to provide base support for Vector ops. If it turns out
427 /// transposes can't be avoided then this should be replaced with a more optimal
428 /// implementation, perhaps with tile <-> vector (MOVA) ops.
429 struct TransposeOpToArmSMELowering
430  : public OpRewritePattern<vector::TransposeOp> {
432 
433  LogicalResult matchAndRewrite(vector::TransposeOp transposeOp,
434  PatternRewriter &rewriter) const final {
435  auto tileType = transposeOp.getResultVectorType();
436  if (!tileType || !arm_sme::isValidSMETileVectorType(tileType))
437  return failure();
438 
439  // Bail unless this is a true 2-D matrix transpose.
440  ArrayRef<int64_t> permutation = transposeOp.getPermutation();
441  if (permutation[0] != 1 || permutation[1] != 0)
442  return failure();
443 
444  OpBuilder::InsertionGuard g(rewriter);
445  auto loc = transposeOp.getLoc();
446 
447  // Allocate buffer to store input tile to.
448  Value vscale =
449  rewriter.create<vector::VectorScaleOp>(loc, rewriter.getIndexType());
450  Value minTileSlices = rewriter.create<arith::ConstantOp>(
451  loc, rewriter.getIndexAttr(tileType.getDimSize(0)));
452  Value c0 =
453  rewriter.create<arith::ConstantOp>(loc, rewriter.getIndexAttr(0));
454  Value numTileSlices =
455  rewriter.create<arith::MulIOp>(loc, vscale, minTileSlices);
456  auto bufferType =
457  MemRefType::get({ShapedType::kDynamic, ShapedType::kDynamic},
458  tileType.getElementType());
459  auto buffer = rewriter.create<memref::AllocaOp>(
460  loc, bufferType, ValueRange{numTileSlices, numTileSlices});
461 
462  Value input = transposeOp.getVector();
463 
464  // Store input tile.
465  auto tileStoreOp = rewriter.create<arm_sme::TileStoreOp>(
466  loc, input, buffer, ValueRange{c0, c0});
467 
468  // Reload input tile vertically.
469  rewriter.replaceOpWithNewOp<arm_sme::TileLoadOp>(
470  transposeOp, tileType, tileStoreOp.getBase(), tileStoreOp.getIndices(),
471  arm_sme::TileSliceLayout::Vertical);
472 
473  return success();
474  }
475 };
476 
477 /// Conversion pattern for vector.outerproduct.
478 ///
479 /// If the vector.outerproduct is masked (and the mask is from a
480 /// vector.create_mask), then the mask is decomposed into two 1-D masks for the
481 /// operands.
482 ///
483 /// Example:
484 ///
485 /// %mask = vector.create_mask %dimA, %dimB : vector<[4]x[4]xi1>
486 /// %result = vector.mask %mask {
487 /// vector.outerproduct %vecA, %vecB
488 /// : vector<[4]xf32>, vector<[4]xf32>
489 /// } : vector<[4]x[4]xi1> -> vector<[4]x[4]xf32>
490 ///
491 /// is converted to:
492 ///
493 /// %maskA = vector.create_mask %dimA : vector<[4]xi1>
494 /// %maskB = vector.create_mask %dimB : vector<[4]xi1>
495 /// %result = arm_sme.outerproduct %vecA, %vecB masks(%maskA, %maskB)
496 /// : vector<[4]xf32>, vector<[4]xf32>
497 ///
498 /// Unmasked outerproducts can be directly replaced with the arm_sme op.
499 ///
500 /// Example:
501 ///
502 /// %result = vector.outerproduct %vecA, %vecB
503 /// : vector<[4]xf32>, vector<[4]xf32>
504 ///
505 /// is converted to:
506 ///
507 /// %result = arm_sme.outerproduct %vecA, %vecB
508 /// : vector<[4]xf32>, vector<[4]xf32>
509 ///
510 struct VectorOuterProductToArmSMELowering
511  : public OpRewritePattern<vector::OuterProductOp> {
512 
514 
515  LogicalResult matchAndRewrite(vector::OuterProductOp outerProductOp,
516  PatternRewriter &rewriter) const override {
517 
518  // We don't yet support lowering AXPY operations to SME. These could be
519  // lowered by masking out all but the first element of the LHS.
520  if (!isa<VectorType>(outerProductOp.getOperandTypeRHS()))
521  return outerProductOp.emitError("AXPY operations not supported");
522 
524  outerProductOp.getResultVectorType()))
525  return outerProductOp.emitError(
526  "outer product does not fit into SME tile");
527 
528  auto kind = outerProductOp.getKind();
529  if (kind != vector::CombiningKind::ADD)
530  return outerProductOp.emitError(
531  "unsupported kind (lowering to SME only supports ADD at the moment)");
532 
533  Value lhsMask = {};
534  Value rhsMask = {};
535  Operation *rootOp = outerProductOp;
536  auto loc = outerProductOp.getLoc();
537  if (outerProductOp.isMasked()) {
538  auto maskOp = outerProductOp.getMaskingOp();
539  rewriter.setInsertionPoint(maskOp);
540  rootOp = maskOp;
541  auto operandMasks = decomposeResultMask(loc, maskOp.getMask(), rewriter);
542  if (failed(operandMasks))
543  return failure();
544  std::tie(lhsMask, rhsMask) = *operandMasks;
545  }
546 
547  rewriter.replaceOpWithNewOp<arm_sme::OuterProductOp>(
548  rootOp, outerProductOp.getResultVectorType(), outerProductOp.getLhs(),
549  outerProductOp.getRhs(), lhsMask, rhsMask, outerProductOp.getAcc());
550 
551  return success();
552  }
553 
555  decomposeResultMask(Location loc, Value mask, PatternRewriter &rewriter) {
556  // Attempt to extract masks from vector.create_mask.
557  // TODO: Add support for other mask sources.
558  auto createMaskOp = mask.getDefiningOp<vector::CreateMaskOp>();
559  if (!createMaskOp)
560  return failure();
561 
562  auto maskType = createMaskOp.getVectorType();
563  Value lhsMaskDim = createMaskOp.getOperand(0);
564  Value rhsMaskDim = createMaskOp.getOperand(1);
565 
566  VectorType operandMaskType = VectorType::Builder(maskType).dropDim(0);
567  Value lhsMask =
568  rewriter.create<vector::CreateMaskOp>(loc, operandMaskType, lhsMaskDim);
569  Value rhsMask =
570  rewriter.create<vector::CreateMaskOp>(loc, operandMaskType, rhsMaskDim);
571 
572  return std::make_pair(lhsMask, rhsMask);
573  }
574 };
575 
576 /// Lower `vector.extract` using `arm_sme.move_tile_slice_to_vector`.
577 ///
578 /// Example:
579 /// ```
580 /// %el = vector.extract %tile[%row, %col]: i32 from vector<[4]x[4]xi32>
581 /// ```
582 /// Becomes:
583 /// ```
584 /// %slice = arm_sme.move_tile_slice_to_vector %tile[%row]
585 /// : vector<[4]xi32> from vector<[4]x[4]xi32>
586 /// %el = vector.extract %slice[%col] : i32 from vector<[4]xi32>
587 /// ```
588 struct VectorExtractToArmSMELowering
589  : public OpRewritePattern<vector::ExtractOp> {
591 
592  LogicalResult matchAndRewrite(vector::ExtractOp extractOp,
593  PatternRewriter &rewriter) const override {
594  VectorType sourceType = extractOp.getSourceVectorType();
595  if (!arm_sme::isValidSMETileVectorType(sourceType))
596  return failure();
597 
598  auto loc = extractOp.getLoc();
599  auto position = extractOp.getMixedPosition();
600 
601  Value sourceVector = extractOp.getVector();
602 
603  // Extract entire vector. Should be handled by folder, but just to be safe.
604  if (position.empty()) {
605  rewriter.replaceOp(extractOp, sourceVector);
606  return success();
607  }
608 
609  Value sliceIndex = vector::getAsValues(rewriter, loc, position[0]).front();
610  auto moveTileSliceToVector =
611  rewriter.create<arm_sme::MoveTileSliceToVectorOp>(loc, sourceVector,
612  sliceIndex);
613 
614  if (position.size() == 1) {
615  // Single index case: Extracts a 1D slice.
616  rewriter.replaceOp(extractOp, moveTileSliceToVector);
617  return success();
618  }
619 
620  // Two indices case: Extracts a single element.
621  assert(position.size() == 2);
622  rewriter.replaceOpWithNewOp<vector::ExtractOp>(
623  extractOp, moveTileSliceToVector, position[1]);
624 
625  return success();
626  }
627 };
628 
629 /// Lower `vector.insert` using `arm_sme.move_vector_to_tile_slice` and
630 /// `arm_sme.move_tile_slice_to_vector`.
631 ///
632 /// Example:
633 /// ```
634 /// %new_tile = vector.insert %el, %tile[%row, %col]
635 /// : i32 into vector<[4]x[4]xi32>
636 /// ```
637 /// Becomes:
638 /// ```
639 /// %slice = arm_sme.move_tile_slice_to_vector %tile[%row]
640 /// : vector<[4]xi32> from vector<[4]x[4]xi32>
641 /// %new_slice = vector.insert %el, %slice[%col] : i32 into vector<[4]xi32>
642 /// %new_tile = arm_sme.move_vector_to_tile_slice %new_slice, %tile, %row
643 /// : vector<[4]xi32> into vector<[4]x[4]xi32>
644 /// ```
645 struct VectorInsertToArmSMELowering
646  : public OpRewritePattern<vector::InsertOp> {
648 
649  LogicalResult matchAndRewrite(vector::InsertOp insertOp,
650  PatternRewriter &rewriter) const override {
651  VectorType resultType = insertOp.getResult().getType();
652 
653  if (!arm_sme::isValidSMETileVectorType(resultType))
654  return failure();
655 
656  auto loc = insertOp.getLoc();
657  auto position = insertOp.getMixedPosition();
658 
659  Value source = insertOp.getSource();
660 
661  // Overwrite entire vector with value. Should be handled by folder, but
662  // just to be safe.
663  if (position.empty()) {
664  rewriter.replaceOp(insertOp, source);
665  return success();
666  }
667 
668  Value tileSlice = source;
669  Value sliceIndex = vector::getAsValues(rewriter, loc, position[0]).front();
670  if (position.size() == 2) {
671  // Two indices case: Insert single element into tile.
672  // We need to first extract the existing slice and update the element.
673  tileSlice = rewriter.create<arm_sme::MoveTileSliceToVectorOp>(
674  loc, insertOp.getDest(), sliceIndex);
675  tileSlice = rewriter.create<vector::InsertOp>(loc, source, tileSlice,
676  position[1]);
677  }
678 
679  // Insert the slice into the destination tile.
680  rewriter.replaceOpWithNewOp<arm_sme::MoveVectorToTileSliceOp>(
681  insertOp, tileSlice, insertOp.getDest(), sliceIndex);
682  return success();
683  }
684 };
685 
686 } // namespace
687 
689  MLIRContext &ctx) {
690  patterns.add<BroadcastOpToArmSMELowering, ConstantOpToArmSMELowering,
691  SplatOpToArmSMELowering, TransferReadToArmSMELowering,
692  TransferWriteToArmSMELowering, TransposeOpToArmSMELowering,
693  VectorLoadToArmSMELowering, VectorStoreToArmSMELowering,
694  VectorOuterProductToArmSMELowering,
695  VectorExtractToArmSMELowering, VectorInsertToArmSMELowering>(
696  &ctx);
697 }
static arm_sme::CastTileToVector getSMETileAndCastToVector(PatternRewriter &rewriter, Location loc, VectorType type)
Returns a tile of the given vector type.
static scf::ForOp getLoopOverTileSlices(PatternRewriter &rewriter, Location loc, Type eltType)
Generates a for loop over ZA tile slices where the induction variable is the tile slice index.
static bool isSplatZero(Type elemType, DenseElementsAttr val)
Returns true if 'val' is a splat of zero, false otherwise.
Base type for affine expression.
Definition: AffineExpr.h:68
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
Definition: AffineMap.h:47
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
unsigned getNumDims() const
Definition: AffineMap.cpp:374
static AffineMap getPermutationMap(ArrayRef< unsigned > permutation, MLIRContext *context)
Returns an AffineMap representing a permutation.
Definition: AffineMap.cpp:248
bool isIdentity() const
Returns true if this affine map is an identity affine map.
Definition: AffineMap.cpp:323
Attributes are known-constant values of operations.
Definition: Attributes.h:25
IntegerAttr getIndexAttr(int64_t value)
Definition: Builders.cpp:124
IntegerType getIntegerType(unsigned width)
Definition: Builders.cpp:87
IndexType getIndexType()
Definition: Builders.cpp:71
An attribute that represents a reference to a dense vector or tensor object.
std::enable_if_t<!std::is_base_of< Attribute, T >::value||std::is_same< Attribute, T >::value, T > getSplatValue() const
Return the splat value for this attribute.
bool isSplat() const
Returns true if this attribute corresponds to a splat, i.e.
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
This class provides support for representing a failure result, or a valid value of type T.
Definition: LogicalResult.h:78
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:63
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
RAII guard to reset the insertion point of the builder when destroyed.
Definition: Builders.h:333
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
Definition: Builders.h:416
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Definition: Builders.h:383
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:446
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
Location getLoc()
The source location the operation was defined or derived from.
Definition: Operation.h:223
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:727
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
virtual void replaceOp(Operation *op, ValueRange newValues)
This method replaces the results of the operation with the specified list of values.
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replaces the result op with a new op that is created without verification.
Definition: PatternMatch.h:539
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:378
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition: Value.cpp:20
This is a builder type that keeps local references to arguments.
Definition: BuiltinTypes.h:305
Builder & dropDim(unsigned pos)
Erase a dim from shape @pos.
Definition: BuiltinTypes.h:330
unsigned getSMETileSliceMinNumElts(Type type)
Return minimum number of elements for the given element type in a vector of SVL bits.
Definition: Utils.cpp:21
bool isValidSMETileVectorType(VectorType vType)
Returns true if vType is a valid vector type for an SME tile or false otherwise.
Definition: Utils.cpp:32
SmallVector< Value > getAsValues(OpBuilder &builder, Location loc, ArrayRef< OpFoldResult > foldResults)
Convert foldResults into Values.
Definition: VectorOps.cpp:306
Include the generated interface declarations.
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:62
void bindDims(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to DimExpr at positions: [0 .
Definition: AffineExpr.h:334
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:56
SmallVector< Loops, 8 > tile(ArrayRef< scf::ForOp > forOps, ArrayRef< Value > sizes, ArrayRef< scf::ForOp > targets)
Performs tiling fo imperfectly nested loops (with interchange) by strip-mining the forOps by sizes an...
Definition: Utils.cpp:824
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
void populateVectorToArmSMEPatterns(RewritePatternSet &patterns, MLIRContext &ctx)
Collect a set of patterns to lower Vector ops to ArmSME ops that map to LLVM intrinsics.
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value.
Definition: LogicalResult.h:72
This class represents an efficient way to signal success or failure.
Definition: LogicalResult.h:26
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Definition: PatternMatch.h:357