MLIR  20.0.0git
ArmSMEToLLVM.cpp
Go to the documentation of this file.
1 //===- ArmSMEToLLVM.cpp - Convert ArmSME to LLVM dialect ------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements lowering of ArmSME operations to LLVM intrinsics.
10 //
11 //===----------------------------------------------------------------------===//
12 
14 
26 #include "mlir/Pass/Pass.h"
28 #include "llvm/ADT/ScopeExit.h"
29 
30 namespace mlir {
31 #define GEN_PASS_DEF_CONVERTARMSMETOLLVM
32 #include "mlir/Conversion/Passes.h.inc"
33 } // namespace mlir
34 
35 using namespace mlir;
36 
37 namespace {
38 
39 static constexpr StringLiteral kInMemoryTileIdAttr("arm_sme.in_memory_tile_id");
40 
41 /// Helper to create an arm_sme.intr.ld1*.(horiz|vert)' intrinsic.
42 static Operation *createLoadTileSliceIntrinsic(
43  RewriterBase &rewriter, Location loc, arm_sme::ArmSMETileType type,
44  arm_sme::TileSliceLayout layout, Value maskOp, Value ptr,
45  IntegerAttr tileId, Value tileSliceI32) {
46  if (layout == arm_sme::TileSliceLayout::Horizontal) {
47  switch (type) {
48  case arm_sme::ArmSMETileType::ZAB:
49  return rewriter.create<arm_sme::aarch64_sme_ld1b_horiz>(
50  loc, maskOp, ptr, tileId, tileSliceI32);
51  case arm_sme::ArmSMETileType::ZAH:
52  return rewriter.create<arm_sme::aarch64_sme_ld1h_horiz>(
53  loc, maskOp, ptr, tileId, tileSliceI32);
54  case arm_sme::ArmSMETileType::ZAS:
55  return rewriter.create<arm_sme::aarch64_sme_ld1w_horiz>(
56  loc, maskOp, ptr, tileId, tileSliceI32);
57  case arm_sme::ArmSMETileType::ZAD:
58  return rewriter.create<arm_sme::aarch64_sme_ld1d_horiz>(
59  loc, maskOp, ptr, tileId, tileSliceI32);
60  case arm_sme::ArmSMETileType::ZAQ:
61  return rewriter.create<arm_sme::aarch64_sme_ld1q_horiz>(
62  loc, maskOp, ptr, tileId, tileSliceI32);
63  }
64  } else {
65  switch (type) {
66  case arm_sme::ArmSMETileType::ZAB:
67  return rewriter.create<arm_sme::aarch64_sme_ld1b_vert>(
68  loc, maskOp, ptr, tileId, tileSliceI32);
69  case arm_sme::ArmSMETileType::ZAH:
70  return rewriter.create<arm_sme::aarch64_sme_ld1h_vert>(
71  loc, maskOp, ptr, tileId, tileSliceI32);
72  case arm_sme::ArmSMETileType::ZAS:
73  return rewriter.create<arm_sme::aarch64_sme_ld1w_vert>(
74  loc, maskOp, ptr, tileId, tileSliceI32);
75  case arm_sme::ArmSMETileType::ZAD:
76  return rewriter.create<arm_sme::aarch64_sme_ld1d_vert>(
77  loc, maskOp, ptr, tileId, tileSliceI32);
78  case arm_sme::ArmSMETileType::ZAQ:
79  return rewriter.create<arm_sme::aarch64_sme_ld1q_vert>(
80  loc, maskOp, ptr, tileId, tileSliceI32);
81  break;
82  }
83  }
84 }
85 
86 /// Helper to create an arm_sme.intr.st1*.(horiz|vert)' intrinsic.
87 static Operation *createStoreTileSliceIntrinsic(
88  RewriterBase &rewriter, Location loc, arm_sme::ArmSMETileType type,
89  arm_sme::TileSliceLayout layout, Value maskOp, Value ptr,
90  IntegerAttr tileId, Value tileSliceI32) {
91  if (layout == arm_sme::TileSliceLayout::Horizontal) {
92  switch (type) {
93  case arm_sme::ArmSMETileType::ZAB:
94  return rewriter.create<arm_sme::aarch64_sme_st1b_horiz>(
95  loc, maskOp, ptr, tileId, tileSliceI32);
96  case arm_sme::ArmSMETileType::ZAH:
97  return rewriter.create<arm_sme::aarch64_sme_st1h_horiz>(
98  loc, maskOp, ptr, tileId, tileSliceI32);
99  case arm_sme::ArmSMETileType::ZAS:
100  return rewriter.create<arm_sme::aarch64_sme_st1w_horiz>(
101  loc, maskOp, ptr, tileId, tileSliceI32);
102  case arm_sme::ArmSMETileType::ZAD:
103  return rewriter.create<arm_sme::aarch64_sme_st1d_horiz>(
104  loc, maskOp, ptr, tileId, tileSliceI32);
105  case arm_sme::ArmSMETileType::ZAQ:
106  return rewriter.create<arm_sme::aarch64_sme_st1q_horiz>(
107  loc, maskOp, ptr, tileId, tileSliceI32);
108  }
109  } else {
110  switch (type) {
111  case arm_sme::ArmSMETileType::ZAB:
112  return rewriter.create<arm_sme::aarch64_sme_st1b_vert>(
113  loc, maskOp, ptr, tileId, tileSliceI32);
114  case arm_sme::ArmSMETileType::ZAH:
115  return rewriter.create<arm_sme::aarch64_sme_st1h_vert>(
116  loc, maskOp, ptr, tileId, tileSliceI32);
117  case arm_sme::ArmSMETileType::ZAS:
118  return rewriter.create<arm_sme::aarch64_sme_st1w_vert>(
119  loc, maskOp, ptr, tileId, tileSliceI32);
120  case arm_sme::ArmSMETileType::ZAD:
121  return rewriter.create<arm_sme::aarch64_sme_st1d_vert>(
122  loc, maskOp, ptr, tileId, tileSliceI32);
123  case arm_sme::ArmSMETileType::ZAQ:
124  return rewriter.create<arm_sme::aarch64_sme_st1q_vert>(
125  loc, maskOp, ptr, tileId, tileSliceI32);
126  }
127  }
128 }
129 
130 IntegerAttr getTileIdOrError(arm_sme::ArmSMETileOpInterface op) {
131  auto tileId = op.getTileId();
132  if (!tileId)
133  op.emitOpError(
134  "expected tile ID to be allocated before conversion to LLVM");
135  return tileId;
136 }
137 
138 /// Creates an alloca matching the size of tile used by `tileOp`. The alloca is
139 /// placed in the first block of the function.
140 static memref::AllocaOp
141 createAllocaForTile(RewriterBase &rewriter, Location loc,
142  FunctionOpInterface func,
143  arm_sme::ArmSMETileOpInterface tileOp) {
144  RewriterBase::InsertionGuard g(rewriter);
145  // Move to the first operation in the function.
146  rewriter.setInsertionPointToStart(&func.getBlocks().front());
147  // Create an alloca matching the tile size of the `tileOp`.
148  auto vscale = rewriter.create<vector::VectorScaleOp>(loc);
149  auto tileElementType = tileOp.getTileType().getElementType();
150  auto memrefType = MemRefType::get(
151  {ShapedType::kDynamic, ShapedType::kDynamic}, tileElementType);
152  unsigned minElements = arm_sme::getSMETileSliceMinNumElts(tileElementType);
153  auto minElementsOp =
154  rewriter.create<arith::ConstantIndexOp>(loc, minElements);
155  auto vectorLen = rewriter.create<arith::MulIOp>(loc, vscale, minElementsOp);
156  auto alloca = rewriter.create<memref::AllocaOp>(
157  loc, memrefType, ValueRange{vectorLen, vectorLen});
158  return alloca;
159 }
160 
161 /// Finds or creates an alloca for a spill of a tile.
162 static memref::AllocaOp getOrCreateAllocaForTile(
163  RewriterBase &rewriter, Location loc, FunctionOpInterface func,
164  arm_sme::ArmSMETileOpInterface tileOp, unsigned tileId) {
165  // Find an alloca at the top of the function tagged with a
166  // 'arm_sme.in_memory_tile_id' that matches `tileId`.
167  for (auto &op : func.getBlocks().front()) {
168  auto alloca = llvm::dyn_cast<memref::AllocaOp>(op);
169  if (!alloca)
170  continue;
171  auto inMemoryTileId = llvm::dyn_cast_or_null<IntegerAttr>(
172  alloca->getDiscardableAttr(kInMemoryTileIdAttr));
173  if (!inMemoryTileId)
174  continue;
175  if (inMemoryTileId.getInt() == tileId)
176  return alloca;
177  }
178  // Otherwise, create a new alloca:
179  auto alloca = createAllocaForTile(rewriter, loc, func, tileOp);
180  alloca->setDiscardableAttr(kInMemoryTileIdAttr,
181  rewriter.getI32IntegerAttr(tileId));
182  return alloca;
183 }
184 
185 /// Very naive lowering of in-memory tiles (i.e. tiles that were not assigned a
186 /// hardware tile ID) to ArmSME intrinsics. Currently, this works by assigning
187 /// the op to tile 0, then emitting a full tile swap between ZA and memory
188 /// before + after the tile op.
189 ///
190 /// Example:
191 ///
192 /// // Note: <IN MEMORY TILE> = tile ID >= 16.
193 /// arm_sme.tile_op { tile_id = <IN MEMORY TILE> }
194 ///
195 /// is converted to:
196 /// // At function entry:
197 /// %spill = memref.alloca ... : memref<?x?xty>
198 ///
199 /// // Around op:
200 /// scf.for %slice_idx {
201 /// %slice_to_save = "arm_sme.intr.read.horiz" ... <{tile_id = 0 : i32}>
202 /// "arm_sme.intr.ld1h.horiz"(%spill, %slice_idx) <{tile_id = 0 : i32}>
203 /// vector.store %slice_to_save, %spill[%slice_idx, %c0]
204 /// }
205 /// arm_sme.tile_op { tile_id = 0 }
206 /// scf.for %slice_idx {
207 /// %slice_to_save = "arm_sme.intr.read.horiz" ... <{tile_id = 0 : i32}>
208 /// "arm_sme.intr.ld1h.horiz"(%spill, %slice_idx) <{tile_id = 0 : i32}>
209 /// vector.store %slice_to_save, %spill[%slice_idx, %c0]
210 /// }
211 ///
212 /// Note that these spills/fills are not inserted earlier as concept of a
213 /// register, and the need to swap the contents, can't really be represented
214 /// correctly at a high level in MLIR.
215 ///
216 /// TODO: Reduce the spills/reloads to single slices where possible (and omit
217 /// redundant reloads). This could be done via a method on the
218 /// `ArmSMETileOpInterface` which returns how the operation uses ZA. E.g.:
219 ///
220 /// `tileOp.getZaUsage()` could return:
221 ///
222 /// struct ArmSMEOpZAUsage {
223 /// enum class Kind {
224 /// TileRead, // Omit store after tile operation.
225 /// TileWrite, // Omit load before tile operation.
226 /// TileReadWrite, // Needs both tile load and store.
227 /// SliceRead, // Spill single slice and omit store after operation.
228 /// SliceWrite, // Spill single slice and omit load before operation.
229 /// SliceReadWrite // Spill single slice.
230 /// };
231 /// Value sliceIndex {};
232 /// TileSliceLayout sliceLayout { TileSliceLayout::Horizontal };
233 /// };
234 ///
235 struct ConvertArmSMESpillsAndFillsToLLVM : public ConvertToLLVMPattern {
236 
237  ConvertArmSMESpillsAndFillsToLLVM(StringRef rootOpName,
238  const LLVMTypeConverter &typeConverter,
239  PatternBenefit benefit)
240  : ConvertToLLVMPattern(rootOpName, &typeConverter.getContext(),
241  typeConverter, benefit) {}
242 
243  LogicalResult
244  matchAndRewrite(Operation *op, ArrayRef<Value> operands,
245  ConversionPatternRewriter &rewriter) const override {
246  auto tileOp = cast<arm_sme::ArmSMETileOpInterface>(op);
247  // Tile has a real (hardware) tile. No spills/reloads required.
248  if (!tileOp.isInMemoryTile())
249  return failure();
250 
251  tileOp->emitWarning(
252  "failed to allocate SME virtual tile to operation, tile value will go "
253  "through memory, expect degraded performance");
254 
255  // Step 1. Create an alloca for the tile at the top of the function (if one
256  // does not already exist).
257  auto loc = tileOp.getLoc();
258  auto func = tileOp->getParentOfType<FunctionOpInterface>();
259  auto tileAlloca = getOrCreateAllocaForTile(rewriter, loc, func, tileOp,
260  tileOp.getTileId().getInt());
261 
262  // Step 2. Assign the op a real tile ID.
263  // For simplicity, we always use tile 0 (which always exists).
264  auto zeroTileId = rewriter.getI32IntegerAttr(0);
265  rewriter.modifyOpInPlace(tileOp, [&] { tileOp.setTileId(zeroTileId); });
266 
267  VectorType tileVectorType = tileOp.getTileType();
268  auto sliceType = VectorType::Builder(tileVectorType).dropDim(0);
269  auto swapInMemoryTileWithSMETileZero = [&] {
270  emitFullTileSwap(rewriter, loc, tileAlloca,
271  *arm_sme::getSMETileType(tileVectorType), sliceType,
272  zeroTileId);
273  };
274 
275  // Step 3. Emit tile swaps before and after the op.
276  // TODO: Reduce the amount spilled to the amount of data the `tileOp`
277  // touches (i.e. a single tile slice).
278  {
279  rewriter.setInsertionPoint(op);
280  // Swap the contents of ZA and the in-memory tile before the op.
281  swapInMemoryTileWithSMETileZero();
282  rewriter.setInsertionPointAfter(op);
283  // Swap the tile back out to memory again after the op.
284  swapInMemoryTileWithSMETileZero();
285  }
286 
287  return success();
288  }
289 
290  /// Extracts a pointer to a slice of an in-memory tile.
291  Value getInMemoryTileSlicePtr(RewriterBase &rewriter, Location loc,
292  Value tileMemory, Value sliceIndex) const {
293  auto llvmType = getTypeConverter()->convertType(tileMemory.getType());
294  auto descriptor =
295  rewriter.create<UnrealizedConversionCastOp>(loc, llvmType, tileMemory);
296  auto zero = rewriter.create<arith::ConstantIntOp>(loc, 0, /*width=*/64);
297  auto sliceIndexI64 = rewriter.create<arith::IndexCastOp>(
298  loc, rewriter.getI64Type(), sliceIndex);
299  return getStridedElementPtr(
300  loc, llvm::cast<MemRefType>(tileMemory.getType()),
301  descriptor.getResult(0), {sliceIndexI64, zero},
302  static_cast<ConversionPatternRewriter &>(rewriter));
303  }
304 
305  /// Emits an in-place swap of a slice of a tile in ZA and a slice of a
306  /// tile-sized memref (`tileAlloca`).
307  void emitSliceSwap(RewriterBase &rewriter, Location loc, Value tileAlloca,
308  arm_sme::ArmSMETileType tileType, VectorType sliceType,
309  IntegerAttr tileId, Value sliceIndex) const {
310  // Cast the slice index to an i32.
311  auto sliceIndexI32 = rewriter.create<arith::IndexCastOp>(
312  loc, rewriter.getI32Type(), sliceIndex);
313  // Create an all-true predicate for the slice.
314  auto predicateType = sliceType.clone(rewriter.getI1Type());
315  auto allTruePredicate = rewriter.create<arith::ConstantOp>(
316  loc, DenseElementsAttr::get(predicateType, true));
317  // Create padding vector (never used due to all-true predicate).
318  auto padVector = rewriter.create<LLVM::UndefOp>(loc, sliceType);
319  // Get a pointer to the current slice.
320  auto slicePtr =
321  getInMemoryTileSlicePtr(rewriter, loc, tileAlloca, sliceIndex);
322  // Read the value of the current slice from ZA.
323  auto currentTileSlice = rewriter.create<arm_sme::aarch64_sme_read_horiz>(
324  loc, sliceType, padVector, allTruePredicate, tileId, sliceIndexI32);
325  // Load the new tile slice back from memory into ZA.
326  createLoadTileSliceIntrinsic(
327  rewriter, loc, tileType, arm_sme::TileSliceLayout::Horizontal,
328  allTruePredicate, slicePtr, tileId, sliceIndexI32);
329  // Store the current tile slice to memory.
330  auto zero = rewriter.create<arith::ConstantIndexOp>(loc, 0);
331  rewriter.create<vector::StoreOp>(loc, currentTileSlice, tileAlloca,
332  ValueRange{sliceIndex, zero});
333  }
334 
335  /// Emits a full in-place swap of the contents of a tile in ZA and a
336  /// tile-sized memref (`tileAlloca`).
337  void emitFullTileSwap(RewriterBase &rewriter, Location loc, Value tileAlloca,
338  arm_sme::ArmSMETileType tileType, VectorType sliceType,
339  IntegerAttr tileId) const {
340  RewriterBase::InsertionGuard guard(rewriter);
341  // Create an scf.for over all tile slices.
342  auto minNumElts =
343  rewriter.create<arith::ConstantIndexOp>(loc, sliceType.getDimSize(0));
344  auto lowerBound = rewriter.create<arith::ConstantIndexOp>(loc, 0);
345  auto upperBound = rewriter.create<arith::MulIOp>(
346  loc, minNumElts, rewriter.create<vector::VectorScaleOp>(loc));
347  auto step = rewriter.create<arith::ConstantIndexOp>(loc, 1);
348  auto forOp = rewriter.create<scf::ForOp>(loc, lowerBound, upperBound, step);
349  // Emit a swap for each tile slice.
350  rewriter.setInsertionPointToStart(forOp.getBody());
351  auto sliceIndex = forOp.getInductionVar();
352  emitSliceSwap(rewriter, loc, tileAlloca, tileType, sliceType, tileId,
353  sliceIndex);
354  }
355 };
356 
357 enum class RequiresSpillsAndFills { Yes, No };
358 
359 /// Base class for ArmSME to LLVM conversion patterns. By default, this adds
360 /// spills and fills around ArmSME ops that use in-memory tile IDs. This can be
361 /// disabled by setting the `requiresSpillsAndFills` template parameter to
362 /// `RequiresSpillsAndFills::No`.
363 template <typename SourceOp, RequiresSpillsAndFills requiresSpillsAndFills =
364  RequiresSpillsAndFills::Yes>
365 struct ConvertArmSMEOpToLLVMPattern : ConvertOpToLLVMPattern<SourceOp> {
366  using ArmSMEOp = SourceOp;
368 
369  static constexpr bool requiresSpillsAndFillsConversion() {
370  return requiresSpillsAndFills == RequiresSpillsAndFills::Yes;
371  }
372 };
373 
374 template <typename Pattern>
375 static void addArmSMEConversionPattern(RewritePatternSet &patterns,
376  LLVMTypeConverter const &typeConverter) {
377  // Register spills/fills for ops that implement the
378  // `ArmSMETileOpInterface` and have `requiresSpillsAndFills` set to
379  // `RequiresSpillsAndFills::Yes`.
380  if constexpr (Pattern::requiresSpillsAndFillsConversion() &&
381  std::is_base_of_v<arm_sme::ArmSMETileOpInterface::Trait<
382  typename Pattern::ArmSMEOp>,
383  typename Pattern::ArmSMEOp>) {
384  // Add spill/fill conversions with a very high benefit to ensure
385  // they are lowered first.
386  patterns.add<ConvertArmSMESpillsAndFillsToLLVM>(
387  Pattern::ArmSMEOp::getOperationName(), typeConverter,
388  /*benefit=*/1337);
389  }
390  patterns.add<Pattern>(typeConverter);
391 }
392 
393 /// Helper to register `ConvertArmSMEOpToLLVMPattern` patterns.
394 template <typename... Patterns>
395 static void
396 addArmSMEConversionPatterns(RewritePatternSet &patterns,
397  LLVMTypeConverter const &typeConverter) {
398  (addArmSMEConversionPattern<Patterns>(patterns, typeConverter), ...);
399 }
400 
401 /// Lower 'arm_sme.zero' to SME intrinsics.
402 ///
403 /// BEFORE:
404 /// ```mlir
405 /// %v = arm_sme.zero {tile_id = 0 : i32} : vector<[4]x[4]xi32>
406 /// ```
407 ///
408 /// AFTER:
409 /// ```mlir
410 /// "arm_sme.intr.zero"() <{tile_mask = 17 : i32}> : () -> ()
411 /// %v = arm_sme.get_tile : vector<[4]x[4]xi32>
412 /// ```
413 ///
414 /// The 'arm_sme.get_tile' (which models the return) will fold away once all
415 /// ArmSME ops have been converted to LLVM intrinsics.
416 struct ZeroOpConversion : public ConvertArmSMEOpToLLVMPattern<arm_sme::ZeroOp> {
417  using ConvertArmSMEOpToLLVMPattern::ConvertArmSMEOpToLLVMPattern;
418 
419  LogicalResult
420  matchAndRewrite(arm_sme::ZeroOp zero, OpAdaptor adaptor,
421  ConversionPatternRewriter &rewriter) const override {
422  auto loc = zero.getLoc();
423 
424  auto tileId = getTileIdOrError(zero);
425  if (!tileId)
426  return failure();
427 
428  // Get the base mask for tile based on the element size.
429  // The base mask is just the mask to zero the first tile (of a size).
430  // These masks are derived from:
431  // https://developer.arm.com/documentation/ddi0602/2022-06/SME-Instructions/ZERO--Zero-a-list-of-64-bit-element-ZA-tiles-
432  arm_sme::ArmSMETileType tileType =
433  *arm_sme::getSMETileType(zero.getTileType());
434  auto baseMaskForSize = [&] {
435  switch (tileType) {
436  case arm_sme::ArmSMETileType::ZAB:
437  // Zeroing the 8-bit ZA0.B tile is equivalent to zeroing all eight
438  // 64-bit element tiles named ZA0.D to ZA7.D.
439  return 0b1111'1111;
440  case arm_sme::ArmSMETileType::ZAH:
441  // Zeroing the 16-bit ZA0.H tile is equivalent to zeroing 64-bit
442  // element tiles named ZA0.D, ZA2.D, ZA4.D, and ZA6.D. Shift this left
443  // once for ZA1.H.
444  return 0b0101'0101;
445  case arm_sme::ArmSMETileType::ZAS:
446  // Zeroing the 32-bit ZA0.S tile is equivalent to zeroing 64-bit
447  // element tiles named ZA0.D and ZA4.D.
448  // Shift left by 1, 2, or 3 respectively for ZA1.S, ZA2.S, ZA3.S.
449  return 0b0001'0001;
450  case arm_sme::ArmSMETileType::ZAD:
451  // Zeroing one of the a 64-bit tiles ZA0.D to ZA7.D just requires
452  // setting the bit for that tile.
453  return 0b0000'0001;
454  default:
455  llvm_unreachable("bad element size");
456  }
457  }();
458 
459  // The actual mask is just the base mask shifted by the tile ID.
460  // This will be folded to a constant after tile allocation.
461  //
462  // The shift is just derived from the layout of the tiles, and that the tile
463  // ID is the index of the tile. For example, looking at the 32-bit ZAx.S
464  // tiles:
465  //
466  // ZA0.S = ZA0.D and ZA4.D
467  // * Tile ID -> 0
468  // * Mask -> 00010001 = (00010001 << 0)
469  // ZA1.S = ZA1.D and ZA5.D
470  // * Tile ID -> 1
471  // * Mask -> 00100010 = (00010001 << 1)
472  // ZA2.S = ZA2.D and ZA6.D
473  // * Tile ID -> 2
474  // * Mask -> 01000100 = (00010001 << 2)
475  // ZA3.S = ZA3.D and ZA7.D
476  // * Tile ID -> 3
477  // * Mask -> 10001000 = (00010001 << 3)
478  //
479  // This holds for all tile sizes.
480  int32_t zeroMask = baseMaskForSize << int32_t(tileId.getInt());
481  rewriter.create<arm_sme::aarch64_sme_zero>(
482  loc, rewriter.getI32IntegerAttr(zeroMask));
483 
484  // Create a placeholder op to preserve dataflow.
485  // Note: Place the `get_tile` op at the start of the block. This ensures
486  // that if there are multiple `zero` ops the intrinsics will be consecutive.
487  rewriter.setInsertionPointToStart(zero->getBlock());
488  rewriter.replaceOpWithNewOp<arm_sme::GetTileOp>(zero, zero.getVectorType());
489 
490  return success();
491  }
492 };
493 
494 /// Lower `arm_sme.load_tile_slice` to SME intrinsics.
495 struct LoadTileSliceConversion
496  : public ConvertArmSMEOpToLLVMPattern<arm_sme::LoadTileSliceOp> {
497  using ConvertArmSMEOpToLLVMPattern::ConvertArmSMEOpToLLVMPattern;
498 
499  LogicalResult
500  matchAndRewrite(arm_sme::LoadTileSliceOp loadTileSliceOp,
501  arm_sme::LoadTileSliceOp::Adaptor adaptor,
502  ConversionPatternRewriter &rewriter) const override {
503  auto loc = loadTileSliceOp.getLoc();
504  auto tileId = getTileIdOrError(loadTileSliceOp);
505  if (!tileId)
506  return failure();
507 
508  Value ptr = this->getStridedElementPtr(loc, loadTileSliceOp.getMemRefType(),
509  adaptor.getBase(),
510  adaptor.getIndices(), rewriter);
511 
512  auto tileSlice = loadTileSliceOp.getTileSliceIndex();
513 
514  // Cast tile slice to i32 for intrinsic.
515  auto tileSliceI32 = rewriter.create<arith::IndexCastUIOp>(
516  loc, rewriter.getI32Type(), tileSlice);
517 
518  // Create all active predicate mask.
519  auto maskOp = loadTileSliceOp.getMask();
520 
521  auto tileVectorType = loadTileSliceOp.getVectorType();
522  arm_sme::ArmSMETileType tileType = *arm_sme::getSMETileType(tileVectorType);
523  arm_sme::TileSliceLayout layout = loadTileSliceOp.getLayout();
524 
525  // Create 'arm_sme.intr.ld1*.(horiz|vert)' intrinsic to load ZA tile slice.
526  createLoadTileSliceIntrinsic(rewriter, loc, tileType, layout, maskOp, ptr,
527  tileId, tileSliceI32);
528 
529  // The load intrinsics have no result, replace 'arm_sme.tile_load' with
530  // the input tile to preserve dataflow.
531  rewriter.replaceOp(loadTileSliceOp, loadTileSliceOp.getTile());
532 
533  return success();
534  }
535 };
536 
537 /// Lower for `arm_sme.store_tile_slice` to SME intrinsics.
538 struct StoreTileSliceConversion
539  : public ConvertArmSMEOpToLLVMPattern<arm_sme::StoreTileSliceOp> {
540  using ConvertArmSMEOpToLLVMPattern::ConvertArmSMEOpToLLVMPattern;
541 
542  LogicalResult
543  matchAndRewrite(arm_sme::StoreTileSliceOp storeTileSliceOp,
544  arm_sme::StoreTileSliceOp::Adaptor adaptor,
545  ConversionPatternRewriter &rewriter) const override {
546  auto loc = storeTileSliceOp.getLoc();
547  auto tileVectorType = storeTileSliceOp.getVectorType();
548 
549  auto tileId = getTileIdOrError(storeTileSliceOp);
550  if (!tileId)
551  return failure();
552 
553  // Create 'arm_sme.intr.st1*.horiz' intrinsic to store ZA tile slice.
554  Value ptr = this->getStridedElementPtr(
555  loc, storeTileSliceOp.getMemRefType(), adaptor.getBase(),
556  adaptor.getIndices(), rewriter);
557 
558  auto tileSlice = storeTileSliceOp.getTileSliceIndex();
559 
560  // Cast tile slice to i32 for intrinsic.
561  auto tileSliceI32 = rewriter.create<arith::IndexCastUIOp>(
562  loc, rewriter.getI32Type(), tileSlice);
563 
564  auto maskOp = storeTileSliceOp.getMask();
565 
566  arm_sme::TileSliceLayout layout = storeTileSliceOp.getLayout();
567  arm_sme::ArmSMETileType tileType = *arm_sme::getSMETileType(tileVectorType);
568 
569  rewriter.replaceOp(storeTileSliceOp,
570  createStoreTileSliceIntrinsic(rewriter, loc, tileType,
571  layout, maskOp, ptr,
572  tileId, tileSliceI32));
573 
574  return success();
575  }
576 };
577 
578 /// Lower `arm_sme.insert_tile_slice` to SME intrinsics.
579 struct InsertTileSliceConversion
580  : public ConvertArmSMEOpToLLVMPattern<arm_sme::InsertTileSliceOp> {
581  using ConvertArmSMEOpToLLVMPattern::ConvertArmSMEOpToLLVMPattern;
582 
583  LogicalResult
584  matchAndRewrite(arm_sme::InsertTileSliceOp insertTileSliceOp,
585  arm_sme::InsertTileSliceOp::Adaptor adaptor,
586  ConversionPatternRewriter &rewriter) const override {
587  auto loc = insertTileSliceOp.getLoc();
588  auto tileType = insertTileSliceOp.getTileType();
589 
590  auto tileId = getTileIdOrError(insertTileSliceOp);
591  if (!tileId)
592  return failure();
593 
594  auto tileSlice = insertTileSliceOp.getTileSliceIndex();
595 
596  // Cast tile slice from index to i32 for intrinsic.
597  auto tileSliceI32 = rewriter.create<arith::IndexCastUIOp>(
598  loc, rewriter.getI32Type(), tileSlice);
599 
600  // Create all active predicate mask.
601  auto one = rewriter.create<arith::ConstantOp>(
602  loc, rewriter.getI1Type(),
603  rewriter.getIntegerAttr(rewriter.getI1Type(), 1));
604  auto predTy = VectorType::get(tileType.getShape()[0], rewriter.getI1Type(),
605  /*scalableDims=*/{true});
606  auto allActiveMask = rewriter.create<vector::SplatOp>(loc, predTy, one);
607 
608  // Create 'arm_sme.intr.write.(horiz|vert)' to write vector to tile slice.
609  switch (insertTileSliceOp.getLayout()) {
610  case arm_sme::TileSliceLayout::Horizontal:
611  rewriter.create<arm_sme::aarch64_sme_write_horiz>(
612  loc, tileId, tileSliceI32, allActiveMask,
613  insertTileSliceOp.getVector());
614  break;
615  case arm_sme::TileSliceLayout::Vertical:
616  rewriter.create<arm_sme::aarch64_sme_write_vert>(
617  loc, tileId, tileSliceI32, allActiveMask,
618  insertTileSliceOp.getVector());
619  break;
620  }
621 
622  // Intrinsic has no result, replace 'arm_sme.insert_tile_slice' with
623  // the input tile to preserve dataflow.
624  rewriter.replaceOp(insertTileSliceOp, insertTileSliceOp.getTile());
625 
626  return success();
627  }
628 };
629 
630 /// Lower `arm_sme.extract_tile_slice` to SME intrinsics.
631 struct ExtractTileSliceConversion
632  : public ConvertArmSMEOpToLLVMPattern<arm_sme::ExtractTileSliceOp> {
633  using ConvertArmSMEOpToLLVMPattern::ConvertArmSMEOpToLLVMPattern;
634 
635  LogicalResult
636  matchAndRewrite(arm_sme::ExtractTileSliceOp extractTileSlice, OpAdaptor,
637  ConversionPatternRewriter &rewriter) const override {
638  auto loc = extractTileSlice.getLoc();
639  auto sliceType = extractTileSlice.getSliceType();
640  auto sliceIndex = extractTileSlice.getTileSliceIndex();
641 
642  auto tileId = getTileIdOrError(extractTileSlice);
643  if (!tileId)
644  return failure();
645 
646  // Create an 'all true' predicate for the tile slice.
647  auto predicateType = sliceType.cloneWith({}, rewriter.getI1Type());
648  auto allTruePredicate = rewriter.create<arith::ConstantOp>(
649  loc, DenseElementsAttr::get(predicateType, true));
650 
651  // Zero destination/fallback for tile slice extraction.
652  auto zeroVector = rewriter.create<arith::ConstantOp>(
653  loc, sliceType, rewriter.getZeroAttr(sliceType));
654 
655  // Cast tile slice from index to i32 for intrinsic.
656  auto sliceIndexI32 = rewriter.create<arith::IndexCastOp>(
657  loc, rewriter.getI32Type(), sliceIndex);
658 
659  // Create 'arm_sme.intr.read.(horiz|vert)' to extract the tile slice.
660  switch (extractTileSlice.getLayout()) {
661  case arm_sme::TileSliceLayout::Horizontal:
662  rewriter.replaceOpWithNewOp<arm_sme::aarch64_sme_read_horiz>(
663  extractTileSlice, sliceType, zeroVector, allTruePredicate, tileId,
664  sliceIndexI32);
665  break;
666  case arm_sme::TileSliceLayout::Vertical:
667  rewriter.replaceOpWithNewOp<arm_sme::aarch64_sme_read_vert>(
668  extractTileSlice, sliceType, zeroVector, allTruePredicate, tileId,
669  sliceIndexI32);
670  break;
671  }
672 
673  return success();
674  }
675 };
676 
677 /// Lower `arm_sme.outerproduct` to SME MOPA intrinsics.
678 ///
679 /// Example:
680 ///
681 /// %0 = arm_sme.outerproduct %lhs, %rhs acc(%acc)
682 /// : vector<[4]xf32>, vector<[4]xf32>
683 ///
684 /// is converted to:
685 ///
686 /// "arm_sme.intr.mopa"(%ptrue_s, %ptrue_s, %lhs, %rhs) <{tile_id = 0 : i32}>
687 /// : (vector<[4]xi1>, vector<[4]xi1>, vector<[4]xf32>,
688 /// vector<[4]xf32>) -> ()
689 ///
690 /// Currently only supports FMOPA and BFMOPA (non-widening).
691 struct OuterProductOpConversion
692  : public ConvertArmSMEOpToLLVMPattern<arm_sme::OuterProductOp> {
693  using ConvertArmSMEOpToLLVMPattern::ConvertArmSMEOpToLLVMPattern;
694 
695  LogicalResult
696  matchAndRewrite(arm_sme::OuterProductOp outerProductOp,
697  arm_sme::OuterProductOp::Adaptor adaptor,
698  ConversionPatternRewriter &rewriter) const override {
699  auto tileId = getTileIdOrError(outerProductOp);
700  if (!tileId)
701  return failure();
702 
703  auto isSupportedType = [](VectorType vectorType) {
704  // TODO: the FP outer product instruction variants are predicated on
705  // different features [1]:
706  //
707  // * FMOPA (non-widening)
708  // * half-precision - +sme2p1,+sme-f16f16
709  // * single-precision - +sme
710  // * double-precision - +sme-f64f64
711  // * BFMOPA
712  // * half-precision - +sme2p1,+b16b16
713  //
714  // It should be possible to control lowering based on target features.
715  // [1]
716  // https://developer.arm.com/downloads/-/exploration-tools/feature-names-for-a-profile
717  if ((vectorType.getRank() != 2) || !vectorType.allDimsScalable())
718  return false;
719 
720  auto elementType = vectorType.getElementType();
721 
722  if (!elementType.isF16() && !elementType.isBF16() &&
723  !elementType.isF32() && !elementType.isF64())
724  return false;
725 
726  unsigned minNumElts = arm_sme::MinStreamingVectorLengthInBits /
727  vectorType.getElementTypeBitWidth();
728  return vectorType.getShape() ==
729  ArrayRef<int64_t>({minNumElts, minNumElts});
730  };
731 
732  // TODO: Support CombiningKind::Sub for outer products.
733  if (outerProductOp.getKind() != arm_sme::CombiningKind::Add)
734  return outerProductOp.emitError("unsupported kind");
735 
736  auto resultVectorType = outerProductOp.getResultType();
737  if (!isSupportedType(resultVectorType))
738  return outerProductOp.emitError("unsupported type");
739 
740  auto loc = outerProductOp.getLoc();
741 
742  Value acc = outerProductOp.getAcc();
743  if (!acc) {
744  // Initalize accumulator with zero.
745  auto zero = rewriter.create<arm_sme::ZeroOp>(loc, resultVectorType);
746  zero.setTileId(tileId);
747  acc = zero;
748  }
749 
750  Value lhsMask = outerProductOp.getLhsMask();
751  Value rhsMask = outerProductOp.getRhsMask();
752 
753  if (!lhsMask || !rhsMask) {
754  auto predTy =
755  outerProductOp.getLhsType().cloneWith({}, rewriter.getI1Type());
756  Value allActiveMask = rewriter.create<arith::ConstantOp>(
757  loc, DenseElementsAttr::get(predTy, true));
758  lhsMask = allActiveMask;
759  rhsMask = allActiveMask;
760  }
761 
762  // Create 'arm_sme.intr.mopa' outer product intrinsic.
763  rewriter.create<arm_sme::aarch64_sme_mopa>(loc, tileId, lhsMask, rhsMask,
764  outerProductOp.getLhs(),
765  outerProductOp.getRhs());
766 
767  // The outerproduct intrinsics have no result, replace
768  // 'arm_sme.outerproduct' with the input tile to preserve dataflow.
769  rewriter.replaceOp(outerProductOp, acc);
770 
771  return success();
772  }
773 };
774 
775 /// Lower 2-way and 4-way widening outer products to intrinsics.
776 template <class OuterProductWideningOp, class OuterProductWideningIntrOp>
777 struct OuterProductWideningOpConversion
778  : public ConvertArmSMEOpToLLVMPattern<OuterProductWideningOp> {
779  using ConvertArmSMEOpToLLVMPattern<
780  OuterProductWideningOp>::ConvertArmSMEOpToLLVMPattern;
781 
782  LogicalResult
783  matchAndRewrite(OuterProductWideningOp op,
784  typename OuterProductWideningOp::Adaptor adaptor,
785  ConversionPatternRewriter &rewriter) const override {
786  auto tileId = getTileIdOrError(op);
787  if (!tileId)
788  return failure();
789 
790  auto loc = op.getLoc();
791  Value acc = op.getAcc();
792  if (!acc) {
793  // Initalize accumulator with zero.
794  auto zero = rewriter.create<arm_sme::ZeroOp>(loc, op.getResultType());
795  zero.setTileId(tileId);
796  acc = zero;
797  }
798 
799  Value lhsMask = op.getLhsMask();
800  Value rhsMask = op.getRhsMask();
801  if (!lhsMask || !rhsMask) {
802  auto predTy = op.getLhsType().cloneWith({}, rewriter.getI1Type());
803  Value allActiveMask = rewriter.create<arith::ConstantOp>(
804  loc, DenseElementsAttr::get(predTy, true));
805  lhsMask = allActiveMask;
806  rhsMask = allActiveMask;
807  }
808 
809  rewriter.create<OuterProductWideningIntrOp>(
810  loc, tileId, lhsMask, rhsMask, adaptor.getLhs(), adaptor.getRhs());
811 
812  // The outerproduct intrinsics have no result, replace
813  // 'arm_sme.outerproduct' with the input tile to preserve dataflow.
814  rewriter.replaceOp(op, acc);
815 
816  return success();
817  }
818 };
819 
820 /// Lower `arm_sme.streaming_vl` to SME CNTS intrinsics.
821 ///
822 /// Example:
823 ///
824 /// %0 = arm_sme.streaming_vl <half>
825 ///
826 /// is converted to:
827 ///
828 /// %cnt = "arm_sme.intr.cntsh"() : () -> i64
829 /// %0 = arith.index_cast %cnt : i64 to index
830 ///
831 struct StreamingVLOpConversion
832  : public ConvertArmSMEOpToLLVMPattern<arm_sme::StreamingVLOp,
833  RequiresSpillsAndFills::No> {
834  using ConvertArmSMEOpToLLVMPattern::ConvertArmSMEOpToLLVMPattern;
835 
836  LogicalResult
837  matchAndRewrite(arm_sme::StreamingVLOp streamingVlOp,
838  arm_sme::StreamingVLOp::Adaptor adaptor,
839  ConversionPatternRewriter &rewriter) const override {
840  auto loc = streamingVlOp.getLoc();
841  auto i64Type = rewriter.getI64Type();
842  auto *intrOp = [&]() -> Operation * {
843  switch (streamingVlOp.getTypeSize()) {
844  case arm_sme::TypeSize::Byte:
845  return rewriter.create<arm_sme::aarch64_sme_cntsb>(loc, i64Type);
846  case arm_sme::TypeSize::Half:
847  return rewriter.create<arm_sme::aarch64_sme_cntsh>(loc, i64Type);
848  case arm_sme::TypeSize::Word:
849  return rewriter.create<arm_sme::aarch64_sme_cntsw>(loc, i64Type);
850  case arm_sme::TypeSize::Double:
851  return rewriter.create<arm_sme::aarch64_sme_cntsd>(loc, i64Type);
852  }
853  }();
854  rewriter.replaceOpWithNewOp<arith::IndexCastOp>(
855  streamingVlOp, rewriter.getIndexType(), intrOp->getResult(0));
856  return success();
857  }
858 };
859 
860 /// Merges consecutive `arm_sme.intr.zero` operations in a block by bitwise
861 /// or-ing the zero masks. Note: In future the backend _should_ handle this.
862 static void mergeConsecutiveTileZerosInBlock(Block *block) {
863  uint32_t mergedZeroMask = 0;
865  auto replaceMergedZeroOps = [&] {
866  auto cleanup = llvm::make_scope_exit([&] {
867  mergedZeroMask = 0;
868  zeroOpsToMerge.clear();
869  });
870  if (zeroOpsToMerge.size() <= 1)
871  return;
872  IRRewriter rewriter(zeroOpsToMerge.front());
873  rewriter.create<arm_sme::aarch64_sme_zero>(
874  zeroOpsToMerge.front().getLoc(),
875  rewriter.getI32IntegerAttr(mergedZeroMask));
876  for (auto zeroOp : zeroOpsToMerge)
877  rewriter.eraseOp(zeroOp);
878  };
879  for (Operation &op : *block) {
880  if (auto zeroOp = dyn_cast<arm_sme::aarch64_sme_zero>(op)) {
881  mergedZeroMask |= zeroOp.getTileMask();
882  zeroOpsToMerge.push_back(zeroOp);
883  } else {
884  replaceMergedZeroOps();
885  }
886  }
887  replaceMergedZeroOps();
888 }
889 
890 } // namespace
891 
892 namespace {
893 
894 struct ConvertArmSMEToLLVMPass
895  : public impl::ConvertArmSMEToLLVMBase<ConvertArmSMEToLLVMPass> {
896  ConvertArmSMEToLLVMPass(bool dumpTileLiveRanges) {
897  this->dumpTileLiveRanges = dumpTileLiveRanges;
898  }
899  void runOnOperation() override {
900  auto function = getOperation();
901 
902  if (failed(arm_sme::allocateSMETiles(function, dumpTileLiveRanges)))
903  return signalPassFailure();
904 
906  RewritePatternSet patterns(&getContext());
907  LLVMTypeConverter converter(&getContext());
909  populateArmSMEToLLVMConversionPatterns(converter, patterns);
910 
911  if (failed(applyPartialConversion(function, target, std::move(patterns))))
912  signalPassFailure();
913 
914  function->walk(mergeConsecutiveTileZerosInBlock);
915 
916  // Walk the function and fail if there are unexpected operations on SME
917  // tile types after conversion.
918  function->walk([&](Operation *op) {
919  // These ops are legal post conversion, skip these.
920  if (isa<arm_sme::CopyTileOp, arm_sme::GetTileOp, cf::BranchOp>(op) ||
921  !op->isRegistered())
922  return;
923  auto isSMETileType = [](Type type) {
924  return arm_sme::isValidSMETileVectorType(type);
925  };
926  if (llvm::any_of(op->getResultTypes(), isSMETileType) ||
927  llvm::any_of(op->getOperandTypes(), isSMETileType)) {
928  op->emitOpError("unexpected operation with SME tile type after "
929  "conversion to LLVM");
930  signalPassFailure();
931  }
932  });
933  }
934 };
935 
936 } // namespace
937 
939  target.addIllegalDialect<arm_sme::ArmSMEDialect>();
940  target.addLegalOp<
941  arm_sme::aarch64_sme_zero, arm_sme::aarch64_sme_str,
942  arm_sme::aarch64_sme_ld1b_horiz, arm_sme::aarch64_sme_ld1h_horiz,
943  arm_sme::aarch64_sme_ld1w_horiz, arm_sme::aarch64_sme_ld1d_horiz,
944  arm_sme::aarch64_sme_ld1q_horiz, arm_sme::aarch64_sme_st1b_horiz,
945  arm_sme::aarch64_sme_st1h_horiz, arm_sme::aarch64_sme_st1w_horiz,
946  arm_sme::aarch64_sme_st1d_horiz, arm_sme::aarch64_sme_st1q_horiz,
947  arm_sme::aarch64_sme_ld1b_vert, arm_sme::aarch64_sme_ld1h_vert,
948  arm_sme::aarch64_sme_ld1w_vert, arm_sme::aarch64_sme_ld1d_vert,
949  arm_sme::aarch64_sme_ld1q_vert, arm_sme::aarch64_sme_st1b_vert,
950  arm_sme::aarch64_sme_st1h_vert, arm_sme::aarch64_sme_st1w_vert,
951  arm_sme::aarch64_sme_st1d_vert, arm_sme::aarch64_sme_st1q_vert,
952  arm_sme::aarch64_sme_read_horiz, arm_sme::aarch64_sme_read_vert,
953  arm_sme::aarch64_sme_write_horiz, arm_sme::aarch64_sme_write_vert,
954  arm_sme::aarch64_sme_mopa, arm_sme::aarch64_sme_mopa_wide,
955  arm_sme::aarch64_sme_mops_wide, arm_sme::aarch64_sme_smopa_wide,
956  arm_sme::aarch64_sme_smops_wide, arm_sme::aarch64_sme_umopa_wide,
957  arm_sme::aarch64_sme_umops_wide, arm_sme::aarch64_sme_smopa_za32,
958  arm_sme::aarch64_sme_smops_za32, arm_sme::aarch64_sme_umopa_za32,
959  arm_sme::aarch64_sme_umops_za32, arm_sme::aarch64_sme_sumopa_wide,
960  arm_sme::aarch64_sme_sumops_wide, arm_sme::aarch64_sme_usmopa_wide,
961  arm_sme::aarch64_sme_usmops_wide, arm_sme::aarch64_sme_cntsb,
962  arm_sme::aarch64_sme_cntsh, arm_sme::aarch64_sme_cntsw,
963  arm_sme::aarch64_sme_cntsd>();
964  target.addLegalDialect<arith::ArithDialect,
965  /* The following are used to lower tile spills/fills */
966  vector::VectorDialect, scf::SCFDialect,
967  memref::MemRefDialect>();
968  // Pseudo operations. These cannot be code-generated but may exist in the
969  // input IR, or be generated during the conversion. They need to be eliminated
970  // before the final conversion to LLVM IR (and likely will be due to DCE).
971  target.addLegalOp<arm_sme::GetTileOp, arm_sme::CopyTileOp,
972  UnrealizedConversionCastOp>();
973 }
974 
976  RewritePatternSet &patterns) {
977  converter.addConversion([&](VectorType type) -> std::optional<Type> {
978  // There's no LLVM type for SME tiles, but after lowering to intrinsics all
979  // SME vector types should be eliminated.
981  return type;
982  return std::nullopt;
983  });
984 
985  addArmSMEConversionPatterns<
986  LoadTileSliceConversion, ExtractTileSliceConversion,
987  InsertTileSliceConversion, StoreTileSliceConversion,
988  StreamingVLOpConversion, OuterProductOpConversion,
989  OuterProductWideningOpConversion<arm_sme::FMopa2WayOp,
990  arm_sme::aarch64_sme_mopa_wide>,
991  OuterProductWideningOpConversion<arm_sme::FMops2WayOp,
992  arm_sme::aarch64_sme_mops_wide>,
993  OuterProductWideningOpConversion<arm_sme::SMopa2WayOp,
994  arm_sme::aarch64_sme_smopa_za32>,
995  OuterProductWideningOpConversion<arm_sme::SMops2WayOp,
996  arm_sme::aarch64_sme_smops_za32>,
997  OuterProductWideningOpConversion<arm_sme::UMopa2WayOp,
998  arm_sme::aarch64_sme_umopa_za32>,
999  OuterProductWideningOpConversion<arm_sme::UMops2WayOp,
1000  arm_sme::aarch64_sme_umops_za32>,
1001  OuterProductWideningOpConversion<arm_sme::SMopa4WayOp,
1002  arm_sme::aarch64_sme_smopa_wide>,
1003  OuterProductWideningOpConversion<arm_sme::SMops4WayOp,
1004  arm_sme::aarch64_sme_smops_wide>,
1005  OuterProductWideningOpConversion<arm_sme::UMopa4WayOp,
1006  arm_sme::aarch64_sme_umopa_wide>,
1007  OuterProductWideningOpConversion<arm_sme::UMops4WayOp,
1008  arm_sme::aarch64_sme_umops_wide>,
1009  OuterProductWideningOpConversion<arm_sme::SuMopa4WayOp,
1010  arm_sme::aarch64_sme_sumopa_wide>,
1011  OuterProductWideningOpConversion<arm_sme::SuMops4WayOp,
1012  arm_sme::aarch64_sme_sumops_wide>,
1013  OuterProductWideningOpConversion<arm_sme::UsMopa4WayOp,
1014  arm_sme::aarch64_sme_usmopa_wide>,
1015  OuterProductWideningOpConversion<arm_sme::UsMops4WayOp,
1016  arm_sme::aarch64_sme_usmops_wide>,
1017  ZeroOpConversion>(patterns, converter);
1018 }
1019 
1020 std::unique_ptr<Pass>
1021 mlir::createConvertArmSMEToLLVMPass(bool dumpTileLiveRanges) {
1022  return std::make_unique<ConvertArmSMEToLLVMPass>(dumpTileLiveRanges);
1023 }
static MLIRContext * getContext(OpFoldResult val)
Block represents an ordered list of Operations.
Definition: Block.h:31
IntegerAttr getI32IntegerAttr(int32_t value)
Definition: Builders.cpp:228
IntegerAttr getIntegerAttr(Type type, int64_t value)
Definition: Builders.cpp:250
IntegerType getI64Type()
Definition: Builders.cpp:97
IntegerType getI32Type()
Definition: Builders.cpp:95
TypedAttr getZeroAttr(Type type)
Definition: Builders.cpp:343
IntegerType getI1Type()
Definition: Builders.cpp:85
IndexType getIndexType()
Definition: Builders.cpp:83
This class implements a pattern rewriter for use with ConversionPatterns.
void replaceOp(Operation *op, ValueRange newValues) override
PatternRewriter hook for replacing an operation.
void eraseOp(Operation *op) override
PatternRewriter hook for erasing a dead operation.
This class describes a specific conversion target.
void addLegalOp(OperationName op)
Register the given operations as legal.
void addLegalDialect(StringRef name, Names... names)
Register the operations of the given dialects as legal.
void addIllegalDialect(StringRef name, Names... names)
Register the operations of the given dialects as illegal, i.e.
Utility class for operation conversions targeting the LLVM dialect that match exactly one source oper...
Definition: Pattern.h:143
Base class for operation conversions targeting the LLVM IR dialect.
Definition: Pattern.h:41
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
This class coordinates rewriting a piece of IR outside of a pattern rewrite, providing a way to keep ...
Definition: PatternMatch.h:766
Derived class that automatically populates legalization information for different LLVM ops.
Conversion from types to the LLVM IR dialect.
Definition: TypeConverter.h:35
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:63
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
Definition: Builders.h:436
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Definition: Builders.h:403
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:476
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
Definition: Builders.h:417
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
bool isRegistered()
Returns true if this operation has a registered operation description, otherwise false.
Definition: Operation.h:129
Operation * clone(IRMapping &mapper, CloneOptions options=CloneOptions::all())
Create a deep copy of this operation, remapping any operands that use values outside of the operation...
Definition: Operation.cpp:717
Location getLoc()
The source location the operation was defined or derived from.
Definition: Operation.h:223
operand_type_range getOperandTypes()
Definition: Operation.h:392
result_type_range getResultTypes()
Definition: Operation.h:423
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
Definition: Operation.cpp:671
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
Definition: PatternMatch.h:34
This class contains all of the data related to a pattern, but does not contain any methods or logic f...
Definition: PatternMatch.h:73
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
Definition: PatternMatch.h:847
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
Definition: PatternMatch.h:400
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
Definition: PatternMatch.h:630
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Definition: PatternMatch.h:536
void addConversion(FnT &&callback)
Register a conversion function.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:381
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:129
This is a builder type that keeps local references to arguments.
Definition: BuiltinTypes.h:314
Builder & dropDim(unsigned pos)
Erase a dim from shape @pos.
Definition: BuiltinTypes.h:339
std::optional< ArmSMETileType > getSMETileType(VectorType)
Returns the type of SME tile this vector type corresponds to, or none if the vector type does not fit...
Definition: Utils.cpp:44
LogicalResult allocateSMETiles(FunctionOpInterface function, bool dumpRanges=false)
Allocate tile IDs to all ArmSME operations in a function.
unsigned getSMETileSliceMinNumElts(Type type)
Return minimum number of elements for the given element type in a vector of SVL bits.
Definition: Utils.cpp:18
bool isValidSMETileVectorType(VectorType vType)
Returns true if vType is a valid vector type for an SME tile or false otherwise.
Definition: Utils.cpp:29
constexpr unsigned MinStreamingVectorLengthInBits
Definition: Utils.h:33
Include the generated interface declarations.
std::unique_ptr< Pass > createConvertArmSMEToLLVMPass(bool dumpTileLiveRanges=false)
Create a pass to convert from the ArmSME dialect to LLVM intrinsics.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
void configureArmSMEToLLVMConversionLegality(ConversionTarget &target)
Configure target to convert from the ArmSME dialect to LLVM intrinsics.
LogicalResult applyPartialConversion(ArrayRef< Operation * > ops, const ConversionTarget &target, const FrozenRewritePatternSet &patterns, ConversionConfig config=ConversionConfig())
Below we define several entry points for operation conversion.
void populateArmSMEToLLVMConversionPatterns(LLVMTypeConverter &converter, RewritePatternSet &patterns)
Populate the given list with patterns that convert from the ArmSME dialect to LLVM intrinsics.