MLIR  20.0.0git
ArmSMEToLLVM.cpp
Go to the documentation of this file.
1 //===- ArmSMEToLLVM.cpp - Convert ArmSME to LLVM dialect ------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements lowering of ArmSME operations to LLVM intrinsics.
10 //
11 //===----------------------------------------------------------------------===//
12 
14 
26 #include "mlir/Pass/Pass.h"
28 #include "llvm/ADT/ScopeExit.h"
29 
30 namespace mlir {
31 #define GEN_PASS_DEF_CONVERTARMSMETOLLVM
32 #include "mlir/Conversion/Passes.h.inc"
33 } // namespace mlir
34 
35 using namespace mlir;
36 
37 namespace {
38 
39 static constexpr StringLiteral kInMemoryTileIdAttr("arm_sme.in_memory_tile_id");
40 
41 /// Helper to create an arm_sme.intr.ld1*.(horiz|vert)' intrinsic.
42 static Operation *createLoadTileSliceIntrinsic(
43  RewriterBase &rewriter, Location loc, arm_sme::ArmSMETileType type,
44  arm_sme::TileSliceLayout layout, Value maskOp, Value ptr,
45  IntegerAttr tileId, Value tileSliceI32) {
46  if (layout == arm_sme::TileSliceLayout::Horizontal) {
47  switch (type) {
48  case arm_sme::ArmSMETileType::ZAB:
49  return rewriter.create<arm_sme::aarch64_sme_ld1b_horiz>(
50  loc, maskOp, ptr, tileId, tileSliceI32);
51  case arm_sme::ArmSMETileType::ZAH:
52  return rewriter.create<arm_sme::aarch64_sme_ld1h_horiz>(
53  loc, maskOp, ptr, tileId, tileSliceI32);
54  case arm_sme::ArmSMETileType::ZAS:
55  return rewriter.create<arm_sme::aarch64_sme_ld1w_horiz>(
56  loc, maskOp, ptr, tileId, tileSliceI32);
57  case arm_sme::ArmSMETileType::ZAD:
58  return rewriter.create<arm_sme::aarch64_sme_ld1d_horiz>(
59  loc, maskOp, ptr, tileId, tileSliceI32);
60  case arm_sme::ArmSMETileType::ZAQ:
61  return rewriter.create<arm_sme::aarch64_sme_ld1q_horiz>(
62  loc, maskOp, ptr, tileId, tileSliceI32);
63  }
64  } else {
65  switch (type) {
66  case arm_sme::ArmSMETileType::ZAB:
67  return rewriter.create<arm_sme::aarch64_sme_ld1b_vert>(
68  loc, maskOp, ptr, tileId, tileSliceI32);
69  case arm_sme::ArmSMETileType::ZAH:
70  return rewriter.create<arm_sme::aarch64_sme_ld1h_vert>(
71  loc, maskOp, ptr, tileId, tileSliceI32);
72  case arm_sme::ArmSMETileType::ZAS:
73  return rewriter.create<arm_sme::aarch64_sme_ld1w_vert>(
74  loc, maskOp, ptr, tileId, tileSliceI32);
75  case arm_sme::ArmSMETileType::ZAD:
76  return rewriter.create<arm_sme::aarch64_sme_ld1d_vert>(
77  loc, maskOp, ptr, tileId, tileSliceI32);
78  case arm_sme::ArmSMETileType::ZAQ:
79  return rewriter.create<arm_sme::aarch64_sme_ld1q_vert>(
80  loc, maskOp, ptr, tileId, tileSliceI32);
81  break;
82  }
83  }
84  llvm_unreachable("unknown type in createLoadTileSliceIntrinsic");
85 }
86 
87 /// Helper to create an arm_sme.intr.st1*.(horiz|vert)' intrinsic.
88 static Operation *createStoreTileSliceIntrinsic(
89  RewriterBase &rewriter, Location loc, arm_sme::ArmSMETileType type,
90  arm_sme::TileSliceLayout layout, Value maskOp, Value ptr,
91  IntegerAttr tileId, Value tileSliceI32) {
92  if (layout == arm_sme::TileSliceLayout::Horizontal) {
93  switch (type) {
94  case arm_sme::ArmSMETileType::ZAB:
95  return rewriter.create<arm_sme::aarch64_sme_st1b_horiz>(
96  loc, maskOp, ptr, tileId, tileSliceI32);
97  case arm_sme::ArmSMETileType::ZAH:
98  return rewriter.create<arm_sme::aarch64_sme_st1h_horiz>(
99  loc, maskOp, ptr, tileId, tileSliceI32);
100  case arm_sme::ArmSMETileType::ZAS:
101  return rewriter.create<arm_sme::aarch64_sme_st1w_horiz>(
102  loc, maskOp, ptr, tileId, tileSliceI32);
103  case arm_sme::ArmSMETileType::ZAD:
104  return rewriter.create<arm_sme::aarch64_sme_st1d_horiz>(
105  loc, maskOp, ptr, tileId, tileSliceI32);
106  case arm_sme::ArmSMETileType::ZAQ:
107  return rewriter.create<arm_sme::aarch64_sme_st1q_horiz>(
108  loc, maskOp, ptr, tileId, tileSliceI32);
109  }
110  } else {
111  switch (type) {
112  case arm_sme::ArmSMETileType::ZAB:
113  return rewriter.create<arm_sme::aarch64_sme_st1b_vert>(
114  loc, maskOp, ptr, tileId, tileSliceI32);
115  case arm_sme::ArmSMETileType::ZAH:
116  return rewriter.create<arm_sme::aarch64_sme_st1h_vert>(
117  loc, maskOp, ptr, tileId, tileSliceI32);
118  case arm_sme::ArmSMETileType::ZAS:
119  return rewriter.create<arm_sme::aarch64_sme_st1w_vert>(
120  loc, maskOp, ptr, tileId, tileSliceI32);
121  case arm_sme::ArmSMETileType::ZAD:
122  return rewriter.create<arm_sme::aarch64_sme_st1d_vert>(
123  loc, maskOp, ptr, tileId, tileSliceI32);
124  case arm_sme::ArmSMETileType::ZAQ:
125  return rewriter.create<arm_sme::aarch64_sme_st1q_vert>(
126  loc, maskOp, ptr, tileId, tileSliceI32);
127  }
128  }
129  llvm_unreachable("unknown type in createStoreTileSliceIntrinsic");
130 }
131 
132 IntegerAttr getTileIdOrError(arm_sme::ArmSMETileOpInterface op) {
133  auto tileId = op.getTileId();
134  if (!tileId)
135  op.emitOpError(
136  "expected tile ID to be allocated before conversion to LLVM");
137  return tileId;
138 }
139 
140 /// Creates an alloca matching the size of tile used by `tileOp`. The alloca is
141 /// placed in the first block of the function.
142 static memref::AllocaOp
143 createAllocaForTile(RewriterBase &rewriter, Location loc,
144  FunctionOpInterface func,
145  arm_sme::ArmSMETileOpInterface tileOp) {
146  RewriterBase::InsertionGuard g(rewriter);
147  // Move to the first operation in the function.
148  rewriter.setInsertionPointToStart(&func.getBlocks().front());
149  // Create an alloca matching the tile size of the `tileOp`.
150  auto vscale = rewriter.create<vector::VectorScaleOp>(loc);
151  auto tileElementType = tileOp.getTileType().getElementType();
152  auto memrefType = MemRefType::get(
153  {ShapedType::kDynamic, ShapedType::kDynamic}, tileElementType);
154  unsigned minElements = arm_sme::getSMETileSliceMinNumElts(tileElementType);
155  auto minElementsOp =
156  rewriter.create<arith::ConstantIndexOp>(loc, minElements);
157  auto vectorLen = rewriter.create<arith::MulIOp>(loc, vscale, minElementsOp);
158  auto alloca = rewriter.create<memref::AllocaOp>(
159  loc, memrefType, ValueRange{vectorLen, vectorLen});
160  return alloca;
161 }
162 
163 /// Finds or creates an alloca for a spill of a tile.
164 static memref::AllocaOp getOrCreateAllocaForTile(
165  RewriterBase &rewriter, Location loc, FunctionOpInterface func,
166  arm_sme::ArmSMETileOpInterface tileOp, unsigned tileId) {
167  // Find an alloca at the top of the function tagged with a
168  // 'arm_sme.in_memory_tile_id' that matches `tileId`.
169  for (auto &op : func.getBlocks().front()) {
170  auto alloca = llvm::dyn_cast<memref::AllocaOp>(op);
171  if (!alloca)
172  continue;
173  auto inMemoryTileId = llvm::dyn_cast_or_null<IntegerAttr>(
174  alloca->getDiscardableAttr(kInMemoryTileIdAttr));
175  if (!inMemoryTileId)
176  continue;
177  if (inMemoryTileId.getInt() == tileId)
178  return alloca;
179  }
180  // Otherwise, create a new alloca:
181  auto alloca = createAllocaForTile(rewriter, loc, func, tileOp);
182  alloca->setDiscardableAttr(kInMemoryTileIdAttr,
183  rewriter.getI32IntegerAttr(tileId));
184  return alloca;
185 }
186 
187 /// Very naive lowering of in-memory tiles (i.e. tiles that were not assigned a
188 /// hardware tile ID) to ArmSME intrinsics. Currently, this works by assigning
189 /// the op to tile 0, then emitting a full tile swap between ZA and memory
190 /// before + after the tile op.
191 ///
192 /// Example:
193 ///
194 /// // Note: <IN MEMORY TILE> = tile ID >= 16.
195 /// arm_sme.tile_op { tile_id = <IN MEMORY TILE> }
196 ///
197 /// is converted to:
198 /// // At function entry:
199 /// %spill = memref.alloca ... : memref<?x?xty>
200 ///
201 /// // Around op:
202 /// scf.for %slice_idx {
203 /// %slice_to_save = "arm_sme.intr.read.horiz" ... <{tile_id = 0 : i32}>
204 /// "arm_sme.intr.ld1h.horiz"(%spill, %slice_idx) <{tile_id = 0 : i32}>
205 /// vector.store %slice_to_save, %spill[%slice_idx, %c0]
206 /// }
207 /// arm_sme.tile_op { tile_id = 0 }
208 /// scf.for %slice_idx {
209 /// %slice_to_save = "arm_sme.intr.read.horiz" ... <{tile_id = 0 : i32}>
210 /// "arm_sme.intr.ld1h.horiz"(%spill, %slice_idx) <{tile_id = 0 : i32}>
211 /// vector.store %slice_to_save, %spill[%slice_idx, %c0]
212 /// }
213 ///
214 /// Note that these spills/fills are not inserted earlier as concept of a
215 /// register, and the need to swap the contents, can't really be represented
216 /// correctly at a high level in MLIR.
217 ///
218 /// TODO: Reduce the spills/reloads to single slices where possible (and omit
219 /// redundant reloads). This could be done via a method on the
220 /// `ArmSMETileOpInterface` which returns how the operation uses ZA. E.g.:
221 ///
222 /// `tileOp.getZaUsage()` could return:
223 ///
224 /// struct ArmSMEOpZAUsage {
225 /// enum class Kind {
226 /// TileRead, // Omit store after tile operation.
227 /// TileWrite, // Omit load before tile operation.
228 /// TileReadWrite, // Needs both tile load and store.
229 /// SliceRead, // Spill single slice and omit store after operation.
230 /// SliceWrite, // Spill single slice and omit load before operation.
231 /// SliceReadWrite // Spill single slice.
232 /// };
233 /// Value sliceIndex {};
234 /// TileSliceLayout sliceLayout { TileSliceLayout::Horizontal };
235 /// };
236 ///
237 struct ConvertArmSMESpillsAndFillsToLLVM : public ConvertToLLVMPattern {
238 
239  ConvertArmSMESpillsAndFillsToLLVM(StringRef rootOpName,
240  const LLVMTypeConverter &typeConverter,
241  PatternBenefit benefit)
242  : ConvertToLLVMPattern(rootOpName, &typeConverter.getContext(),
243  typeConverter, benefit) {}
244 
245  LogicalResult
246  matchAndRewrite(Operation *op, ArrayRef<Value> operands,
247  ConversionPatternRewriter &rewriter) const override {
248  auto tileOp = cast<arm_sme::ArmSMETileOpInterface>(op);
249  // Tile has a real (hardware) tile. No spills/reloads required.
250  if (!tileOp.isInMemoryTile())
251  return failure();
252 
253  tileOp->emitWarning(
254  "failed to allocate SME virtual tile to operation, tile value will go "
255  "through memory, expect degraded performance");
256 
257  // Step 1. Create an alloca for the tile at the top of the function (if one
258  // does not already exist).
259  auto loc = tileOp.getLoc();
260  auto func = tileOp->getParentOfType<FunctionOpInterface>();
261  auto tileAlloca = getOrCreateAllocaForTile(rewriter, loc, func, tileOp,
262  tileOp.getTileId().getInt());
263 
264  // Step 2. Assign the op a real tile ID.
265  // For simplicity, we always use tile 0 (which always exists).
266  auto zeroTileId = rewriter.getI32IntegerAttr(0);
267  rewriter.modifyOpInPlace(tileOp, [&] { tileOp.setTileId(zeroTileId); });
268 
269  VectorType tileVectorType = tileOp.getTileType();
270  auto sliceType = VectorType::Builder(tileVectorType).dropDim(0);
271  auto swapInMemoryTileWithSMETileZero = [&] {
272  emitFullTileSwap(rewriter, loc, tileAlloca,
273  *arm_sme::getSMETileType(tileVectorType), sliceType,
274  zeroTileId);
275  };
276 
277  // Step 3. Emit tile swaps before and after the op.
278  // TODO: Reduce the amount spilled to the amount of data the `tileOp`
279  // touches (i.e. a single tile slice).
280  {
281  rewriter.setInsertionPoint(op);
282  // Swap the contents of ZA and the in-memory tile before the op.
283  swapInMemoryTileWithSMETileZero();
284  rewriter.setInsertionPointAfter(op);
285  // Swap the tile back out to memory again after the op.
286  swapInMemoryTileWithSMETileZero();
287  }
288 
289  return success();
290  }
291 
292  /// Extracts a pointer to a slice of an in-memory tile.
293  Value getInMemoryTileSlicePtr(RewriterBase &rewriter, Location loc,
294  Value tileMemory, Value sliceIndex) const {
295  auto llvmType = getTypeConverter()->convertType(tileMemory.getType());
296  auto descriptor =
297  rewriter.create<UnrealizedConversionCastOp>(loc, llvmType, tileMemory);
298  auto zero = rewriter.create<arith::ConstantIntOp>(loc, 0, /*width=*/64);
299  auto sliceIndexI64 = rewriter.create<arith::IndexCastOp>(
300  loc, rewriter.getI64Type(), sliceIndex);
301  return getStridedElementPtr(
302  loc, llvm::cast<MemRefType>(tileMemory.getType()),
303  descriptor.getResult(0), {sliceIndexI64, zero},
304  static_cast<ConversionPatternRewriter &>(rewriter));
305  }
306 
307  /// Emits an in-place swap of a slice of a tile in ZA and a slice of a
308  /// tile-sized memref (`tileAlloca`).
309  void emitSliceSwap(RewriterBase &rewriter, Location loc, Value tileAlloca,
310  arm_sme::ArmSMETileType tileType, VectorType sliceType,
311  IntegerAttr tileId, Value sliceIndex) const {
312  // Cast the slice index to an i32.
313  auto sliceIndexI32 = rewriter.create<arith::IndexCastOp>(
314  loc, rewriter.getI32Type(), sliceIndex);
315  // Create an all-true predicate for the slice.
316  auto predicateType = sliceType.clone(rewriter.getI1Type());
317  auto allTruePredicate = rewriter.create<arith::ConstantOp>(
318  loc, DenseElementsAttr::get(predicateType, true));
319  // Create padding vector (never used due to all-true predicate).
320  auto padVector = rewriter.create<LLVM::UndefOp>(loc, sliceType);
321  // Get a pointer to the current slice.
322  auto slicePtr =
323  getInMemoryTileSlicePtr(rewriter, loc, tileAlloca, sliceIndex);
324  // Read the value of the current slice from ZA.
325  auto currentTileSlice = rewriter.create<arm_sme::aarch64_sme_read_horiz>(
326  loc, sliceType, padVector, allTruePredicate, tileId, sliceIndexI32);
327  // Load the new tile slice back from memory into ZA.
328  createLoadTileSliceIntrinsic(
329  rewriter, loc, tileType, arm_sme::TileSliceLayout::Horizontal,
330  allTruePredicate, slicePtr, tileId, sliceIndexI32);
331  // Store the current tile slice to memory.
332  auto zero = rewriter.create<arith::ConstantIndexOp>(loc, 0);
333  rewriter.create<vector::StoreOp>(loc, currentTileSlice, tileAlloca,
334  ValueRange{sliceIndex, zero});
335  }
336 
337  /// Emits a full in-place swap of the contents of a tile in ZA and a
338  /// tile-sized memref (`tileAlloca`).
339  void emitFullTileSwap(RewriterBase &rewriter, Location loc, Value tileAlloca,
340  arm_sme::ArmSMETileType tileType, VectorType sliceType,
341  IntegerAttr tileId) const {
342  RewriterBase::InsertionGuard guard(rewriter);
343  // Create an scf.for over all tile slices.
344  auto minNumElts =
345  rewriter.create<arith::ConstantIndexOp>(loc, sliceType.getDimSize(0));
346  auto lowerBound = rewriter.create<arith::ConstantIndexOp>(loc, 0);
347  auto upperBound = rewriter.create<arith::MulIOp>(
348  loc, minNumElts, rewriter.create<vector::VectorScaleOp>(loc));
349  auto step = rewriter.create<arith::ConstantIndexOp>(loc, 1);
350  auto forOp = rewriter.create<scf::ForOp>(loc, lowerBound, upperBound, step);
351  // Emit a swap for each tile slice.
352  rewriter.setInsertionPointToStart(forOp.getBody());
353  auto sliceIndex = forOp.getInductionVar();
354  emitSliceSwap(rewriter, loc, tileAlloca, tileType, sliceType, tileId,
355  sliceIndex);
356  }
357 };
358 
359 enum class RequiresSpillsAndFills { Yes, No };
360 
361 /// Base class for ArmSME to LLVM conversion patterns. By default, this adds
362 /// spills and fills around ArmSME ops that use in-memory tile IDs. This can be
363 /// disabled by setting the `requiresSpillsAndFills` template parameter to
364 /// `RequiresSpillsAndFills::No`.
365 template <typename SourceOp, RequiresSpillsAndFills requiresSpillsAndFills =
366  RequiresSpillsAndFills::Yes>
367 struct ConvertArmSMEOpToLLVMPattern : ConvertOpToLLVMPattern<SourceOp> {
368  using ArmSMEOp = SourceOp;
370 
371  static constexpr bool requiresSpillsAndFillsConversion() {
372  return requiresSpillsAndFills == RequiresSpillsAndFills::Yes;
373  }
374 };
375 
376 template <typename Pattern>
377 static void addArmSMEConversionPattern(RewritePatternSet &patterns,
378  LLVMTypeConverter const &typeConverter) {
379  // Register spills/fills for ops that implement the
380  // `ArmSMETileOpInterface` and have `requiresSpillsAndFills` set to
381  // `RequiresSpillsAndFills::Yes`.
382  if constexpr (Pattern::requiresSpillsAndFillsConversion() &&
383  std::is_base_of_v<arm_sme::ArmSMETileOpInterface::Trait<
384  typename Pattern::ArmSMEOp>,
385  typename Pattern::ArmSMEOp>) {
386  // Add spill/fill conversions with a very high benefit to ensure
387  // they are lowered first.
388  patterns.add<ConvertArmSMESpillsAndFillsToLLVM>(
389  Pattern::ArmSMEOp::getOperationName(), typeConverter,
390  /*benefit=*/1337);
391  }
392  patterns.add<Pattern>(typeConverter);
393 }
394 
395 /// Helper to register `ConvertArmSMEOpToLLVMPattern` patterns.
396 template <typename... Patterns>
397 static void
398 addArmSMEConversionPatterns(RewritePatternSet &patterns,
399  LLVMTypeConverter const &typeConverter) {
400  (addArmSMEConversionPattern<Patterns>(patterns, typeConverter), ...);
401 }
402 
403 /// Lower 'arm_sme.zero' to SME intrinsics.
404 ///
405 /// BEFORE:
406 /// ```mlir
407 /// %v = arm_sme.zero {tile_id = 0 : i32} : vector<[4]x[4]xi32>
408 /// ```
409 ///
410 /// AFTER:
411 /// ```mlir
412 /// "arm_sme.intr.zero"() <{tile_mask = 17 : i32}> : () -> ()
413 /// %v = arm_sme.get_tile : vector<[4]x[4]xi32>
414 /// ```
415 ///
416 /// The 'arm_sme.get_tile' (which models the return) will fold away once all
417 /// ArmSME ops have been converted to LLVM intrinsics.
418 struct ZeroOpConversion : public ConvertArmSMEOpToLLVMPattern<arm_sme::ZeroOp> {
419  using ConvertArmSMEOpToLLVMPattern::ConvertArmSMEOpToLLVMPattern;
420 
421  LogicalResult
422  matchAndRewrite(arm_sme::ZeroOp zero, OpAdaptor adaptor,
423  ConversionPatternRewriter &rewriter) const override {
424  auto loc = zero.getLoc();
425 
426  auto tileId = getTileIdOrError(zero);
427  if (!tileId)
428  return failure();
429 
430  // Get the base mask for tile based on the element size.
431  // The base mask is just the mask to zero the first tile (of a size).
432  // These masks are derived from:
433  // https://developer.arm.com/documentation/ddi0602/2022-06/SME-Instructions/ZERO--Zero-a-list-of-64-bit-element-ZA-tiles-
434  arm_sme::ArmSMETileType tileType =
435  *arm_sme::getSMETileType(zero.getTileType());
436  auto baseMaskForSize = [&] {
437  switch (tileType) {
438  case arm_sme::ArmSMETileType::ZAB:
439  // Zeroing the 8-bit ZA0.B tile is equivalent to zeroing all eight
440  // 64-bit element tiles named ZA0.D to ZA7.D.
441  return 0b1111'1111;
442  case arm_sme::ArmSMETileType::ZAH:
443  // Zeroing the 16-bit ZA0.H tile is equivalent to zeroing 64-bit
444  // element tiles named ZA0.D, ZA2.D, ZA4.D, and ZA6.D. Shift this left
445  // once for ZA1.H.
446  return 0b0101'0101;
447  case arm_sme::ArmSMETileType::ZAS:
448  // Zeroing the 32-bit ZA0.S tile is equivalent to zeroing 64-bit
449  // element tiles named ZA0.D and ZA4.D.
450  // Shift left by 1, 2, or 3 respectively for ZA1.S, ZA2.S, ZA3.S.
451  return 0b0001'0001;
452  case arm_sme::ArmSMETileType::ZAD:
453  // Zeroing one of the a 64-bit tiles ZA0.D to ZA7.D just requires
454  // setting the bit for that tile.
455  return 0b0000'0001;
456  default:
457  llvm_unreachable("bad element size");
458  }
459  }();
460 
461  // The actual mask is just the base mask shifted by the tile ID.
462  // This will be folded to a constant after tile allocation.
463  //
464  // The shift is just derived from the layout of the tiles, and that the tile
465  // ID is the index of the tile. For example, looking at the 32-bit ZAx.S
466  // tiles:
467  //
468  // ZA0.S = ZA0.D and ZA4.D
469  // * Tile ID -> 0
470  // * Mask -> 00010001 = (00010001 << 0)
471  // ZA1.S = ZA1.D and ZA5.D
472  // * Tile ID -> 1
473  // * Mask -> 00100010 = (00010001 << 1)
474  // ZA2.S = ZA2.D and ZA6.D
475  // * Tile ID -> 2
476  // * Mask -> 01000100 = (00010001 << 2)
477  // ZA3.S = ZA3.D and ZA7.D
478  // * Tile ID -> 3
479  // * Mask -> 10001000 = (00010001 << 3)
480  //
481  // This holds for all tile sizes.
482  int32_t zeroMask = baseMaskForSize << int32_t(tileId.getInt());
483  rewriter.create<arm_sme::aarch64_sme_zero>(
484  loc, rewriter.getI32IntegerAttr(zeroMask));
485 
486  // Create a placeholder op to preserve dataflow.
487  // Note: Place the `get_tile` op at the start of the block. This ensures
488  // that if there are multiple `zero` ops the intrinsics will be consecutive.
489  rewriter.setInsertionPointToStart(zero->getBlock());
490  rewriter.replaceOpWithNewOp<arm_sme::GetTileOp>(zero, zero.getVectorType());
491 
492  return success();
493  }
494 };
495 
496 /// Lower `arm_sme.load_tile_slice` to SME intrinsics.
497 struct LoadTileSliceConversion
498  : public ConvertArmSMEOpToLLVMPattern<arm_sme::LoadTileSliceOp> {
499  using ConvertArmSMEOpToLLVMPattern::ConvertArmSMEOpToLLVMPattern;
500 
501  LogicalResult
502  matchAndRewrite(arm_sme::LoadTileSliceOp loadTileSliceOp,
503  arm_sme::LoadTileSliceOp::Adaptor adaptor,
504  ConversionPatternRewriter &rewriter) const override {
505  auto loc = loadTileSliceOp.getLoc();
506  auto tileId = getTileIdOrError(loadTileSliceOp);
507  if (!tileId)
508  return failure();
509 
510  Value ptr = this->getStridedElementPtr(loc, loadTileSliceOp.getMemRefType(),
511  adaptor.getBase(),
512  adaptor.getIndices(), rewriter);
513 
514  auto tileSlice = loadTileSliceOp.getTileSliceIndex();
515 
516  // Cast tile slice to i32 for intrinsic.
517  auto tileSliceI32 = rewriter.create<arith::IndexCastUIOp>(
518  loc, rewriter.getI32Type(), tileSlice);
519 
520  // Create all active predicate mask.
521  auto maskOp = loadTileSliceOp.getMask();
522 
523  auto tileVectorType = loadTileSliceOp.getVectorType();
524  arm_sme::ArmSMETileType tileType = *arm_sme::getSMETileType(tileVectorType);
525  arm_sme::TileSliceLayout layout = loadTileSliceOp.getLayout();
526 
527  // Create 'arm_sme.intr.ld1*.(horiz|vert)' intrinsic to load ZA tile slice.
528  createLoadTileSliceIntrinsic(rewriter, loc, tileType, layout, maskOp, ptr,
529  tileId, tileSliceI32);
530 
531  // The load intrinsics have no result, replace 'arm_sme.tile_load' with
532  // the input tile to preserve dataflow.
533  rewriter.replaceOp(loadTileSliceOp, loadTileSliceOp.getTile());
534 
535  return success();
536  }
537 };
538 
539 /// Lower for `arm_sme.store_tile_slice` to SME intrinsics.
540 struct StoreTileSliceConversion
541  : public ConvertArmSMEOpToLLVMPattern<arm_sme::StoreTileSliceOp> {
542  using ConvertArmSMEOpToLLVMPattern::ConvertArmSMEOpToLLVMPattern;
543 
544  LogicalResult
545  matchAndRewrite(arm_sme::StoreTileSliceOp storeTileSliceOp,
546  arm_sme::StoreTileSliceOp::Adaptor adaptor,
547  ConversionPatternRewriter &rewriter) const override {
548  auto loc = storeTileSliceOp.getLoc();
549  auto tileVectorType = storeTileSliceOp.getVectorType();
550 
551  auto tileId = getTileIdOrError(storeTileSliceOp);
552  if (!tileId)
553  return failure();
554 
555  // Create 'arm_sme.intr.st1*.horiz' intrinsic to store ZA tile slice.
556  Value ptr = this->getStridedElementPtr(
557  loc, storeTileSliceOp.getMemRefType(), adaptor.getBase(),
558  adaptor.getIndices(), rewriter);
559 
560  auto tileSlice = storeTileSliceOp.getTileSliceIndex();
561 
562  // Cast tile slice to i32 for intrinsic.
563  auto tileSliceI32 = rewriter.create<arith::IndexCastUIOp>(
564  loc, rewriter.getI32Type(), tileSlice);
565 
566  auto maskOp = storeTileSliceOp.getMask();
567 
568  arm_sme::TileSliceLayout layout = storeTileSliceOp.getLayout();
569  arm_sme::ArmSMETileType tileType = *arm_sme::getSMETileType(tileVectorType);
570 
571  rewriter.replaceOp(storeTileSliceOp,
572  createStoreTileSliceIntrinsic(rewriter, loc, tileType,
573  layout, maskOp, ptr,
574  tileId, tileSliceI32));
575 
576  return success();
577  }
578 };
579 
580 /// Lower `arm_sme.insert_tile_slice` to SME intrinsics.
581 struct InsertTileSliceConversion
582  : public ConvertArmSMEOpToLLVMPattern<arm_sme::InsertTileSliceOp> {
583  using ConvertArmSMEOpToLLVMPattern::ConvertArmSMEOpToLLVMPattern;
584 
585  LogicalResult
586  matchAndRewrite(arm_sme::InsertTileSliceOp insertTileSliceOp,
587  arm_sme::InsertTileSliceOp::Adaptor adaptor,
588  ConversionPatternRewriter &rewriter) const override {
589  auto loc = insertTileSliceOp.getLoc();
590  auto tileType = insertTileSliceOp.getTileType();
591 
592  auto tileId = getTileIdOrError(insertTileSliceOp);
593  if (!tileId)
594  return failure();
595 
596  auto tileSlice = insertTileSliceOp.getTileSliceIndex();
597 
598  // Cast tile slice from index to i32 for intrinsic.
599  auto tileSliceI32 = rewriter.create<arith::IndexCastUIOp>(
600  loc, rewriter.getI32Type(), tileSlice);
601 
602  // Create all active predicate mask.
603  auto one = rewriter.create<arith::ConstantOp>(
604  loc, rewriter.getI1Type(),
605  rewriter.getIntegerAttr(rewriter.getI1Type(), 1));
606  auto predTy = VectorType::get(tileType.getShape()[0], rewriter.getI1Type(),
607  /*scalableDims=*/{true});
608  auto allActiveMask = rewriter.create<vector::SplatOp>(loc, predTy, one);
609 
610  // Create 'arm_sme.intr.write.(horiz|vert)' to write vector to tile slice.
611  switch (insertTileSliceOp.getLayout()) {
612  case arm_sme::TileSliceLayout::Horizontal:
613  rewriter.create<arm_sme::aarch64_sme_write_horiz>(
614  loc, tileId, tileSliceI32, allActiveMask,
615  insertTileSliceOp.getVector());
616  break;
617  case arm_sme::TileSliceLayout::Vertical:
618  rewriter.create<arm_sme::aarch64_sme_write_vert>(
619  loc, tileId, tileSliceI32, allActiveMask,
620  insertTileSliceOp.getVector());
621  break;
622  }
623 
624  // Intrinsic has no result, replace 'arm_sme.insert_tile_slice' with
625  // the input tile to preserve dataflow.
626  rewriter.replaceOp(insertTileSliceOp, insertTileSliceOp.getTile());
627 
628  return success();
629  }
630 };
631 
632 /// Lower `arm_sme.extract_tile_slice` to SME intrinsics.
633 struct ExtractTileSliceConversion
634  : public ConvertArmSMEOpToLLVMPattern<arm_sme::ExtractTileSliceOp> {
635  using ConvertArmSMEOpToLLVMPattern::ConvertArmSMEOpToLLVMPattern;
636 
637  LogicalResult
638  matchAndRewrite(arm_sme::ExtractTileSliceOp extractTileSlice, OpAdaptor,
639  ConversionPatternRewriter &rewriter) const override {
640  auto loc = extractTileSlice.getLoc();
641  auto sliceType = extractTileSlice.getSliceType();
642  auto sliceIndex = extractTileSlice.getTileSliceIndex();
643 
644  auto tileId = getTileIdOrError(extractTileSlice);
645  if (!tileId)
646  return failure();
647 
648  // Create an 'all true' predicate for the tile slice.
649  auto predicateType = sliceType.cloneWith({}, rewriter.getI1Type());
650  auto allTruePredicate = rewriter.create<arith::ConstantOp>(
651  loc, DenseElementsAttr::get(predicateType, true));
652 
653  // Zero destination/fallback for tile slice extraction.
654  auto zeroVector = rewriter.create<arith::ConstantOp>(
655  loc, sliceType, rewriter.getZeroAttr(sliceType));
656 
657  // Cast tile slice from index to i32 for intrinsic.
658  auto sliceIndexI32 = rewriter.create<arith::IndexCastOp>(
659  loc, rewriter.getI32Type(), sliceIndex);
660 
661  // Create 'arm_sme.intr.read.(horiz|vert)' to extract the tile slice.
662  switch (extractTileSlice.getLayout()) {
663  case arm_sme::TileSliceLayout::Horizontal:
664  rewriter.replaceOpWithNewOp<arm_sme::aarch64_sme_read_horiz>(
665  extractTileSlice, sliceType, zeroVector, allTruePredicate, tileId,
666  sliceIndexI32);
667  break;
668  case arm_sme::TileSliceLayout::Vertical:
669  rewriter.replaceOpWithNewOp<arm_sme::aarch64_sme_read_vert>(
670  extractTileSlice, sliceType, zeroVector, allTruePredicate, tileId,
671  sliceIndexI32);
672  break;
673  }
674 
675  return success();
676  }
677 };
678 
679 /// Lower `arm_sme.outerproduct` to SME MOPA intrinsics.
680 ///
681 /// Example:
682 ///
683 /// %0 = arm_sme.outerproduct %lhs, %rhs acc(%acc)
684 /// : vector<[4]xf32>, vector<[4]xf32>
685 ///
686 /// is converted to:
687 ///
688 /// "arm_sme.intr.mopa"(%ptrue_s, %ptrue_s, %lhs, %rhs) <{tile_id = 0 : i32}>
689 /// : (vector<[4]xi1>, vector<[4]xi1>, vector<[4]xf32>,
690 /// vector<[4]xf32>) -> ()
691 ///
692 /// Currently only supports FMOPA and BFMOPA (non-widening).
693 struct OuterProductOpConversion
694  : public ConvertArmSMEOpToLLVMPattern<arm_sme::OuterProductOp> {
695  using ConvertArmSMEOpToLLVMPattern::ConvertArmSMEOpToLLVMPattern;
696 
697  LogicalResult
698  matchAndRewrite(arm_sme::OuterProductOp outerProductOp,
699  arm_sme::OuterProductOp::Adaptor adaptor,
700  ConversionPatternRewriter &rewriter) const override {
701  auto tileId = getTileIdOrError(outerProductOp);
702  if (!tileId)
703  return failure();
704 
705  auto isSupportedType = [](VectorType vectorType) {
706  // TODO: the FP outer product instruction variants are predicated on
707  // different features [1]:
708  //
709  // * FMOPA (non-widening)
710  // * half-precision - +sme2p1,+sme-f16f16
711  // * single-precision - +sme
712  // * double-precision - +sme-f64f64
713  // * BFMOPA
714  // * half-precision - +sme2p1,+b16b16
715  //
716  // It should be possible to control lowering based on target features.
717  // [1]
718  // https://developer.arm.com/downloads/-/exploration-tools/feature-names-for-a-profile
719  if ((vectorType.getRank() != 2) || !vectorType.allDimsScalable())
720  return false;
721 
722  auto elementType = vectorType.getElementType();
723 
724  if (!elementType.isF16() && !elementType.isBF16() &&
725  !elementType.isF32() && !elementType.isF64())
726  return false;
727 
728  unsigned minNumElts = arm_sme::MinStreamingVectorLengthInBits /
729  vectorType.getElementTypeBitWidth();
730  return vectorType.getShape() ==
731  ArrayRef<int64_t>({minNumElts, minNumElts});
732  };
733 
734  // TODO: Support CombiningKind::Sub for outer products.
735  if (outerProductOp.getKind() != arm_sme::CombiningKind::Add)
736  return outerProductOp.emitError("unsupported kind");
737 
738  auto resultVectorType = outerProductOp.getResultType();
739  if (!isSupportedType(resultVectorType))
740  return outerProductOp.emitError("unsupported type");
741 
742  auto loc = outerProductOp.getLoc();
743 
744  Value acc = outerProductOp.getAcc();
745  if (!acc) {
746  // Initalize accumulator with zero.
747  auto zero = rewriter.create<arm_sme::ZeroOp>(loc, resultVectorType);
748  zero.setTileId(tileId);
749  acc = zero;
750  }
751 
752  Value lhsMask = outerProductOp.getLhsMask();
753  Value rhsMask = outerProductOp.getRhsMask();
754 
755  if (!lhsMask || !rhsMask) {
756  auto predTy =
757  outerProductOp.getLhsType().cloneWith({}, rewriter.getI1Type());
758  Value allActiveMask = rewriter.create<arith::ConstantOp>(
759  loc, DenseElementsAttr::get(predTy, true));
760  lhsMask = allActiveMask;
761  rhsMask = allActiveMask;
762  }
763 
764  // Create 'arm_sme.intr.mopa' outer product intrinsic.
765  rewriter.create<arm_sme::aarch64_sme_mopa>(loc, tileId, lhsMask, rhsMask,
766  outerProductOp.getLhs(),
767  outerProductOp.getRhs());
768 
769  // The outerproduct intrinsics have no result, replace
770  // 'arm_sme.outerproduct' with the input tile to preserve dataflow.
771  rewriter.replaceOp(outerProductOp, acc);
772 
773  return success();
774  }
775 };
776 
777 /// Lower 2-way and 4-way widening outer products to intrinsics.
778 template <class OuterProductWideningOp, class OuterProductWideningIntrOp>
779 struct OuterProductWideningOpConversion
780  : public ConvertArmSMEOpToLLVMPattern<OuterProductWideningOp> {
781  using ConvertArmSMEOpToLLVMPattern<
782  OuterProductWideningOp>::ConvertArmSMEOpToLLVMPattern;
783 
784  LogicalResult
785  matchAndRewrite(OuterProductWideningOp op,
786  typename OuterProductWideningOp::Adaptor adaptor,
787  ConversionPatternRewriter &rewriter) const override {
788  auto tileId = getTileIdOrError(op);
789  if (!tileId)
790  return failure();
791 
792  auto loc = op.getLoc();
793  Value acc = op.getAcc();
794  if (!acc) {
795  // Initalize accumulator with zero.
796  auto zero = rewriter.create<arm_sme::ZeroOp>(loc, op.getResultType());
797  zero.setTileId(tileId);
798  acc = zero;
799  }
800 
801  Value lhsMask = op.getLhsMask();
802  Value rhsMask = op.getRhsMask();
803  if (!lhsMask || !rhsMask) {
804  auto predTy = op.getLhsType().cloneWith({}, rewriter.getI1Type());
805  Value allActiveMask = rewriter.create<arith::ConstantOp>(
806  loc, DenseElementsAttr::get(predTy, true));
807  lhsMask = allActiveMask;
808  rhsMask = allActiveMask;
809  }
810 
811  rewriter.create<OuterProductWideningIntrOp>(
812  loc, tileId, lhsMask, rhsMask, adaptor.getLhs(), adaptor.getRhs());
813 
814  // The outerproduct intrinsics have no result, replace
815  // 'arm_sme.outerproduct' with the input tile to preserve dataflow.
816  rewriter.replaceOp(op, acc);
817 
818  return success();
819  }
820 };
821 
822 /// Lower `arm_sme.streaming_vl` to SME CNTS intrinsics.
823 ///
824 /// Example:
825 ///
826 /// %0 = arm_sme.streaming_vl <half>
827 ///
828 /// is converted to:
829 ///
830 /// %cnt = "arm_sme.intr.cntsh"() : () -> i64
831 /// %0 = arith.index_cast %cnt : i64 to index
832 ///
833 struct StreamingVLOpConversion
834  : public ConvertArmSMEOpToLLVMPattern<arm_sme::StreamingVLOp,
835  RequiresSpillsAndFills::No> {
836  using ConvertArmSMEOpToLLVMPattern::ConvertArmSMEOpToLLVMPattern;
837 
838  LogicalResult
839  matchAndRewrite(arm_sme::StreamingVLOp streamingVlOp,
840  arm_sme::StreamingVLOp::Adaptor adaptor,
841  ConversionPatternRewriter &rewriter) const override {
842  auto loc = streamingVlOp.getLoc();
843  auto i64Type = rewriter.getI64Type();
844  auto *intrOp = [&]() -> Operation * {
845  switch (streamingVlOp.getTypeSize()) {
846  case arm_sme::TypeSize::Byte:
847  return rewriter.create<arm_sme::aarch64_sme_cntsb>(loc, i64Type);
848  case arm_sme::TypeSize::Half:
849  return rewriter.create<arm_sme::aarch64_sme_cntsh>(loc, i64Type);
850  case arm_sme::TypeSize::Word:
851  return rewriter.create<arm_sme::aarch64_sme_cntsw>(loc, i64Type);
852  case arm_sme::TypeSize::Double:
853  return rewriter.create<arm_sme::aarch64_sme_cntsd>(loc, i64Type);
854  }
855  llvm_unreachable("unknown type size in StreamingVLOpConversion");
856  }();
857  rewriter.replaceOpWithNewOp<arith::IndexCastOp>(
858  streamingVlOp, rewriter.getIndexType(), intrOp->getResult(0));
859  return success();
860  }
861 };
862 
863 /// Merges consecutive `arm_sme.intr.zero` operations in a block by bitwise
864 /// or-ing the zero masks. Note: In future the backend _should_ handle this.
865 static void mergeConsecutiveTileZerosInBlock(Block *block) {
866  uint32_t mergedZeroMask = 0;
868  auto replaceMergedZeroOps = [&] {
869  auto cleanup = llvm::make_scope_exit([&] {
870  mergedZeroMask = 0;
871  zeroOpsToMerge.clear();
872  });
873  if (zeroOpsToMerge.size() <= 1)
874  return;
875  IRRewriter rewriter(zeroOpsToMerge.front());
876  rewriter.create<arm_sme::aarch64_sme_zero>(
877  zeroOpsToMerge.front().getLoc(),
878  rewriter.getI32IntegerAttr(mergedZeroMask));
879  for (auto zeroOp : zeroOpsToMerge)
880  rewriter.eraseOp(zeroOp);
881  };
882  for (Operation &op : *block) {
883  if (auto zeroOp = dyn_cast<arm_sme::aarch64_sme_zero>(op)) {
884  mergedZeroMask |= zeroOp.getTileMask();
885  zeroOpsToMerge.push_back(zeroOp);
886  } else {
887  replaceMergedZeroOps();
888  }
889  }
890  replaceMergedZeroOps();
891 }
892 
893 } // namespace
894 
895 namespace {
896 
897 struct ConvertArmSMEToLLVMPass
898  : public impl::ConvertArmSMEToLLVMBase<ConvertArmSMEToLLVMPass> {
899  ConvertArmSMEToLLVMPass(bool dumpTileLiveRanges) {
900  this->dumpTileLiveRanges = dumpTileLiveRanges;
901  }
902  void runOnOperation() override {
903  auto function = getOperation();
904 
905  if (failed(arm_sme::allocateSMETiles(function, dumpTileLiveRanges)))
906  return signalPassFailure();
907 
909  RewritePatternSet patterns(&getContext());
910  LLVMTypeConverter converter(&getContext());
912  populateArmSMEToLLVMConversionPatterns(converter, patterns);
913 
914  if (failed(applyPartialConversion(function, target, std::move(patterns))))
915  signalPassFailure();
916 
917  function->walk(mergeConsecutiveTileZerosInBlock);
918 
919  // Walk the function and fail if there are unexpected operations on SME
920  // tile types after conversion.
921  function->walk([&](Operation *op) {
922  // These ops are legal post conversion, skip these.
923  if (isa<arm_sme::CopyTileOp, arm_sme::GetTileOp, cf::BranchOp>(op) ||
924  !op->isRegistered())
925  return;
926  auto isSMETileType = [](Type type) {
927  return arm_sme::isValidSMETileVectorType(type);
928  };
929  if (llvm::any_of(op->getResultTypes(), isSMETileType) ||
930  llvm::any_of(op->getOperandTypes(), isSMETileType)) {
931  op->emitOpError("unexpected operation with SME tile type after "
932  "conversion to LLVM");
933  signalPassFailure();
934  }
935  });
936  }
937 };
938 
939 } // namespace
940 
942  target.addIllegalDialect<arm_sme::ArmSMEDialect>();
943  target.addLegalOp<
944  arm_sme::aarch64_sme_zero, arm_sme::aarch64_sme_str,
945  arm_sme::aarch64_sme_ld1b_horiz, arm_sme::aarch64_sme_ld1h_horiz,
946  arm_sme::aarch64_sme_ld1w_horiz, arm_sme::aarch64_sme_ld1d_horiz,
947  arm_sme::aarch64_sme_ld1q_horiz, arm_sme::aarch64_sme_st1b_horiz,
948  arm_sme::aarch64_sme_st1h_horiz, arm_sme::aarch64_sme_st1w_horiz,
949  arm_sme::aarch64_sme_st1d_horiz, arm_sme::aarch64_sme_st1q_horiz,
950  arm_sme::aarch64_sme_ld1b_vert, arm_sme::aarch64_sme_ld1h_vert,
951  arm_sme::aarch64_sme_ld1w_vert, arm_sme::aarch64_sme_ld1d_vert,
952  arm_sme::aarch64_sme_ld1q_vert, arm_sme::aarch64_sme_st1b_vert,
953  arm_sme::aarch64_sme_st1h_vert, arm_sme::aarch64_sme_st1w_vert,
954  arm_sme::aarch64_sme_st1d_vert, arm_sme::aarch64_sme_st1q_vert,
955  arm_sme::aarch64_sme_read_horiz, arm_sme::aarch64_sme_read_vert,
956  arm_sme::aarch64_sme_write_horiz, arm_sme::aarch64_sme_write_vert,
957  arm_sme::aarch64_sme_mopa, arm_sme::aarch64_sme_mopa_wide,
958  arm_sme::aarch64_sme_mops_wide, arm_sme::aarch64_sme_smopa_wide,
959  arm_sme::aarch64_sme_smops_wide, arm_sme::aarch64_sme_umopa_wide,
960  arm_sme::aarch64_sme_umops_wide, arm_sme::aarch64_sme_smopa_za32,
961  arm_sme::aarch64_sme_smops_za32, arm_sme::aarch64_sme_umopa_za32,
962  arm_sme::aarch64_sme_umops_za32, arm_sme::aarch64_sme_sumopa_wide,
963  arm_sme::aarch64_sme_sumops_wide, arm_sme::aarch64_sme_usmopa_wide,
964  arm_sme::aarch64_sme_usmops_wide, arm_sme::aarch64_sme_cntsb,
965  arm_sme::aarch64_sme_cntsh, arm_sme::aarch64_sme_cntsw,
966  arm_sme::aarch64_sme_cntsd>();
967  target.addLegalDialect<arith::ArithDialect,
968  /* The following are used to lower tile spills/fills */
969  vector::VectorDialect, scf::SCFDialect,
970  memref::MemRefDialect>();
971  // Pseudo operations. These cannot be code-generated but may exist in the
972  // input IR, or be generated during the conversion. They need to be eliminated
973  // before the final conversion to LLVM IR (and likely will be due to DCE).
974  target.addLegalOp<arm_sme::GetTileOp, arm_sme::CopyTileOp,
975  UnrealizedConversionCastOp>();
976 }
977 
979  RewritePatternSet &patterns) {
980  converter.addConversion([&](VectorType type) -> std::optional<Type> {
981  // There's no LLVM type for SME tiles, but after lowering to intrinsics all
982  // SME vector types should be eliminated.
984  return type;
985  return std::nullopt;
986  });
987 
988  addArmSMEConversionPatterns<
989  LoadTileSliceConversion, ExtractTileSliceConversion,
990  InsertTileSliceConversion, StoreTileSliceConversion,
991  StreamingVLOpConversion, OuterProductOpConversion,
992  OuterProductWideningOpConversion<arm_sme::FMopa2WayOp,
993  arm_sme::aarch64_sme_mopa_wide>,
994  OuterProductWideningOpConversion<arm_sme::FMops2WayOp,
995  arm_sme::aarch64_sme_mops_wide>,
996  OuterProductWideningOpConversion<arm_sme::SMopa2WayOp,
997  arm_sme::aarch64_sme_smopa_za32>,
998  OuterProductWideningOpConversion<arm_sme::SMops2WayOp,
999  arm_sme::aarch64_sme_smops_za32>,
1000  OuterProductWideningOpConversion<arm_sme::UMopa2WayOp,
1001  arm_sme::aarch64_sme_umopa_za32>,
1002  OuterProductWideningOpConversion<arm_sme::UMops2WayOp,
1003  arm_sme::aarch64_sme_umops_za32>,
1004  OuterProductWideningOpConversion<arm_sme::SMopa4WayOp,
1005  arm_sme::aarch64_sme_smopa_wide>,
1006  OuterProductWideningOpConversion<arm_sme::SMops4WayOp,
1007  arm_sme::aarch64_sme_smops_wide>,
1008  OuterProductWideningOpConversion<arm_sme::UMopa4WayOp,
1009  arm_sme::aarch64_sme_umopa_wide>,
1010  OuterProductWideningOpConversion<arm_sme::UMops4WayOp,
1011  arm_sme::aarch64_sme_umops_wide>,
1012  OuterProductWideningOpConversion<arm_sme::SuMopa4WayOp,
1013  arm_sme::aarch64_sme_sumopa_wide>,
1014  OuterProductWideningOpConversion<arm_sme::SuMops4WayOp,
1015  arm_sme::aarch64_sme_sumops_wide>,
1016  OuterProductWideningOpConversion<arm_sme::UsMopa4WayOp,
1017  arm_sme::aarch64_sme_usmopa_wide>,
1018  OuterProductWideningOpConversion<arm_sme::UsMops4WayOp,
1019  arm_sme::aarch64_sme_usmops_wide>,
1020  ZeroOpConversion>(patterns, converter);
1021 }
1022 
1023 std::unique_ptr<Pass>
1024 mlir::createConvertArmSMEToLLVMPass(bool dumpTileLiveRanges) {
1025  return std::make_unique<ConvertArmSMEToLLVMPass>(dumpTileLiveRanges);
1026 }
static MLIRContext * getContext(OpFoldResult val)
Block represents an ordered list of Operations.
Definition: Block.h:33
IntegerAttr getI32IntegerAttr(int32_t value)
Definition: Builders.cpp:240
IntegerAttr getIntegerAttr(Type type, int64_t value)
Definition: Builders.cpp:268
IntegerType getI64Type()
Definition: Builders.cpp:109
IntegerType getI32Type()
Definition: Builders.cpp:107
TypedAttr getZeroAttr(Type type)
Definition: Builders.cpp:364
IntegerType getI1Type()
Definition: Builders.cpp:97
IndexType getIndexType()
Definition: Builders.cpp:95
This class implements a pattern rewriter for use with ConversionPatterns.
void replaceOp(Operation *op, ValueRange newValues) override
Replace the given operation with the new values.
void eraseOp(Operation *op) override
PatternRewriter hook for erasing a dead operation.
This class describes a specific conversion target.
void addLegalOp(OperationName op)
Register the given operations as legal.
void addLegalDialect(StringRef name, Names... names)
Register the operations of the given dialects as legal.
void addIllegalDialect(StringRef name, Names... names)
Register the operations of the given dialects as illegal, i.e.
Utility class for operation conversions targeting the LLVM dialect that match exactly one source oper...
Definition: Pattern.h:143
Base class for operation conversions targeting the LLVM IR dialect.
Definition: Pattern.h:41
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
This class coordinates rewriting a piece of IR outside of a pattern rewrite, providing a way to keep ...
Definition: PatternMatch.h:772
Derived class that automatically populates legalization information for different LLVM ops.
Conversion from types to the LLVM IR dialect.
Definition: TypeConverter.h:35
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:66
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
Definition: Builders.h:439
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Definition: Builders.h:406
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:497
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
Definition: Builders.h:420
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
bool isRegistered()
Returns true if this operation has a registered operation description, otherwise false.
Definition: Operation.h:129
Operation * clone(IRMapping &mapper, CloneOptions options=CloneOptions::all())
Create a deep copy of this operation, remapping any operands that use values outside of the operation...
Definition: Operation.cpp:717
operand_type_range getOperandTypes()
Definition: Operation.h:392
result_type_range getResultTypes()
Definition: Operation.h:423
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
Definition: PatternMatch.h:34
This class contains all of the data related to a pattern, but does not contain any methods or logic f...
Definition: PatternMatch.h:73
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
Definition: PatternMatch.h:853
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
Definition: PatternMatch.h:400
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
Definition: PatternMatch.h:636
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Definition: PatternMatch.h:542
void addConversion(FnT &&callback)
Register a conversion function.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:381
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:129
This is a builder type that keeps local references to arguments.
Definition: BuiltinTypes.h:317
Builder & dropDim(unsigned pos)
Erase a dim from shape @pos.
Definition: BuiltinTypes.h:342
std::optional< ArmSMETileType > getSMETileType(VectorType)
Returns the type of SME tile this vector type corresponds to, or none if the vector type does not fit...
Definition: Utils.cpp:44
LogicalResult allocateSMETiles(FunctionOpInterface function, bool dumpRanges=false)
Allocate tile IDs to all ArmSME operations in a function.
unsigned getSMETileSliceMinNumElts(Type type)
Return minimum number of elements for the given element type in a vector of SVL bits.
Definition: Utils.cpp:18
bool isValidSMETileVectorType(VectorType vType)
Returns true if vType is a valid vector type for an SME tile or false otherwise.
Definition: Utils.cpp:29
constexpr unsigned MinStreamingVectorLengthInBits
Definition: Utils.h:33
Include the generated interface declarations.
std::unique_ptr< Pass > createConvertArmSMEToLLVMPass(bool dumpTileLiveRanges=false)
Create a pass to convert from the ArmSME dialect to LLVM intrinsics.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
void configureArmSMEToLLVMConversionLegality(ConversionTarget &target)
Configure target to convert from the ArmSME dialect to LLVM intrinsics.
LogicalResult applyPartialConversion(ArrayRef< Operation * > ops, const ConversionTarget &target, const FrozenRewritePatternSet &patterns, ConversionConfig config=ConversionConfig())
Below we define several entry points for operation conversion.
void populateArmSMEToLLVMConversionPatterns(LLVMTypeConverter &converter, RewritePatternSet &patterns)
Populate the given list with patterns that convert from the ArmSME dialect to LLVM intrinsics.