MLIR  22.0.0git
ArmSMEToLLVM.cpp
Go to the documentation of this file.
1 //===- ArmSMEToLLVM.cpp - Convert ArmSME to LLVM dialect ------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements lowering of ArmSME operations to LLVM intrinsics.
10 //
11 //===----------------------------------------------------------------------===//
12 
14 
25 #include "mlir/Pass/Pass.h"
27 #include "llvm/ADT/ScopeExit.h"
28 
29 namespace mlir {
30 #define GEN_PASS_DEF_CONVERTARMSMETOLLVM
31 #include "mlir/Conversion/Passes.h.inc"
32 } // namespace mlir
33 
34 using namespace mlir;
35 
36 namespace {
37 
38 static constexpr StringLiteral kInMemoryTileIdAttr("arm_sme.in_memory_tile_id");
39 
40 /// Helper to create an arm_sme.intr.ld1*.(horiz|vert)' intrinsic.
41 static Operation *createLoadTileSliceIntrinsic(
42  RewriterBase &rewriter, Location loc, arm_sme::ArmSMETileType type,
43  arm_sme::TileSliceLayout layout, Value maskOp, Value ptr,
44  IntegerAttr tileId, Value tileSliceI32) {
45  if (layout == arm_sme::TileSliceLayout::Horizontal) {
46  switch (type) {
47  case arm_sme::ArmSMETileType::ZAB:
48  return arm_sme::aarch64_sme_ld1b_horiz::create(rewriter, loc, maskOp, ptr,
49  tileId, tileSliceI32);
50  case arm_sme::ArmSMETileType::ZAH:
51  return arm_sme::aarch64_sme_ld1h_horiz::create(rewriter, loc, maskOp, ptr,
52  tileId, tileSliceI32);
53  case arm_sme::ArmSMETileType::ZAS:
54  return arm_sme::aarch64_sme_ld1w_horiz::create(rewriter, loc, maskOp, ptr,
55  tileId, tileSliceI32);
56  case arm_sme::ArmSMETileType::ZAD:
57  return arm_sme::aarch64_sme_ld1d_horiz::create(rewriter, loc, maskOp, ptr,
58  tileId, tileSliceI32);
59  case arm_sme::ArmSMETileType::ZAQ:
60  return arm_sme::aarch64_sme_ld1q_horiz::create(rewriter, loc, maskOp, ptr,
61  tileId, tileSliceI32);
62  }
63  } else {
64  switch (type) {
65  case arm_sme::ArmSMETileType::ZAB:
66  return arm_sme::aarch64_sme_ld1b_vert::create(rewriter, loc, maskOp, ptr,
67  tileId, tileSliceI32);
68  case arm_sme::ArmSMETileType::ZAH:
69  return arm_sme::aarch64_sme_ld1h_vert::create(rewriter, loc, maskOp, ptr,
70  tileId, tileSliceI32);
71  case arm_sme::ArmSMETileType::ZAS:
72  return arm_sme::aarch64_sme_ld1w_vert::create(rewriter, loc, maskOp, ptr,
73  tileId, tileSliceI32);
74  case arm_sme::ArmSMETileType::ZAD:
75  return arm_sme::aarch64_sme_ld1d_vert::create(rewriter, loc, maskOp, ptr,
76  tileId, tileSliceI32);
77  case arm_sme::ArmSMETileType::ZAQ:
78  return arm_sme::aarch64_sme_ld1q_vert::create(rewriter, loc, maskOp, ptr,
79  tileId, tileSliceI32);
80  break;
81  }
82  }
83  llvm_unreachable("unknown type in createLoadTileSliceIntrinsic");
84 }
85 
86 /// Helper to create an arm_sme.intr.st1*.(horiz|vert)' intrinsic.
87 static Operation *createStoreTileSliceIntrinsic(
88  RewriterBase &rewriter, Location loc, arm_sme::ArmSMETileType type,
89  arm_sme::TileSliceLayout layout, Value maskOp, Value ptr,
90  IntegerAttr tileId, Value tileSliceI32) {
91  if (layout == arm_sme::TileSliceLayout::Horizontal) {
92  switch (type) {
93  case arm_sme::ArmSMETileType::ZAB:
94  return arm_sme::aarch64_sme_st1b_horiz::create(rewriter, loc, maskOp, ptr,
95  tileId, tileSliceI32);
96  case arm_sme::ArmSMETileType::ZAH:
97  return arm_sme::aarch64_sme_st1h_horiz::create(rewriter, loc, maskOp, ptr,
98  tileId, tileSliceI32);
99  case arm_sme::ArmSMETileType::ZAS:
100  return arm_sme::aarch64_sme_st1w_horiz::create(rewriter, loc, maskOp, ptr,
101  tileId, tileSliceI32);
102  case arm_sme::ArmSMETileType::ZAD:
103  return arm_sme::aarch64_sme_st1d_horiz::create(rewriter, loc, maskOp, ptr,
104  tileId, tileSliceI32);
105  case arm_sme::ArmSMETileType::ZAQ:
106  return arm_sme::aarch64_sme_st1q_horiz::create(rewriter, loc, maskOp, ptr,
107  tileId, tileSliceI32);
108  }
109  } else {
110  switch (type) {
111  case arm_sme::ArmSMETileType::ZAB:
112  return arm_sme::aarch64_sme_st1b_vert::create(rewriter, loc, maskOp, ptr,
113  tileId, tileSliceI32);
114  case arm_sme::ArmSMETileType::ZAH:
115  return arm_sme::aarch64_sme_st1h_vert::create(rewriter, loc, maskOp, ptr,
116  tileId, tileSliceI32);
117  case arm_sme::ArmSMETileType::ZAS:
118  return arm_sme::aarch64_sme_st1w_vert::create(rewriter, loc, maskOp, ptr,
119  tileId, tileSliceI32);
120  case arm_sme::ArmSMETileType::ZAD:
121  return arm_sme::aarch64_sme_st1d_vert::create(rewriter, loc, maskOp, ptr,
122  tileId, tileSliceI32);
123  case arm_sme::ArmSMETileType::ZAQ:
124  return arm_sme::aarch64_sme_st1q_vert::create(rewriter, loc, maskOp, ptr,
125  tileId, tileSliceI32);
126  }
127  }
128  llvm_unreachable("unknown type in createStoreTileSliceIntrinsic");
129 }
130 
131 IntegerAttr getTileIdOrError(arm_sme::ArmSMETileOpInterface op) {
132  auto tileId = op.getTileId();
133  if (!tileId)
134  op.emitOpError(
135  "expected tile ID to be allocated before conversion to LLVM");
136  return tileId;
137 }
138 
139 /// Creates an alloca matching the size of tile used by `tileOp`. The alloca is
140 /// placed in the first block of the function.
141 static memref::AllocaOp
142 createAllocaForTile(RewriterBase &rewriter, Location loc,
143  FunctionOpInterface func,
144  arm_sme::ArmSMETileOpInterface tileOp) {
145  RewriterBase::InsertionGuard g(rewriter);
146  // Move to the first operation in the function.
147  rewriter.setInsertionPointToStart(&func.getBlocks().front());
148  // Create an alloca matching the tile size of the `tileOp`.
149  auto vscale = vector::VectorScaleOp::create(rewriter, loc);
150  auto tileElementType = tileOp.getTileType().getElementType();
151  auto memrefType = MemRefType::get(
152  {ShapedType::kDynamic, ShapedType::kDynamic}, tileElementType);
153  unsigned minElements = arm_sme::getSMETileSliceMinNumElts(tileElementType);
154  auto minElementsOp =
155  arith::ConstantIndexOp::create(rewriter, loc, minElements);
156  auto vectorLen = arith::MulIOp::create(rewriter, loc, vscale, minElementsOp);
157  auto alloca = memref::AllocaOp::create(rewriter, loc, memrefType,
158  ValueRange{vectorLen, vectorLen});
159  return alloca;
160 }
161 
162 /// Finds or creates an alloca for a spill of a tile.
163 static memref::AllocaOp getOrCreateAllocaForTile(
164  RewriterBase &rewriter, Location loc, FunctionOpInterface func,
165  arm_sme::ArmSMETileOpInterface tileOp, unsigned tileId) {
166  // Find an alloca at the top of the function tagged with a
167  // 'arm_sme.in_memory_tile_id' that matches `tileId`.
168  for (auto &op : func.getBlocks().front()) {
169  auto alloca = llvm::dyn_cast<memref::AllocaOp>(op);
170  if (!alloca)
171  continue;
172  auto inMemoryTileId = llvm::dyn_cast_or_null<IntegerAttr>(
173  alloca->getDiscardableAttr(kInMemoryTileIdAttr));
174  if (!inMemoryTileId)
175  continue;
176  if (inMemoryTileId.getInt() == tileId)
177  return alloca;
178  }
179  // Otherwise, create a new alloca:
180  auto alloca = createAllocaForTile(rewriter, loc, func, tileOp);
181  alloca->setDiscardableAttr(kInMemoryTileIdAttr,
182  rewriter.getI32IntegerAttr(tileId));
183  return alloca;
184 }
185 
186 /// Very naive lowering of in-memory tiles (i.e. tiles that were not assigned a
187 /// hardware tile ID) to ArmSME intrinsics. Currently, this works by assigning
188 /// the op to tile 0, then emitting a full tile swap between ZA and memory
189 /// before + after the tile op.
190 ///
191 /// Example:
192 ///
193 /// // Note: <IN MEMORY TILE> = tile ID >= 16.
194 /// arm_sme.tile_op { tile_id = <IN MEMORY TILE> }
195 ///
196 /// is converted to:
197 /// // At function entry:
198 /// %spill = memref.alloca ... : memref<?x?xty>
199 ///
200 /// // Around op:
201 /// scf.for %slice_idx {
202 /// %slice_to_save = "arm_sme.intr.read.horiz" ... <{tile_id = 0 : i32}>
203 /// "arm_sme.intr.ld1h.horiz"(%spill, %slice_idx) <{tile_id = 0 : i32}>
204 /// vector.store %slice_to_save, %spill[%slice_idx, %c0]
205 /// }
206 /// arm_sme.tile_op { tile_id = 0 }
207 /// scf.for %slice_idx {
208 /// %slice_to_save = "arm_sme.intr.read.horiz" ... <{tile_id = 0 : i32}>
209 /// "arm_sme.intr.ld1h.horiz"(%spill, %slice_idx) <{tile_id = 0 : i32}>
210 /// vector.store %slice_to_save, %spill[%slice_idx, %c0]
211 /// }
212 ///
213 /// Note that these spills/fills are not inserted earlier as concept of a
214 /// register, and the need to swap the contents, can't really be represented
215 /// correctly at a high level in MLIR.
216 ///
217 /// TODO: Reduce the spills/reloads to single slices where possible (and omit
218 /// redundant reloads). This could be done via a method on the
219 /// `ArmSMETileOpInterface` which returns how the operation uses ZA. E.g.:
220 ///
221 /// `tileOp.getZaUsage()` could return:
222 ///
223 /// struct ArmSMEOpZAUsage {
224 /// enum class Kind {
225 /// TileRead, // Omit store after tile operation.
226 /// TileWrite, // Omit load before tile operation.
227 /// TileReadWrite, // Needs both tile load and store.
228 /// SliceRead, // Spill single slice and omit store after operation.
229 /// SliceWrite, // Spill single slice and omit load before operation.
230 /// SliceReadWrite // Spill single slice.
231 /// };
232 /// Value sliceIndex {};
233 /// TileSliceLayout sliceLayout { TileSliceLayout::Horizontal };
234 /// };
235 ///
236 struct ConvertArmSMESpillsAndFillsToLLVM : public ConvertToLLVMPattern {
237 
238  ConvertArmSMESpillsAndFillsToLLVM(StringRef rootOpName,
239  const LLVMTypeConverter &typeConverter,
240  PatternBenefit benefit)
241  : ConvertToLLVMPattern(rootOpName, &typeConverter.getContext(),
242  typeConverter, benefit) {}
243 
244  LogicalResult
245  matchAndRewrite(Operation *op, ArrayRef<Value> operands,
246  ConversionPatternRewriter &rewriter) const override {
247  auto tileOp = cast<arm_sme::ArmSMETileOpInterface>(op);
248  // Tile has a real (hardware) tile. No spills/reloads required.
249  if (!tileOp.isInMemoryTile())
250  return failure();
251 
252  tileOp->emitWarning(
253  "failed to allocate SME virtual tile to operation, tile value will go "
254  "through memory, expect degraded performance");
255 
256  // Step 1. Create an alloca for the tile at the top of the function (if one
257  // does not already exist).
258  auto loc = tileOp.getLoc();
259  auto func = tileOp->getParentOfType<FunctionOpInterface>();
260  auto tileAlloca = getOrCreateAllocaForTile(rewriter, loc, func, tileOp,
261  tileOp.getTileId().getInt());
262 
263  // Step 2. Assign the op a real tile ID.
264  // For simplicity, we always use tile 0 (which always exists).
265  auto zeroTileId = rewriter.getI32IntegerAttr(0);
266  rewriter.modifyOpInPlace(tileOp, [&] { tileOp.setTileId(zeroTileId); });
267 
268  VectorType tileVectorType = tileOp.getTileType();
269  auto sliceType = VectorType::Builder(tileVectorType).dropDim(0);
270  auto swapInMemoryTileWithSMETileZero = [&] {
271  emitFullTileSwap(rewriter, loc, tileAlloca,
272  *arm_sme::getSMETileType(tileVectorType), sliceType,
273  zeroTileId);
274  };
275 
276  // Step 3. Emit tile swaps before and after the op.
277  // TODO: Reduce the amount spilled to the amount of data the `tileOp`
278  // touches (i.e. a single tile slice).
279  {
280  rewriter.setInsertionPoint(op);
281  // Swap the contents of ZA and the in-memory tile before the op.
282  swapInMemoryTileWithSMETileZero();
283  rewriter.setInsertionPointAfter(op);
284  // Swap the tile back out to memory again after the op.
285  swapInMemoryTileWithSMETileZero();
286  }
287 
288  return success();
289  }
290 
291  /// Extracts a pointer to a slice of an in-memory tile.
292  Value getInMemoryTileSlicePtr(RewriterBase &rewriter, Location loc,
293  Value tileMemory, Value sliceIndex) const {
294  auto llvmType = getTypeConverter()->convertType(tileMemory.getType());
295  auto descriptor =
296  UnrealizedConversionCastOp::create(rewriter, loc, llvmType, tileMemory);
297  auto zero = arith::ConstantIntOp::create(rewriter, loc, 0, /*width=*/64);
298  auto sliceIndexI64 = arith::IndexCastOp::create(
299  rewriter, loc, rewriter.getI64Type(), sliceIndex);
300  return getStridedElementPtr(
301  static_cast<ConversionPatternRewriter &>(rewriter), loc,
302  llvm::cast<MemRefType>(tileMemory.getType()), descriptor.getResult(0),
303  {sliceIndexI64, zero});
304  }
305 
306  /// Emits an in-place swap of a slice of a tile in ZA and a slice of a
307  /// tile-sized memref (`tileAlloca`).
308  void emitSliceSwap(RewriterBase &rewriter, Location loc, Value tileAlloca,
309  arm_sme::ArmSMETileType tileType, VectorType sliceType,
310  IntegerAttr tileId, Value sliceIndex) const {
311  // Cast the slice index to an i32.
312  auto sliceIndexI32 = arith::IndexCastOp::create(
313  rewriter, loc, rewriter.getI32Type(), sliceIndex);
314  // Create an all-true predicate for the slice.
315  auto predicateType = sliceType.clone(rewriter.getI1Type());
316  auto allTruePredicate = arith::ConstantOp::create(
317  rewriter, loc, DenseElementsAttr::get(predicateType, true));
318  // Create padding vector (never used due to all-true predicate).
319  auto padVector = LLVM::PoisonOp::create(rewriter, loc, sliceType);
320  // Get a pointer to the current slice.
321  auto slicePtr =
322  getInMemoryTileSlicePtr(rewriter, loc, tileAlloca, sliceIndex);
323  // Read the value of the current slice from ZA.
324  auto currentTileSlice = arm_sme::aarch64_sme_read_horiz::create(
325  rewriter, loc, sliceType, padVector, allTruePredicate, tileId,
326  sliceIndexI32);
327  // Load the new tile slice back from memory into ZA.
328  createLoadTileSliceIntrinsic(
329  rewriter, loc, tileType, arm_sme::TileSliceLayout::Horizontal,
330  allTruePredicate, slicePtr, tileId, sliceIndexI32);
331  // Store the current tile slice to memory.
332  auto zero = arith::ConstantIndexOp::create(rewriter, loc, 0);
333  vector::StoreOp::create(rewriter, loc, currentTileSlice, tileAlloca,
334  ValueRange{sliceIndex, zero});
335  }
336 
337  /// Emits a full in-place swap of the contents of a tile in ZA and a
338  /// tile-sized memref (`tileAlloca`).
339  void emitFullTileSwap(RewriterBase &rewriter, Location loc, Value tileAlloca,
340  arm_sme::ArmSMETileType tileType, VectorType sliceType,
341  IntegerAttr tileId) const {
342  RewriterBase::InsertionGuard guard(rewriter);
343  // Create an scf.for over all tile slices.
344  auto minNumElts =
345  arith::ConstantIndexOp::create(rewriter, loc, sliceType.getDimSize(0));
346  auto lowerBound = arith::ConstantIndexOp::create(rewriter, loc, 0);
347  auto upperBound =
348  arith::MulIOp::create(rewriter, loc, minNumElts,
349  vector::VectorScaleOp::create(rewriter, loc));
350  auto step = arith::ConstantIndexOp::create(rewriter, loc, 1);
351  auto forOp =
352  scf::ForOp::create(rewriter, loc, lowerBound, upperBound, step);
353  // Emit a swap for each tile slice.
354  rewriter.setInsertionPointToStart(forOp.getBody());
355  auto sliceIndex = forOp.getInductionVar();
356  emitSliceSwap(rewriter, loc, tileAlloca, tileType, sliceType, tileId,
357  sliceIndex);
358  }
359 };
360 
361 enum class RequiresSpillsAndFills { Yes, No };
362 
363 /// Base class for ArmSME to LLVM conversion patterns. By default, this adds
364 /// spills and fills around ArmSME ops that use in-memory tile IDs. This can be
365 /// disabled by setting the `requiresSpillsAndFills` template parameter to
366 /// `RequiresSpillsAndFills::No`.
367 template <typename SourceOp, RequiresSpillsAndFills requiresSpillsAndFills =
368  RequiresSpillsAndFills::Yes>
369 struct ConvertArmSMEOpToLLVMPattern : ConvertOpToLLVMPattern<SourceOp> {
370  using ArmSMEOp = SourceOp;
372 
373  static constexpr bool requiresSpillsAndFillsConversion() {
374  return requiresSpillsAndFills == RequiresSpillsAndFills::Yes;
375  }
376 };
377 
378 template <typename Pattern>
379 static void addArmSMEConversionPattern(RewritePatternSet &patterns,
380  LLVMTypeConverter const &typeConverter) {
381  // Register spills/fills for ops that implement the
382  // `ArmSMETileOpInterface` and have `requiresSpillsAndFills` set to
383  // `RequiresSpillsAndFills::Yes`.
384  if constexpr (Pattern::requiresSpillsAndFillsConversion() &&
385  std::is_base_of_v<arm_sme::ArmSMETileOpInterface::Trait<
386  typename Pattern::ArmSMEOp>,
387  typename Pattern::ArmSMEOp>) {
388  // Add spill/fill conversions with a very high benefit to ensure
389  // they are lowered first.
390  patterns.add<ConvertArmSMESpillsAndFillsToLLVM>(
391  Pattern::ArmSMEOp::getOperationName(), typeConverter,
392  /*benefit=*/1337);
393  }
394  patterns.add<Pattern>(typeConverter);
395 }
396 
397 /// Helper to register `ConvertArmSMEOpToLLVMPattern` patterns.
398 template <typename... Patterns>
399 static void
400 addArmSMEConversionPatterns(RewritePatternSet &patterns,
401  LLVMTypeConverter const &typeConverter) {
402  (addArmSMEConversionPattern<Patterns>(patterns, typeConverter), ...);
403 }
404 
405 /// Lower 'arm_sme.zero' to SME intrinsics.
406 ///
407 /// BEFORE:
408 /// ```mlir
409 /// %v = arm_sme.zero {tile_id = 0 : i32} : vector<[4]x[4]xi32>
410 /// ```
411 ///
412 /// AFTER:
413 /// ```mlir
414 /// "arm_sme.intr.zero"() <{tile_mask = 17 : i32}> : () -> ()
415 /// %v = arm_sme.get_tile : vector<[4]x[4]xi32>
416 /// ```
417 ///
418 /// The 'arm_sme.get_tile' (which models the return) will fold away once all
419 /// ArmSME ops have been converted to LLVM intrinsics.
420 struct ZeroOpConversion : public ConvertArmSMEOpToLLVMPattern<arm_sme::ZeroOp> {
421  using ConvertArmSMEOpToLLVMPattern::ConvertArmSMEOpToLLVMPattern;
422 
423  LogicalResult
424  matchAndRewrite(arm_sme::ZeroOp zero, OpAdaptor adaptor,
425  ConversionPatternRewriter &rewriter) const override {
426  auto loc = zero.getLoc();
427 
428  auto tileId = getTileIdOrError(zero);
429  if (!tileId)
430  return failure();
431 
432  // Get the base mask for tile based on the element size.
433  // The base mask is just the mask to zero the first tile (of a size).
434  // These masks are derived from:
435  // https://developer.arm.com/documentation/ddi0602/2022-06/SME-Instructions/ZERO--Zero-a-list-of-64-bit-element-ZA-tiles-
436  arm_sme::ArmSMETileType tileType =
437  *arm_sme::getSMETileType(zero.getTileType());
438  auto baseMaskForSize = [&] {
439  switch (tileType) {
440  case arm_sme::ArmSMETileType::ZAB:
441  // Zeroing the 8-bit ZA0.B tile is equivalent to zeroing all eight
442  // 64-bit element tiles named ZA0.D to ZA7.D.
443  return 0b1111'1111;
444  case arm_sme::ArmSMETileType::ZAH:
445  // Zeroing the 16-bit ZA0.H tile is equivalent to zeroing 64-bit
446  // element tiles named ZA0.D, ZA2.D, ZA4.D, and ZA6.D. Shift this left
447  // once for ZA1.H.
448  return 0b0101'0101;
449  case arm_sme::ArmSMETileType::ZAS:
450  // Zeroing the 32-bit ZA0.S tile is equivalent to zeroing 64-bit
451  // element tiles named ZA0.D and ZA4.D.
452  // Shift left by 1, 2, or 3 respectively for ZA1.S, ZA2.S, ZA3.S.
453  return 0b0001'0001;
454  case arm_sme::ArmSMETileType::ZAD:
455  // Zeroing one of the a 64-bit tiles ZA0.D to ZA7.D just requires
456  // setting the bit for that tile.
457  return 0b0000'0001;
458  default:
459  llvm_unreachable("bad element size");
460  }
461  }();
462 
463  // The actual mask is just the base mask shifted by the tile ID.
464  // This will be folded to a constant after tile allocation.
465  //
466  // The shift is just derived from the layout of the tiles, and that the tile
467  // ID is the index of the tile. For example, looking at the 32-bit ZAx.S
468  // tiles:
469  //
470  // ZA0.S = ZA0.D and ZA4.D
471  // * Tile ID -> 0
472  // * Mask -> 00010001 = (00010001 << 0)
473  // ZA1.S = ZA1.D and ZA5.D
474  // * Tile ID -> 1
475  // * Mask -> 00100010 = (00010001 << 1)
476  // ZA2.S = ZA2.D and ZA6.D
477  // * Tile ID -> 2
478  // * Mask -> 01000100 = (00010001 << 2)
479  // ZA3.S = ZA3.D and ZA7.D
480  // * Tile ID -> 3
481  // * Mask -> 10001000 = (00010001 << 3)
482  //
483  // This holds for all tile sizes.
484  int32_t zeroMask = baseMaskForSize << int32_t(tileId.getInt());
485  arm_sme::aarch64_sme_zero::create(rewriter, loc,
486  rewriter.getI32IntegerAttr(zeroMask));
487 
488  // Create a placeholder op to preserve dataflow.
489  // Note: Place the `get_tile` op at the start of the block. This ensures
490  // that if there are multiple `zero` ops the intrinsics will be consecutive.
491  rewriter.setInsertionPointToStart(zero->getBlock());
492  rewriter.replaceOpWithNewOp<arm_sme::GetTileOp>(zero, zero.getVectorType());
493 
494  return success();
495  }
496 };
497 
498 /// Lower `arm_sme.load_tile_slice` to SME intrinsics.
499 struct LoadTileSliceConversion
500  : public ConvertArmSMEOpToLLVMPattern<arm_sme::LoadTileSliceOp> {
501  using ConvertArmSMEOpToLLVMPattern::ConvertArmSMEOpToLLVMPattern;
502 
503  LogicalResult
504  matchAndRewrite(arm_sme::LoadTileSliceOp loadTileSliceOp,
505  arm_sme::LoadTileSliceOp::Adaptor adaptor,
506  ConversionPatternRewriter &rewriter) const override {
507  auto loc = loadTileSliceOp.getLoc();
508  auto tileId = getTileIdOrError(loadTileSliceOp);
509  if (!tileId)
510  return failure();
511 
512  Value ptr = this->getStridedElementPtr(
513  rewriter, loc, loadTileSliceOp.getMemRefType(), adaptor.getBase(),
514  adaptor.getIndices());
515 
516  auto tileSlice = loadTileSliceOp.getTileSliceIndex();
517 
518  // Cast tile slice to i32 for intrinsic.
519  auto tileSliceI32 = arith::IndexCastUIOp::create(
520  rewriter, loc, rewriter.getI32Type(), tileSlice);
521 
522  // Create all active predicate mask.
523  auto maskOp = loadTileSliceOp.getMask();
524 
525  auto tileVectorType = loadTileSliceOp.getVectorType();
526  arm_sme::ArmSMETileType tileType = *arm_sme::getSMETileType(tileVectorType);
527  arm_sme::TileSliceLayout layout = loadTileSliceOp.getLayout();
528 
529  // Create 'arm_sme.intr.ld1*.(horiz|vert)' intrinsic to load ZA tile slice.
530  createLoadTileSliceIntrinsic(rewriter, loc, tileType, layout, maskOp, ptr,
531  tileId, tileSliceI32);
532 
533  // The load intrinsics have no result, replace 'arm_sme.tile_load' with
534  // the input tile to preserve dataflow.
535  rewriter.replaceOp(loadTileSliceOp, loadTileSliceOp.getTile());
536 
537  return success();
538  }
539 };
540 
541 /// Lower for `arm_sme.store_tile_slice` to SME intrinsics.
542 struct StoreTileSliceConversion
543  : public ConvertArmSMEOpToLLVMPattern<arm_sme::StoreTileSliceOp> {
544  using ConvertArmSMEOpToLLVMPattern::ConvertArmSMEOpToLLVMPattern;
545 
546  LogicalResult
547  matchAndRewrite(arm_sme::StoreTileSliceOp storeTileSliceOp,
548  arm_sme::StoreTileSliceOp::Adaptor adaptor,
549  ConversionPatternRewriter &rewriter) const override {
550  auto loc = storeTileSliceOp.getLoc();
551  auto tileVectorType = storeTileSliceOp.getVectorType();
552 
553  auto tileId = getTileIdOrError(storeTileSliceOp);
554  if (!tileId)
555  return failure();
556 
557  // Create 'arm_sme.intr.st1*.horiz' intrinsic to store ZA tile slice.
558  Value ptr = this->getStridedElementPtr(
559  rewriter, loc, storeTileSliceOp.getMemRefType(), adaptor.getBase(),
560  adaptor.getIndices());
561 
562  auto tileSlice = storeTileSliceOp.getTileSliceIndex();
563 
564  // Cast tile slice to i32 for intrinsic.
565  auto tileSliceI32 = arith::IndexCastUIOp::create(
566  rewriter, loc, rewriter.getI32Type(), tileSlice);
567 
568  auto maskOp = storeTileSliceOp.getMask();
569 
570  arm_sme::TileSliceLayout layout = storeTileSliceOp.getLayout();
571  arm_sme::ArmSMETileType tileType = *arm_sme::getSMETileType(tileVectorType);
572 
573  rewriter.replaceOp(storeTileSliceOp,
574  createStoreTileSliceIntrinsic(rewriter, loc, tileType,
575  layout, maskOp, ptr,
576  tileId, tileSliceI32));
577 
578  return success();
579  }
580 };
581 
582 /// Lower `arm_sme.insert_tile_slice` to SME intrinsics.
583 struct InsertTileSliceConversion
584  : public ConvertArmSMEOpToLLVMPattern<arm_sme::InsertTileSliceOp> {
585  using ConvertArmSMEOpToLLVMPattern::ConvertArmSMEOpToLLVMPattern;
586 
587  LogicalResult
588  matchAndRewrite(arm_sme::InsertTileSliceOp insertTileSliceOp,
589  arm_sme::InsertTileSliceOp::Adaptor adaptor,
590  ConversionPatternRewriter &rewriter) const override {
591  auto loc = insertTileSliceOp.getLoc();
592  auto tileType = insertTileSliceOp.getTileType();
593 
594  auto tileId = getTileIdOrError(insertTileSliceOp);
595  if (!tileId)
596  return failure();
597 
598  auto tileSlice = insertTileSliceOp.getTileSliceIndex();
599 
600  // Cast tile slice from index to i32 for intrinsic.
601  auto tileSliceI32 = arith::IndexCastUIOp::create(
602  rewriter, loc, rewriter.getI32Type(), tileSlice);
603 
604  // Create all active predicate mask.
605  auto one = arith::ConstantOp::create(
606  rewriter, loc, rewriter.getI1Type(),
607  rewriter.getIntegerAttr(rewriter.getI1Type(), 1));
608  auto predTy = VectorType::get(tileType.getShape()[0], rewriter.getI1Type(),
609  /*scalableDims=*/{true});
610  auto allActiveMask =
611  vector::BroadcastOp::create(rewriter, loc, predTy, one);
612 
613  // Create 'arm_sme.intr.write.(horiz|vert)' to write vector to tile slice.
614  switch (insertTileSliceOp.getLayout()) {
615  case arm_sme::TileSliceLayout::Horizontal:
616  arm_sme::aarch64_sme_write_horiz::create(rewriter, loc, tileId,
617  tileSliceI32, allActiveMask,
618  insertTileSliceOp.getVector());
619  break;
620  case arm_sme::TileSliceLayout::Vertical:
621  arm_sme::aarch64_sme_write_vert::create(rewriter, loc, tileId,
622  tileSliceI32, allActiveMask,
623  insertTileSliceOp.getVector());
624  break;
625  }
626 
627  // Intrinsic has no result, replace 'arm_sme.insert_tile_slice' with
628  // the input tile to preserve dataflow.
629  rewriter.replaceOp(insertTileSliceOp, insertTileSliceOp.getTile());
630 
631  return success();
632  }
633 };
634 
635 /// Lower `arm_sme.extract_tile_slice` to SME intrinsics.
636 struct ExtractTileSliceConversion
637  : public ConvertArmSMEOpToLLVMPattern<arm_sme::ExtractTileSliceOp> {
638  using ConvertArmSMEOpToLLVMPattern::ConvertArmSMEOpToLLVMPattern;
639 
640  LogicalResult
641  matchAndRewrite(arm_sme::ExtractTileSliceOp extractTileSlice, OpAdaptor,
642  ConversionPatternRewriter &rewriter) const override {
643  auto loc = extractTileSlice.getLoc();
644  auto sliceType = extractTileSlice.getSliceType();
645  auto sliceIndex = extractTileSlice.getTileSliceIndex();
646 
647  auto tileId = getTileIdOrError(extractTileSlice);
648  if (!tileId)
649  return failure();
650 
651  // Create an 'all true' predicate for the tile slice.
652  auto predicateType = sliceType.cloneWith({}, rewriter.getI1Type());
653  auto allTruePredicate = arith::ConstantOp::create(
654  rewriter, loc, DenseElementsAttr::get(predicateType, true));
655 
656  // Zero destination/fallback for tile slice extraction.
657  auto zeroVector = arith::ConstantOp::create(
658  rewriter, loc, sliceType, rewriter.getZeroAttr(sliceType));
659 
660  // Cast tile slice from index to i32 for intrinsic.
661  auto sliceIndexI32 = arith::IndexCastOp::create(
662  rewriter, loc, rewriter.getI32Type(), sliceIndex);
663 
664  // Create 'arm_sme.intr.read.(horiz|vert)' to extract the tile slice.
665  switch (extractTileSlice.getLayout()) {
666  case arm_sme::TileSliceLayout::Horizontal:
667  rewriter.replaceOpWithNewOp<arm_sme::aarch64_sme_read_horiz>(
668  extractTileSlice, sliceType, zeroVector, allTruePredicate, tileId,
669  sliceIndexI32);
670  break;
671  case arm_sme::TileSliceLayout::Vertical:
672  rewriter.replaceOpWithNewOp<arm_sme::aarch64_sme_read_vert>(
673  extractTileSlice, sliceType, zeroVector, allTruePredicate, tileId,
674  sliceIndexI32);
675  break;
676  }
677 
678  return success();
679  }
680 };
681 
682 /// Lower `arm_sme.outerproduct` to SME MOPA intrinsics.
683 ///
684 /// Example:
685 ///
686 /// %0 = arm_sme.outerproduct %lhs, %rhs acc(%acc)
687 /// : vector<[4]xf32>, vector<[4]xf32>
688 ///
689 /// is converted to:
690 ///
691 /// "arm_sme.intr.mopa"(%ptrue_s, %ptrue_s, %lhs, %rhs) <{tile_id = 0 : i32}>
692 /// : (vector<[4]xi1>, vector<[4]xi1>, vector<[4]xf32>,
693 /// vector<[4]xf32>) -> ()
694 ///
695 /// Currently only supports FMOPA and BFMOPA (non-widening).
696 struct OuterProductOpConversion
697  : public ConvertArmSMEOpToLLVMPattern<arm_sme::OuterProductOp> {
698  using ConvertArmSMEOpToLLVMPattern::ConvertArmSMEOpToLLVMPattern;
699 
700  LogicalResult
701  matchAndRewrite(arm_sme::OuterProductOp outerProductOp,
702  arm_sme::OuterProductOp::Adaptor adaptor,
703  ConversionPatternRewriter &rewriter) const override {
704  auto tileId = getTileIdOrError(outerProductOp);
705  if (!tileId)
706  return failure();
707 
708  auto isSupportedType = [](VectorType vectorType) {
709  // TODO: the FP outer product instruction variants are predicated on
710  // different features [1]:
711  //
712  // * FMOPA (non-widening)
713  // * half-precision - +sme2p1,+sme-f16f16
714  // * single-precision - +sme
715  // * double-precision - +sme-f64f64
716  // * BFMOPA
717  // * half-precision - +sme2p1,+b16b16
718  //
719  // It should be possible to control lowering based on target features.
720  // [1]
721  // https://developer.arm.com/downloads/-/exploration-tools/feature-names-for-a-profile
722  if ((vectorType.getRank() != 2) || !vectorType.allDimsScalable())
723  return false;
724 
725  auto elementType = vectorType.getElementType();
726 
727  if (!elementType.isF16() && !elementType.isBF16() &&
728  !elementType.isF32() && !elementType.isF64())
729  return false;
730 
731  unsigned minNumElts = arm_sme::MinStreamingVectorLengthInBits /
732  vectorType.getElementTypeBitWidth();
733  return vectorType.getShape() ==
734  ArrayRef<int64_t>({minNumElts, minNumElts});
735  };
736 
737  // TODO: Support CombiningKind::Sub for outer products.
738  if (outerProductOp.getKind() != arm_sme::CombiningKind::Add)
739  return outerProductOp.emitError("unsupported kind");
740 
741  auto resultVectorType = outerProductOp.getResultType();
742  if (!isSupportedType(resultVectorType))
743  return outerProductOp.emitError("unsupported type");
744 
745  auto loc = outerProductOp.getLoc();
746 
747  Value acc = outerProductOp.getAcc();
748  if (!acc) {
749  // Initalize accumulator with zero.
750  auto zero = arm_sme::ZeroOp::create(rewriter, loc, resultVectorType);
751  zero.setTileId(tileId);
752  acc = zero;
753  }
754 
755  Value lhsMask = outerProductOp.getLhsMask();
756  Value rhsMask = outerProductOp.getRhsMask();
757 
758  if (!lhsMask || !rhsMask) {
759  auto predTy =
760  outerProductOp.getLhsType().cloneWith({}, rewriter.getI1Type());
761  Value allActiveMask = arith::ConstantOp::create(
762  rewriter, loc, DenseElementsAttr::get(predTy, true));
763  lhsMask = allActiveMask;
764  rhsMask = allActiveMask;
765  }
766 
767  // Create 'arm_sme.intr.mopa' outer product intrinsic.
768  arm_sme::aarch64_sme_mopa::create(rewriter, loc, tileId, lhsMask, rhsMask,
769  outerProductOp.getLhs(),
770  outerProductOp.getRhs());
771 
772  // The outerproduct intrinsics have no result, replace
773  // 'arm_sme.outerproduct' with the input tile to preserve dataflow.
774  rewriter.replaceOp(outerProductOp, acc);
775 
776  return success();
777  }
778 };
779 
780 /// Lower 2-way and 4-way widening outer products to intrinsics.
781 template <class OuterProductWideningOp, class OuterProductWideningIntrOp>
782 struct OuterProductWideningOpConversion
783  : public ConvertArmSMEOpToLLVMPattern<OuterProductWideningOp> {
784  using ConvertArmSMEOpToLLVMPattern<
785  OuterProductWideningOp>::ConvertArmSMEOpToLLVMPattern;
786 
787  LogicalResult
788  matchAndRewrite(OuterProductWideningOp op,
789  typename OuterProductWideningOp::Adaptor adaptor,
790  ConversionPatternRewriter &rewriter) const override {
791  auto tileId = getTileIdOrError(op);
792  if (!tileId)
793  return failure();
794 
795  auto loc = op.getLoc();
796  Value acc = op.getAcc();
797  if (!acc) {
798  // Initalize accumulator with zero.
799  auto zero = arm_sme::ZeroOp::create(rewriter, loc, op.getResultType());
800  zero.setTileId(tileId);
801  acc = zero;
802  }
803 
804  Value lhsMask = op.getLhsMask();
805  Value rhsMask = op.getRhsMask();
806  if (!lhsMask || !rhsMask) {
807  auto predTy = op.getLhsType().cloneWith({}, rewriter.getI1Type());
808  Value allActiveMask = arith::ConstantOp::create(
809  rewriter, loc, DenseElementsAttr::get(predTy, true));
810  lhsMask = allActiveMask;
811  rhsMask = allActiveMask;
812  }
813 
814  OuterProductWideningIntrOp::create(rewriter, loc, tileId, lhsMask, rhsMask,
815  adaptor.getLhs(), adaptor.getRhs());
816 
817  // The outerproduct intrinsics have no result, replace
818  // 'arm_sme.outerproduct' with the input tile to preserve dataflow.
819  rewriter.replaceOp(op, acc);
820 
821  return success();
822  }
823 };
824 
825 /// Lower `arm_sme.streaming_vl` to SME CNTS intrinsics.
826 ///
827 /// Example:
828 ///
829 /// %0 = arm_sme.streaming_vl <half>
830 ///
831 /// is converted to:
832 ///
833 /// %cnt = "arm_sme.intr.cntsh"() : () -> i64
834 /// %0 = arith.index_cast %cnt : i64 to index
835 ///
836 struct StreamingVLOpConversion
837  : public ConvertArmSMEOpToLLVMPattern<arm_sme::StreamingVLOp,
838  RequiresSpillsAndFills::No> {
839  using ConvertArmSMEOpToLLVMPattern::ConvertArmSMEOpToLLVMPattern;
840 
841  LogicalResult
842  matchAndRewrite(arm_sme::StreamingVLOp streamingVlOp,
843  arm_sme::StreamingVLOp::Adaptor adaptor,
844  ConversionPatternRewriter &rewriter) const override {
845  auto loc = streamingVlOp.getLoc();
846  auto i64Type = rewriter.getI64Type();
847  auto *intrOp = [&]() -> Operation * {
848  switch (streamingVlOp.getTypeSize()) {
849  case arm_sme::TypeSize::Byte:
850  return arm_sme::aarch64_sme_cntsb::create(rewriter, loc, i64Type);
851  case arm_sme::TypeSize::Half:
852  return arm_sme::aarch64_sme_cntsh::create(rewriter, loc, i64Type);
853  case arm_sme::TypeSize::Word:
854  return arm_sme::aarch64_sme_cntsw::create(rewriter, loc, i64Type);
855  case arm_sme::TypeSize::Double:
856  return arm_sme::aarch64_sme_cntsd::create(rewriter, loc, i64Type);
857  }
858  llvm_unreachable("unknown type size in StreamingVLOpConversion");
859  }();
860  rewriter.replaceOpWithNewOp<arith::IndexCastOp>(
861  streamingVlOp, rewriter.getIndexType(), intrOp->getResult(0));
862  return success();
863  }
864 };
865 
866 /// Merges consecutive `arm_sme.intr.zero` operations in a block by bitwise
867 /// or-ing the zero masks. Note: In future the backend _should_ handle this.
868 static void mergeConsecutiveTileZerosInBlock(Block *block) {
869  uint32_t mergedZeroMask = 0;
871  auto replaceMergedZeroOps = [&] {
872  auto cleanup = llvm::make_scope_exit([&] {
873  mergedZeroMask = 0;
874  zeroOpsToMerge.clear();
875  });
876  if (zeroOpsToMerge.size() <= 1)
877  return;
878  IRRewriter rewriter(zeroOpsToMerge.front());
879  arm_sme::aarch64_sme_zero::create(
880  rewriter, zeroOpsToMerge.front().getLoc(),
881  rewriter.getI32IntegerAttr(mergedZeroMask));
882  for (auto zeroOp : zeroOpsToMerge)
883  rewriter.eraseOp(zeroOp);
884  };
885  for (Operation &op : *block) {
886  if (auto zeroOp = dyn_cast<arm_sme::aarch64_sme_zero>(op)) {
887  mergedZeroMask |= zeroOp.getTileMask();
888  zeroOpsToMerge.push_back(zeroOp);
889  } else {
890  replaceMergedZeroOps();
891  }
892  }
893  replaceMergedZeroOps();
894 }
895 
896 } // namespace
897 
898 namespace {
899 
900 struct ConvertArmSMEToLLVMPass
901  : public impl::ConvertArmSMEToLLVMBase<ConvertArmSMEToLLVMPass> {
902  ConvertArmSMEToLLVMPass(bool dumpTileLiveRanges) {
903  this->dumpTileLiveRanges = dumpTileLiveRanges;
904  }
905  void runOnOperation() override {
906  auto function = getOperation();
907 
908  if (failed(arm_sme::allocateSMETiles(function, dumpTileLiveRanges)))
909  return signalPassFailure();
910 
913  LLVMTypeConverter converter(&getContext());
916 
917  if (failed(applyPartialConversion(function, target, std::move(patterns))))
918  signalPassFailure();
919 
920  function->walk(mergeConsecutiveTileZerosInBlock);
921 
922  // Walk the function and fail if there are unexpected operations on SME
923  // tile types after conversion.
924  function->walk([&](Operation *op) {
925  // These ops are legal post conversion, skip these.
926  if (isa<arm_sme::CopyTileOp, arm_sme::GetTileOp, cf::BranchOp>(op) ||
927  !op->isRegistered())
928  return;
929  auto isSMETileType = [](Type type) {
930  return arm_sme::isValidSMETileVectorType(type);
931  };
932  if (llvm::any_of(op->getResultTypes(), isSMETileType) ||
933  llvm::any_of(op->getOperandTypes(), isSMETileType)) {
934  op->emitOpError("unexpected operation with SME tile type after "
935  "conversion to LLVM");
936  signalPassFailure();
937  }
938  });
939  }
940 };
941 
942 } // namespace
943 
945  target.addIllegalDialect<arm_sme::ArmSMEDialect>();
946  target.addLegalOp<
947  arm_sme::aarch64_sme_zero, arm_sme::aarch64_sme_str,
948  arm_sme::aarch64_sme_ld1b_horiz, arm_sme::aarch64_sme_ld1h_horiz,
949  arm_sme::aarch64_sme_ld1w_horiz, arm_sme::aarch64_sme_ld1d_horiz,
950  arm_sme::aarch64_sme_ld1q_horiz, arm_sme::aarch64_sme_st1b_horiz,
951  arm_sme::aarch64_sme_st1h_horiz, arm_sme::aarch64_sme_st1w_horiz,
952  arm_sme::aarch64_sme_st1d_horiz, arm_sme::aarch64_sme_st1q_horiz,
953  arm_sme::aarch64_sme_ld1b_vert, arm_sme::aarch64_sme_ld1h_vert,
954  arm_sme::aarch64_sme_ld1w_vert, arm_sme::aarch64_sme_ld1d_vert,
955  arm_sme::aarch64_sme_ld1q_vert, arm_sme::aarch64_sme_st1b_vert,
956  arm_sme::aarch64_sme_st1h_vert, arm_sme::aarch64_sme_st1w_vert,
957  arm_sme::aarch64_sme_st1d_vert, arm_sme::aarch64_sme_st1q_vert,
958  arm_sme::aarch64_sme_read_horiz, arm_sme::aarch64_sme_read_vert,
959  arm_sme::aarch64_sme_write_horiz, arm_sme::aarch64_sme_write_vert,
960  arm_sme::aarch64_sme_mopa, arm_sme::aarch64_sme_mopa_wide,
961  arm_sme::aarch64_sme_mops_wide, arm_sme::aarch64_sme_smopa_wide,
962  arm_sme::aarch64_sme_smops_wide, arm_sme::aarch64_sme_umopa_wide,
963  arm_sme::aarch64_sme_umops_wide, arm_sme::aarch64_sme_smopa_za32,
964  arm_sme::aarch64_sme_smops_za32, arm_sme::aarch64_sme_umopa_za32,
965  arm_sme::aarch64_sme_umops_za32, arm_sme::aarch64_sme_sumopa_wide,
966  arm_sme::aarch64_sme_sumops_wide, arm_sme::aarch64_sme_usmopa_wide,
967  arm_sme::aarch64_sme_usmops_wide, arm_sme::aarch64_sme_cntsb,
968  arm_sme::aarch64_sme_cntsh, arm_sme::aarch64_sme_cntsw,
969  arm_sme::aarch64_sme_cntsd>();
970  target.addLegalDialect<arith::ArithDialect,
971  /* The following are used to lower tile spills/fills */
972  vector::VectorDialect, scf::SCFDialect,
973  memref::MemRefDialect>();
974  // Pseudo operations. These cannot be code-generated but may exist in the
975  // input IR, or be generated during the conversion. They need to be eliminated
976  // before the final conversion to LLVM IR (and likely will be due to DCE).
977  target.addLegalOp<arm_sme::GetTileOp, arm_sme::CopyTileOp,
978  UnrealizedConversionCastOp>();
979 }
980 
983  converter.addConversion([&](VectorType type) -> std::optional<Type> {
984  // There's no LLVM type for SME tiles, but after lowering to intrinsics all
985  // SME vector types should be eliminated.
987  return type;
988  return std::nullopt;
989  });
990 
991  addArmSMEConversionPatterns<
992  LoadTileSliceConversion, ExtractTileSliceConversion,
993  InsertTileSliceConversion, StoreTileSliceConversion,
994  StreamingVLOpConversion, OuterProductOpConversion,
995  OuterProductWideningOpConversion<arm_sme::FMopa2WayOp,
996  arm_sme::aarch64_sme_mopa_wide>,
997  OuterProductWideningOpConversion<arm_sme::FMops2WayOp,
998  arm_sme::aarch64_sme_mops_wide>,
999  OuterProductWideningOpConversion<arm_sme::SMopa2WayOp,
1000  arm_sme::aarch64_sme_smopa_za32>,
1001  OuterProductWideningOpConversion<arm_sme::SMops2WayOp,
1002  arm_sme::aarch64_sme_smops_za32>,
1003  OuterProductWideningOpConversion<arm_sme::UMopa2WayOp,
1004  arm_sme::aarch64_sme_umopa_za32>,
1005  OuterProductWideningOpConversion<arm_sme::UMops2WayOp,
1006  arm_sme::aarch64_sme_umops_za32>,
1007  OuterProductWideningOpConversion<arm_sme::SMopa4WayOp,
1008  arm_sme::aarch64_sme_smopa_wide>,
1009  OuterProductWideningOpConversion<arm_sme::SMops4WayOp,
1010  arm_sme::aarch64_sme_smops_wide>,
1011  OuterProductWideningOpConversion<arm_sme::UMopa4WayOp,
1012  arm_sme::aarch64_sme_umopa_wide>,
1013  OuterProductWideningOpConversion<arm_sme::UMops4WayOp,
1014  arm_sme::aarch64_sme_umops_wide>,
1015  OuterProductWideningOpConversion<arm_sme::SuMopa4WayOp,
1016  arm_sme::aarch64_sme_sumopa_wide>,
1017  OuterProductWideningOpConversion<arm_sme::SuMops4WayOp,
1018  arm_sme::aarch64_sme_sumops_wide>,
1019  OuterProductWideningOpConversion<arm_sme::UsMopa4WayOp,
1020  arm_sme::aarch64_sme_usmopa_wide>,
1021  OuterProductWideningOpConversion<arm_sme::UsMops4WayOp,
1022  arm_sme::aarch64_sme_usmops_wide>,
1023  ZeroOpConversion>(patterns, converter);
1024 }
1025 
1026 std::unique_ptr<Pass>
1027 mlir::createConvertArmSMEToLLVMPass(bool dumpTileLiveRanges) {
1028  return std::make_unique<ConvertArmSMEToLLVMPass>(dumpTileLiveRanges);
1029 }
static MLIRContext * getContext(OpFoldResult val)
Block represents an ordered list of Operations.
Definition: Block.h:33
IntegerAttr getI32IntegerAttr(int32_t value)
Definition: Builders.cpp:195
IntegerAttr getIntegerAttr(Type type, int64_t value)
Definition: Builders.cpp:223
IntegerType getI64Type()
Definition: Builders.cpp:64
IntegerType getI32Type()
Definition: Builders.cpp:62
TypedAttr getZeroAttr(Type type)
Definition: Builders.cpp:319
IntegerType getI1Type()
Definition: Builders.cpp:52
IndexType getIndexType()
Definition: Builders.cpp:50
This class implements a pattern rewriter for use with ConversionPatterns.
void replaceOp(Operation *op, ValueRange newValues) override
Replace the given operation with the new values.
void eraseOp(Operation *op) override
PatternRewriter hook for erasing a dead operation.
This class describes a specific conversion target.
void addLegalOp(OperationName op)
Register the given operations as legal.
void addLegalDialect(StringRef name, Names... names)
Register the operations of the given dialects as legal.
void addIllegalDialect(StringRef name, Names... names)
Register the operations of the given dialects as illegal, i.e.
Utility class for operation conversions targeting the LLVM dialect that match exactly one source oper...
Definition: Pattern.h:209
Base class for operation conversions targeting the LLVM IR dialect.
Definition: Pattern.h:88
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
This class coordinates rewriting a piece of IR outside of a pattern rewrite, providing a way to keep ...
Definition: PatternMatch.h:764
Derived class that automatically populates legalization information for different LLVM ops.
Conversion from types to the LLVM IR dialect.
Definition: TypeConverter.h:35
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:76
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
Definition: Builders.h:429
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Definition: Builders.h:396
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
Definition: Builders.h:410
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
bool isRegistered()
Returns true if this operation has a registered operation description, otherwise false.
Definition: Operation.h:129
operand_type_range getOperandTypes()
Definition: Operation.h:397
result_type_range getResultTypes()
Definition: Operation.h:428
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
Definition: PatternMatch.h:34
This class contains all of the data related to a pattern, but does not contain any methods or logic f...
Definition: PatternMatch.h:73
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
Definition: PatternMatch.h:358
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
Definition: PatternMatch.h:628
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Definition: PatternMatch.h:519
void addConversion(FnT &&callback)
Register a conversion function.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:387
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:105
This is a builder type that keeps local references to arguments.
Definition: BuiltinTypes.h:286
Builder & dropDim(unsigned pos)
Erase a dim from shape @pos.
Definition: BuiltinTypes.h:311
static ConstantIndexOp create(OpBuilder &builder, Location location, int64_t value)
Definition: ArithOps.cpp:359
static ConstantIntOp create(OpBuilder &builder, Location location, int64_t value, unsigned width)
Definition: ArithOps.cpp:258
Value getStridedElementPtr(OpBuilder &builder, Location loc, const LLVMTypeConverter &converter, MemRefType type, Value memRefDesc, ValueRange indices, LLVM::GEPNoWrapFlags noWrapFlags=LLVM::GEPNoWrapFlags::none)
Performs the index computation to get to the element at indices of the memory pointed to by memRefDes...
Definition: Pattern.cpp:478
std::optional< ArmSMETileType > getSMETileType(VectorType)
Returns the type of SME tile this vector type corresponds to, or none if the vector type does not fit...
Definition: Utils.cpp:43
LogicalResult allocateSMETiles(FunctionOpInterface function, bool dumpRanges=false)
Allocate tile IDs to all ArmSME operations in a function.
unsigned getSMETileSliceMinNumElts(Type type)
Return minimum number of elements for the given element type in a vector of SVL bits.
Definition: Utils.cpp:17
bool isValidSMETileVectorType(VectorType vType)
Returns true if vType is a valid vector type for an SME tile or false otherwise.
Definition: Utils.cpp:28
constexpr unsigned MinStreamingVectorLengthInBits
Definition: Utils.h:33
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition: Remarks.h:491
Include the generated interface declarations.
std::unique_ptr< Pass > createConvertArmSMEToLLVMPass(bool dumpTileLiveRanges=false)
Create a pass to convert from the ArmSME dialect to LLVM intrinsics.
const FrozenRewritePatternSet & patterns
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
void configureArmSMEToLLVMConversionLegality(ConversionTarget &target)
Configure target to convert from the ArmSME dialect to LLVM intrinsics.
LogicalResult applyPartialConversion(ArrayRef< Operation * > ops, const ConversionTarget &target, const FrozenRewritePatternSet &patterns, ConversionConfig config=ConversionConfig())
Below we define several entry points for operation conversion.
void populateArmSMEToLLVMConversionPatterns(LLVMTypeConverter &converter, RewritePatternSet &patterns)
Populate the given list with patterns that convert from the ArmSME dialect to LLVM intrinsics.