MLIR  20.0.0git
ArmSMEToLLVM.cpp
Go to the documentation of this file.
1 //===- ArmSMEToLLVM.cpp - Convert ArmSME to LLVM dialect ------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements lowering of ArmSME operations to LLVM intrinsics.
10 //
11 //===----------------------------------------------------------------------===//
12 
14 
26 #include "mlir/Pass/Pass.h"
28 
29 namespace mlir {
30 #define GEN_PASS_DEF_CONVERTARMSMETOLLVM
31 #include "mlir/Conversion/Passes.h.inc"
32 } // namespace mlir
33 
34 using namespace mlir;
35 
36 namespace {
37 
38 static constexpr StringLiteral kInMemoryTileIdAttr("arm_sme.in_memory_tile_id");
39 
40 /// Helper to create an arm_sme.intr.ld1*.(horiz|vert)' intrinsic.
41 static Operation *createLoadTileSliceIntrinsic(
42  RewriterBase &rewriter, Location loc, arm_sme::ArmSMETileType type,
43  arm_sme::TileSliceLayout layout, Value maskOp, Value ptr,
44  IntegerAttr tileId, Value tileSliceI32) {
45  if (layout == arm_sme::TileSliceLayout::Horizontal) {
46  switch (type) {
47  case arm_sme::ArmSMETileType::ZAB:
48  return rewriter.create<arm_sme::aarch64_sme_ld1b_horiz>(
49  loc, maskOp, ptr, tileId, tileSliceI32);
50  case arm_sme::ArmSMETileType::ZAH:
51  return rewriter.create<arm_sme::aarch64_sme_ld1h_horiz>(
52  loc, maskOp, ptr, tileId, tileSliceI32);
53  case arm_sme::ArmSMETileType::ZAS:
54  return rewriter.create<arm_sme::aarch64_sme_ld1w_horiz>(
55  loc, maskOp, ptr, tileId, tileSliceI32);
56  case arm_sme::ArmSMETileType::ZAD:
57  return rewriter.create<arm_sme::aarch64_sme_ld1d_horiz>(
58  loc, maskOp, ptr, tileId, tileSliceI32);
59  case arm_sme::ArmSMETileType::ZAQ:
60  return rewriter.create<arm_sme::aarch64_sme_ld1q_horiz>(
61  loc, maskOp, ptr, tileId, tileSliceI32);
62  }
63  } else {
64  switch (type) {
65  case arm_sme::ArmSMETileType::ZAB:
66  return rewriter.create<arm_sme::aarch64_sme_ld1b_vert>(
67  loc, maskOp, ptr, tileId, tileSliceI32);
68  case arm_sme::ArmSMETileType::ZAH:
69  return rewriter.create<arm_sme::aarch64_sme_ld1h_vert>(
70  loc, maskOp, ptr, tileId, tileSliceI32);
71  case arm_sme::ArmSMETileType::ZAS:
72  return rewriter.create<arm_sme::aarch64_sme_ld1w_vert>(
73  loc, maskOp, ptr, tileId, tileSliceI32);
74  case arm_sme::ArmSMETileType::ZAD:
75  return rewriter.create<arm_sme::aarch64_sme_ld1d_vert>(
76  loc, maskOp, ptr, tileId, tileSliceI32);
77  case arm_sme::ArmSMETileType::ZAQ:
78  return rewriter.create<arm_sme::aarch64_sme_ld1q_vert>(
79  loc, maskOp, ptr, tileId, tileSliceI32);
80  break;
81  }
82  }
83 }
84 
85 /// Helper to create an arm_sme.intr.st1*.(horiz|vert)' intrinsic.
86 static Operation *createStoreTileSliceIntrinsic(
87  RewriterBase &rewriter, Location loc, arm_sme::ArmSMETileType type,
88  arm_sme::TileSliceLayout layout, Value maskOp, Value ptr,
89  IntegerAttr tileId, Value tileSliceI32) {
90  if (layout == arm_sme::TileSliceLayout::Horizontal) {
91  switch (type) {
92  case arm_sme::ArmSMETileType::ZAB:
93  return rewriter.create<arm_sme::aarch64_sme_st1b_horiz>(
94  loc, maskOp, ptr, tileId, tileSliceI32);
95  case arm_sme::ArmSMETileType::ZAH:
96  return rewriter.create<arm_sme::aarch64_sme_st1h_horiz>(
97  loc, maskOp, ptr, tileId, tileSliceI32);
98  case arm_sme::ArmSMETileType::ZAS:
99  return rewriter.create<arm_sme::aarch64_sme_st1w_horiz>(
100  loc, maskOp, ptr, tileId, tileSliceI32);
101  case arm_sme::ArmSMETileType::ZAD:
102  return rewriter.create<arm_sme::aarch64_sme_st1d_horiz>(
103  loc, maskOp, ptr, tileId, tileSliceI32);
104  case arm_sme::ArmSMETileType::ZAQ:
105  return rewriter.create<arm_sme::aarch64_sme_st1q_horiz>(
106  loc, maskOp, ptr, tileId, tileSliceI32);
107  }
108  } else {
109  switch (type) {
110  case arm_sme::ArmSMETileType::ZAB:
111  return rewriter.create<arm_sme::aarch64_sme_st1b_vert>(
112  loc, maskOp, ptr, tileId, tileSliceI32);
113  case arm_sme::ArmSMETileType::ZAH:
114  return rewriter.create<arm_sme::aarch64_sme_st1h_vert>(
115  loc, maskOp, ptr, tileId, tileSliceI32);
116  case arm_sme::ArmSMETileType::ZAS:
117  return rewriter.create<arm_sme::aarch64_sme_st1w_vert>(
118  loc, maskOp, ptr, tileId, tileSliceI32);
119  case arm_sme::ArmSMETileType::ZAD:
120  return rewriter.create<arm_sme::aarch64_sme_st1d_vert>(
121  loc, maskOp, ptr, tileId, tileSliceI32);
122  case arm_sme::ArmSMETileType::ZAQ:
123  return rewriter.create<arm_sme::aarch64_sme_st1q_vert>(
124  loc, maskOp, ptr, tileId, tileSliceI32);
125  }
126  }
127 }
128 
129 IntegerAttr getTileIdOrError(arm_sme::ArmSMETileOpInterface op) {
130  auto tileId = op.getTileId();
131  if (!tileId)
132  op.emitOpError(
133  "expected tile ID to be allocated before conversion to LLVM");
134  return tileId;
135 }
136 
137 /// Creates an alloca matching the size of tile used by `tileOp`. The alloca is
138 /// placed in the first block of the function.
139 static memref::AllocaOp
140 createAllocaForTile(RewriterBase &rewriter, Location loc,
141  FunctionOpInterface func,
142  arm_sme::ArmSMETileOpInterface tileOp) {
143  RewriterBase::InsertionGuard g(rewriter);
144  // Move to the first operation in the function.
145  rewriter.setInsertionPointToStart(&func.getBlocks().front());
146  // Create an alloca matching the tile size of the `tileOp`.
147  auto vscale = rewriter.create<vector::VectorScaleOp>(loc);
148  auto tileElementType = tileOp.getTileType().getElementType();
149  auto memrefType = MemRefType::get(
150  {ShapedType::kDynamic, ShapedType::kDynamic}, tileElementType);
151  unsigned minElements = arm_sme::getSMETileSliceMinNumElts(tileElementType);
152  auto minElementsOp =
153  rewriter.create<arith::ConstantIndexOp>(loc, minElements);
154  auto vectorLen = rewriter.create<arith::MulIOp>(loc, vscale, minElementsOp);
155  auto alloca = rewriter.create<memref::AllocaOp>(
156  loc, memrefType, ValueRange{vectorLen, vectorLen});
157  return alloca;
158 }
159 
160 /// Finds or creates an alloca for a spill of a tile.
161 static memref::AllocaOp getOrCreateAllocaForTile(
162  RewriterBase &rewriter, Location loc, FunctionOpInterface func,
163  arm_sme::ArmSMETileOpInterface tileOp, unsigned tileId) {
164  // Find an alloca at the top of the function tagged with a
165  // 'arm_sme.in_memory_tile_id' that matches `tileId`.
166  for (auto &op : func.getBlocks().front()) {
167  auto alloca = llvm::dyn_cast<memref::AllocaOp>(op);
168  if (!alloca)
169  continue;
170  auto inMemoryTileId = llvm::dyn_cast_or_null<IntegerAttr>(
171  alloca->getDiscardableAttr(kInMemoryTileIdAttr));
172  if (!inMemoryTileId)
173  continue;
174  if (inMemoryTileId.getInt() == tileId)
175  return alloca;
176  }
177  // Otherwise, create a new alloca:
178  auto alloca = createAllocaForTile(rewriter, loc, func, tileOp);
179  alloca->setDiscardableAttr(kInMemoryTileIdAttr,
180  rewriter.getI32IntegerAttr(tileId));
181  return alloca;
182 }
183 
184 /// Very naive lowering of in-memory tiles (i.e. tiles that were not assigned a
185 /// hardware tile ID) to ArmSME intrinsics. Currently, this works by assigning
186 /// the op to tile 0, then emitting a full tile swap between ZA and memory
187 /// before + after the tile op.
188 ///
189 /// Example:
190 ///
191 /// // Note: <IN MEMORY TILE> = tile ID >= 16.
192 /// arm_sme.tile_op { tile_id = <IN MEMORY TILE> }
193 ///
194 /// is converted to:
195 /// // At function entry:
196 /// %spill = memref.alloca ... : memref<?x?xty>
197 ///
198 /// // Around op:
199 /// scf.for %slice_idx {
200 /// %slice_to_save = "arm_sme.intr.read.horiz" ... <{tile_id = 0 : i32}>
201 /// "arm_sme.intr.ld1h.horiz"(%spill, %slice_idx) <{tile_id = 0 : i32}>
202 /// vector.store %slice_to_save, %spill[%slice_idx, %c0]
203 /// }
204 /// arm_sme.tile_op { tile_id = 0 }
205 /// scf.for %slice_idx {
206 /// %slice_to_save = "arm_sme.intr.read.horiz" ... <{tile_id = 0 : i32}>
207 /// "arm_sme.intr.ld1h.horiz"(%spill, %slice_idx) <{tile_id = 0 : i32}>
208 /// vector.store %slice_to_save, %spill[%slice_idx, %c0]
209 /// }
210 ///
211 /// Note that these spills/fills are not inserted earlier as concept of a
212 /// register, and the need to swap the contents, can't really be represented
213 /// correctly at a high level in MLIR.
214 ///
215 /// TODO: Reduce the spills/reloads to single slices where possible (and omit
216 /// redundant reloads). This could be done via a method on the
217 /// `ArmSMETileOpInterface` which returns how the operation uses ZA. E.g.:
218 ///
219 /// `tileOp.getZaUsage()` could return:
220 ///
221 /// struct ArmSMEOpZAUsage {
222 /// enum class Kind {
223 /// TileRead, // Omit store after tile operation.
224 /// TileWrite, // Omit load before tile operation.
225 /// TileReadWrite, // Needs both tile load and store.
226 /// SliceRead, // Spill single slice and omit store after operation.
227 /// SliceWrite, // Spill single slice and omit load before operation.
228 /// SliceReadWrite // Spill single slice.
229 /// };
230 /// Value sliceIndex {};
231 /// TileSliceLayout sliceLayout { TileSliceLayout::Horizontal };
232 /// };
233 ///
234 struct ConvertArmSMESpillsAndFillsToLLVM : public ConvertToLLVMPattern {
235 
236  ConvertArmSMESpillsAndFillsToLLVM(StringRef rootOpName,
237  const LLVMTypeConverter &typeConverter,
238  PatternBenefit benefit)
239  : ConvertToLLVMPattern(rootOpName, &typeConverter.getContext(),
240  typeConverter, benefit) {}
241 
242  LogicalResult
243  matchAndRewrite(Operation *op, ArrayRef<Value> operands,
244  ConversionPatternRewriter &rewriter) const override {
245  auto tileOp = cast<arm_sme::ArmSMETileOpInterface>(op);
246  // Tile has a real (hardware) tile. No spills/reloads required.
247  if (!tileOp.isInMemoryTile())
248  return failure();
249 
250  tileOp->emitWarning(
251  "failed to allocate SME virtual tile to operation, tile value will go "
252  "through memory, expect degraded performance");
253 
254  // Step 1. Create an alloca for the tile at the top of the function (if one
255  // does not already exist).
256  auto loc = tileOp.getLoc();
257  auto func = tileOp->getParentOfType<FunctionOpInterface>();
258  auto tileAlloca = getOrCreateAllocaForTile(rewriter, loc, func, tileOp,
259  tileOp.getTileId().getInt());
260 
261  // Step 2. Assign the op a real tile ID.
262  // For simplicity, we always use tile 0 (which always exists).
263  auto zeroTileId = rewriter.getI32IntegerAttr(0);
264  rewriter.modifyOpInPlace(tileOp, [&] { tileOp.setTileId(zeroTileId); });
265 
266  VectorType tileVectorType = tileOp.getTileType();
267  auto sliceType = VectorType::Builder(tileVectorType).dropDim(0);
268  auto swapInMemoryTileWithSMETileZero = [&] {
269  emitFullTileSwap(rewriter, loc, tileAlloca,
270  *arm_sme::getSMETileType(tileVectorType), sliceType,
271  zeroTileId);
272  };
273 
274  // Step 3. Emit tile swaps before and after the op.
275  // TODO: Reduce the amount spilled to the amount of data the `tileOp`
276  // touches (i.e. a single tile slice).
277  {
278  rewriter.setInsertionPoint(op);
279  // Swap the contents of ZA and the in-memory tile before the op.
280  swapInMemoryTileWithSMETileZero();
281  rewriter.setInsertionPointAfter(op);
282  // Swap the tile back out to memory again after the op.
283  swapInMemoryTileWithSMETileZero();
284  }
285 
286  return success();
287  }
288 
289  /// Extracts a pointer to a slice of an in-memory tile.
290  Value getInMemoryTileSlicePtr(RewriterBase &rewriter, Location loc,
291  Value tileMemory, Value sliceIndex) const {
292  auto llvmType = getTypeConverter()->convertType(tileMemory.getType());
293  auto descriptor =
294  rewriter.create<UnrealizedConversionCastOp>(loc, llvmType, tileMemory);
295  auto zero = rewriter.create<arith::ConstantIntOp>(loc, 0, /*width=*/64);
296  auto sliceIndexI64 = rewriter.create<arith::IndexCastOp>(
297  loc, rewriter.getI64Type(), sliceIndex);
298  return getStridedElementPtr(
299  loc, llvm::cast<MemRefType>(tileMemory.getType()),
300  descriptor.getResult(0), {sliceIndexI64, zero},
301  static_cast<ConversionPatternRewriter &>(rewriter));
302  }
303 
304  /// Emits an in-place swap of a slice of a tile in ZA and a slice of a
305  /// tile-sized memref (`tileAlloca`).
306  void emitSliceSwap(RewriterBase &rewriter, Location loc, Value tileAlloca,
307  arm_sme::ArmSMETileType tileType, VectorType sliceType,
308  IntegerAttr tileId, Value sliceIndex) const {
309  // Cast the slice index to an i32.
310  auto sliceIndexI32 = rewriter.create<arith::IndexCastOp>(
311  loc, rewriter.getI32Type(), sliceIndex);
312  // Create an all-true predicate for the slice.
313  auto predicateType = sliceType.clone(rewriter.getI1Type());
314  auto allTruePredicate = rewriter.create<arith::ConstantOp>(
315  loc, DenseElementsAttr::get(predicateType, true));
316  // Create padding vector (never used due to all-true predicate).
317  auto padVector = rewriter.create<LLVM::UndefOp>(loc, sliceType);
318  // Get a pointer to the current slice.
319  auto slicePtr =
320  getInMemoryTileSlicePtr(rewriter, loc, tileAlloca, sliceIndex);
321  // Read the value of the current slice from ZA.
322  auto currentTileSlice = rewriter.create<arm_sme::aarch64_sme_read_horiz>(
323  loc, sliceType, padVector, allTruePredicate, tileId, sliceIndexI32);
324  // Load the new tile slice back from memory into ZA.
325  createLoadTileSliceIntrinsic(
326  rewriter, loc, tileType, arm_sme::TileSliceLayout::Horizontal,
327  allTruePredicate, slicePtr, tileId, sliceIndexI32);
328  // Store the current tile slice to memory.
329  auto zero = rewriter.create<arith::ConstantIndexOp>(loc, 0);
330  rewriter.create<vector::StoreOp>(loc, currentTileSlice, tileAlloca,
331  ValueRange{sliceIndex, zero});
332  }
333 
334  /// Emits a full in-place swap of the contents of a tile in ZA and a
335  /// tile-sized memref (`tileAlloca`).
336  void emitFullTileSwap(RewriterBase &rewriter, Location loc, Value tileAlloca,
337  arm_sme::ArmSMETileType tileType, VectorType sliceType,
338  IntegerAttr tileId) const {
339  RewriterBase::InsertionGuard guard(rewriter);
340  // Create an scf.for over all tile slices.
341  auto minNumElts =
342  rewriter.create<arith::ConstantIndexOp>(loc, sliceType.getDimSize(0));
343  auto lowerBound = rewriter.create<arith::ConstantIndexOp>(loc, 0);
344  auto upperBound = rewriter.create<arith::MulIOp>(
345  loc, minNumElts, rewriter.create<vector::VectorScaleOp>(loc));
346  auto step = rewriter.create<arith::ConstantIndexOp>(loc, 1);
347  auto forOp = rewriter.create<scf::ForOp>(loc, lowerBound, upperBound, step);
348  // Emit a swap for each tile slice.
349  rewriter.setInsertionPointToStart(forOp.getBody());
350  auto sliceIndex = forOp.getInductionVar();
351  emitSliceSwap(rewriter, loc, tileAlloca, tileType, sliceType, tileId,
352  sliceIndex);
353  }
354 };
355 
356 enum class RequiresSpillsAndFills { Yes, No };
357 
358 /// Base class for ArmSME to LLVM conversion patterns. By default, this adds
359 /// spills and fills around ArmSME ops that use in-memory tile IDs. This can be
360 /// disabled by setting the `requiresSpillsAndFills` template parameter to
361 /// `RequiresSpillsAndFills::No`.
362 template <typename SourceOp, RequiresSpillsAndFills requiresSpillsAndFills =
363  RequiresSpillsAndFills::Yes>
364 struct ConvertArmSMEOpToLLVMPattern : ConvertOpToLLVMPattern<SourceOp> {
365  using ArmSMEOp = SourceOp;
367 
368  static constexpr bool requiresSpillsAndFillsConversion() {
369  return requiresSpillsAndFills == RequiresSpillsAndFills::Yes;
370  }
371 };
372 
373 template <typename Pattern>
374 static void addArmSMEConversionPattern(RewritePatternSet &patterns,
375  LLVMTypeConverter const &typeConverter) {
376  // Register spills/fills for ops that implement the
377  // `ArmSMETileOpInterface` and have `requiresSpillsAndFills` set to
378  // `RequiresSpillsAndFills::Yes`.
379  if constexpr (Pattern::requiresSpillsAndFillsConversion() &&
380  std::is_base_of_v<arm_sme::ArmSMETileOpInterface::Trait<
381  typename Pattern::ArmSMEOp>,
382  typename Pattern::ArmSMEOp>) {
383  // Add spill/fill conversions with a very high benefit to ensure
384  // they are lowered first.
385  patterns.add<ConvertArmSMESpillsAndFillsToLLVM>(
386  Pattern::ArmSMEOp::getOperationName(), typeConverter,
387  /*benefit=*/1337);
388  }
389  patterns.add<Pattern>(typeConverter);
390 }
391 
392 /// Helper to register `ConvertArmSMEOpToLLVMPattern` patterns.
393 template <typename... Patterns>
394 static void
395 addArmSMEConversionPatterns(RewritePatternSet &patterns,
396  LLVMTypeConverter const &typeConverter) {
397  (addArmSMEConversionPattern<Patterns>(patterns, typeConverter), ...);
398 }
399 
400 /// Lower 'arm_sme.zero' to SME intrinsics.
401 ///
402 /// BEFORE:
403 /// ```mlir
404 /// %v = arm_sme.zero {tile_id = 0 : i32} : vector<[4]x[4]xi32>
405 /// ```
406 ///
407 /// AFTER:
408 /// ```mlir
409 /// "arm_sme.intr.zero"() <{tile_mask = 17 : i32}> : () -> ()
410 /// %v = arm_sme.get_tile : vector<[4]x[4]xi32>
411 /// ```
412 ///
413 /// The 'arm_sme.get_tile' (which models the return) will fold away once all
414 /// ArmSME ops have been converted to LLVM intrinsics.
415 struct ZeroOpConversion : public ConvertArmSMEOpToLLVMPattern<arm_sme::ZeroOp> {
416  using ConvertArmSMEOpToLLVMPattern::ConvertArmSMEOpToLLVMPattern;
417 
418  LogicalResult
419  matchAndRewrite(arm_sme::ZeroOp zero, OpAdaptor adaptor,
420  ConversionPatternRewriter &rewriter) const override {
421  auto loc = zero.getLoc();
422 
423  auto tileId = getTileIdOrError(zero);
424  if (!tileId)
425  return failure();
426 
427  // Get the base mask for tile based on the element size.
428  // The base mask is just the mask to zero the first tile (of a size).
429  // These masks are derived from:
430  // https://developer.arm.com/documentation/ddi0602/2022-06/SME-Instructions/ZERO--Zero-a-list-of-64-bit-element-ZA-tiles-
431  arm_sme::ArmSMETileType tileType =
432  *arm_sme::getSMETileType(zero.getTileType());
433  auto baseMaskForSize = [&] {
434  switch (tileType) {
435  case arm_sme::ArmSMETileType::ZAB:
436  // Zeroing the 8-bit ZA0.B tile is equivalent to zeroing all eight
437  // 64-bit element tiles named ZA0.D to ZA7.D.
438  return 0b1111'1111;
439  case arm_sme::ArmSMETileType::ZAH:
440  // Zeroing the 16-bit ZA0.H tile is equivalent to zeroing 64-bit
441  // element tiles named ZA0.D, ZA2.D, ZA4.D, and ZA6.D. Shift this left
442  // once for ZA1.H.
443  return 0b0101'0101;
444  case arm_sme::ArmSMETileType::ZAS:
445  // Zeroing the 32-bit ZA0.S tile is equivalent to zeroing 64-bit
446  // element tiles named ZA0.D and ZA4.D.
447  // Shift left by 1, 2, or 3 respectively for ZA1.S, ZA2.S, ZA3.S.
448  return 0b0001'0001;
449  case arm_sme::ArmSMETileType::ZAD:
450  // Zeroing one of the a 64-bit tiles ZA0.D to ZA7.D just requires
451  // setting the bit for that tile.
452  return 0b0000'0001;
453  default:
454  llvm_unreachable("bad element size");
455  }
456  }();
457 
458  // The actual mask is just the base mask shifted by the tile ID.
459  // This will be folded to a constant after tile allocation.
460  //
461  // The shift is just derived from the layout of the tiles, and that the tile
462  // ID is the index of the tile. For example, looking at the 32-bit ZAx.S
463  // tiles:
464  //
465  // ZA0.S = ZA0.D and ZA4.D
466  // * Tile ID -> 0
467  // * Mask -> 00010001 = (00010001 << 0)
468  // ZA1.S = ZA1.D and ZA5.D
469  // * Tile ID -> 1
470  // * Mask -> 00100010 = (00010001 << 1)
471  // ZA2.S = ZA2.D and ZA6.D
472  // * Tile ID -> 2
473  // * Mask -> 01000100 = (00010001 << 2)
474  // ZA3.S = ZA3.D and ZA7.D
475  // * Tile ID -> 3
476  // * Mask -> 10001000 = (00010001 << 3)
477  //
478  // This holds for all tile sizes.
479  int32_t zeroMask = baseMaskForSize << int32_t(tileId.getInt());
480  rewriter.create<arm_sme::aarch64_sme_zero>(
481  loc, rewriter.getI32IntegerAttr(zeroMask));
482 
483  // Create a placeholder op to preserve dataflow.
484  rewriter.replaceOpWithNewOp<arm_sme::GetTileOp>(zero, zero.getVectorType());
485 
486  return success();
487  }
488 };
489 
490 /// Lower `arm_sme.load_tile_slice` to SME intrinsics.
491 struct LoadTileSliceConversion
492  : public ConvertArmSMEOpToLLVMPattern<arm_sme::LoadTileSliceOp> {
493  using ConvertArmSMEOpToLLVMPattern::ConvertArmSMEOpToLLVMPattern;
494 
495  LogicalResult
496  matchAndRewrite(arm_sme::LoadTileSliceOp loadTileSliceOp,
497  arm_sme::LoadTileSliceOp::Adaptor adaptor,
498  ConversionPatternRewriter &rewriter) const override {
499  auto loc = loadTileSliceOp.getLoc();
500  auto tileId = getTileIdOrError(loadTileSliceOp);
501  if (!tileId)
502  return failure();
503 
504  Value ptr = this->getStridedElementPtr(loc, loadTileSliceOp.getMemRefType(),
505  adaptor.getBase(),
506  adaptor.getIndices(), rewriter);
507 
508  auto tileSlice = loadTileSliceOp.getTileSliceIndex();
509 
510  // Cast tile slice to i32 for intrinsic.
511  auto tileSliceI32 = rewriter.create<arith::IndexCastUIOp>(
512  loc, rewriter.getI32Type(), tileSlice);
513 
514  // Create all active predicate mask.
515  auto maskOp = loadTileSliceOp.getMask();
516 
517  auto tileVectorType = loadTileSliceOp.getVectorType();
518  arm_sme::ArmSMETileType tileType = *arm_sme::getSMETileType(tileVectorType);
519  arm_sme::TileSliceLayout layout = loadTileSliceOp.getLayout();
520 
521  // Create 'arm_sme.intr.ld1*.(horiz|vert)' intrinsic to load ZA tile slice.
522  createLoadTileSliceIntrinsic(rewriter, loc, tileType, layout, maskOp, ptr,
523  tileId, tileSliceI32);
524 
525  // The load intrinsics have no result, replace 'arm_sme.tile_load' with
526  // the input tile to preserve dataflow.
527  rewriter.replaceOp(loadTileSliceOp, loadTileSliceOp.getTile());
528 
529  return success();
530  }
531 };
532 
533 /// Lower for `arm_sme.store_tile_slice` to SME intrinsics.
534 struct StoreTileSliceConversion
535  : public ConvertArmSMEOpToLLVMPattern<arm_sme::StoreTileSliceOp> {
536  using ConvertArmSMEOpToLLVMPattern::ConvertArmSMEOpToLLVMPattern;
537 
538  LogicalResult
539  matchAndRewrite(arm_sme::StoreTileSliceOp storeTileSliceOp,
540  arm_sme::StoreTileSliceOp::Adaptor adaptor,
541  ConversionPatternRewriter &rewriter) const override {
542  auto loc = storeTileSliceOp.getLoc();
543  auto tileVectorType = storeTileSliceOp.getVectorType();
544 
545  auto tileId = getTileIdOrError(storeTileSliceOp);
546  if (!tileId)
547  return failure();
548 
549  // Create 'arm_sme.intr.st1*.horiz' intrinsic to store ZA tile slice.
550  Value ptr = this->getStridedElementPtr(
551  loc, storeTileSliceOp.getMemRefType(), adaptor.getBase(),
552  adaptor.getIndices(), rewriter);
553 
554  auto tileSlice = storeTileSliceOp.getTileSliceIndex();
555 
556  // Cast tile slice to i32 for intrinsic.
557  auto tileSliceI32 = rewriter.create<arith::IndexCastUIOp>(
558  loc, rewriter.getI32Type(), tileSlice);
559 
560  auto maskOp = storeTileSliceOp.getMask();
561 
562  arm_sme::TileSliceLayout layout = storeTileSliceOp.getLayout();
563  arm_sme::ArmSMETileType tileType = *arm_sme::getSMETileType(tileVectorType);
564 
565  rewriter.replaceOp(storeTileSliceOp,
566  createStoreTileSliceIntrinsic(rewriter, loc, tileType,
567  layout, maskOp, ptr,
568  tileId, tileSliceI32));
569 
570  return success();
571  }
572 };
573 
574 /// Lower `arm_sme.move_vector_to_tile_slice` to SME intrinsics.
575 struct MoveVectorToTileSliceConversion
576  : public ConvertArmSMEOpToLLVMPattern<arm_sme::MoveVectorToTileSliceOp> {
577  using ConvertArmSMEOpToLLVMPattern::ConvertArmSMEOpToLLVMPattern;
578 
579  LogicalResult
580  matchAndRewrite(arm_sme::MoveVectorToTileSliceOp moveVectorToTileSliceOp,
581  arm_sme::MoveVectorToTileSliceOp::Adaptor adaptor,
582  ConversionPatternRewriter &rewriter) const override {
583  auto loc = moveVectorToTileSliceOp.getLoc();
584  auto tileType = moveVectorToTileSliceOp.getTileType();
585 
586  auto tileId = getTileIdOrError(moveVectorToTileSliceOp);
587  if (!tileId)
588  return failure();
589 
590  auto tileSlice = moveVectorToTileSliceOp.getTileSliceIndex();
591 
592  // Cast tile slice from index to i32 for intrinsic.
593  auto tileSliceI32 = rewriter.create<arith::IndexCastUIOp>(
594  loc, rewriter.getI32Type(), tileSlice);
595 
596  // Create all active predicate mask.
597  auto one = rewriter.create<arith::ConstantOp>(
598  loc, rewriter.getI1Type(),
599  rewriter.getIntegerAttr(rewriter.getI1Type(), 1));
600  auto predTy = VectorType::get(tileType.getShape()[0], rewriter.getI1Type(),
601  /*scalableDims=*/{true});
602  auto allActiveMask = rewriter.create<vector::SplatOp>(loc, predTy, one);
603 
604  // Create 'arm_sme.intr.write.(horiz|vert)' to write vector to tile slice.
605  switch (moveVectorToTileSliceOp.getLayout()) {
606  case arm_sme::TileSliceLayout::Horizontal:
607  rewriter.create<arm_sme::aarch64_sme_write_horiz>(
608  loc, tileId, tileSliceI32, allActiveMask,
609  moveVectorToTileSliceOp.getVector());
610  break;
611  case arm_sme::TileSliceLayout::Vertical:
612  rewriter.create<arm_sme::aarch64_sme_write_vert>(
613  loc, tileId, tileSliceI32, allActiveMask,
614  moveVectorToTileSliceOp.getVector());
615  break;
616  }
617 
618  // Intrinsic has no result, replace 'arm_sme.move_vector_to_tile_slice' with
619  // the input tile to preserve dataflow.
620  rewriter.replaceOp(moveVectorToTileSliceOp,
621  moveVectorToTileSliceOp.getTile());
622 
623  return success();
624  }
625 };
626 
627 /// Lower `arm_sme.move_tile_slice_to_vector` to SME intrinsics.
628 struct MoveTileSliceToVectorConversion
629  : public ConvertArmSMEOpToLLVMPattern<arm_sme::MoveTileSliceToVectorOp> {
630  using ConvertArmSMEOpToLLVMPattern::ConvertArmSMEOpToLLVMPattern;
631 
632  LogicalResult
633  matchAndRewrite(arm_sme::MoveTileSliceToVectorOp moveTileSliceToVector,
634  OpAdaptor,
635  ConversionPatternRewriter &rewriter) const override {
636  auto loc = moveTileSliceToVector.getLoc();
637  auto sliceType = moveTileSliceToVector.getSliceType();
638  auto sliceIndex = moveTileSliceToVector.getTileSliceIndex();
639 
640  auto tileId = getTileIdOrError(moveTileSliceToVector);
641  if (!tileId)
642  return failure();
643 
644  // Create an 'all true' predicate for the tile slice.
645  auto predicateType = sliceType.cloneWith({}, rewriter.getI1Type());
646  auto allTruePredicate = rewriter.create<arith::ConstantOp>(
647  loc, DenseElementsAttr::get(predicateType, true));
648 
649  // Zero destination/fallback for tile slice extraction.
650  auto zeroVector = rewriter.create<arith::ConstantOp>(
651  loc, sliceType, rewriter.getZeroAttr(sliceType));
652 
653  // Cast tile slice from index to i32 for intrinsic.
654  auto sliceIndexI32 = rewriter.create<arith::IndexCastOp>(
655  loc, rewriter.getI32Type(), sliceIndex);
656 
657  // Create 'arm_sme.intr.read.(horiz|vert)' to extract the tile slice.
658  switch (moveTileSliceToVector.getLayout()) {
659  case arm_sme::TileSliceLayout::Horizontal:
660  rewriter.replaceOpWithNewOp<arm_sme::aarch64_sme_read_horiz>(
661  moveTileSliceToVector, sliceType, zeroVector, allTruePredicate,
662  tileId, sliceIndexI32);
663  break;
664  case arm_sme::TileSliceLayout::Vertical:
665  rewriter.replaceOpWithNewOp<arm_sme::aarch64_sme_read_vert>(
666  moveTileSliceToVector, sliceType, zeroVector, allTruePredicate,
667  tileId, sliceIndexI32);
668  break;
669  }
670 
671  return success();
672  }
673 };
674 
675 /// Lower `arm_sme.outerproduct` to SME MOPA intrinsics.
676 ///
677 /// Example:
678 ///
679 /// %0 = arm_sme.outerproduct %lhs, %rhs acc(%acc)
680 /// : vector<[4]xf32>, vector<[4]xf32>
681 ///
682 /// is converted to:
683 ///
684 /// "arm_sme.intr.mopa"(%ptrue_s, %ptrue_s, %lhs, %rhs) <{tile_id = 0 : i32}>
685 /// : (vector<[4]xi1>, vector<[4]xi1>, vector<[4]xf32>,
686 /// vector<[4]xf32>) -> ()
687 ///
688 /// Currently only supports FMOPA and BFMOPA (non-widening).
689 struct OuterProductOpConversion
690  : public ConvertArmSMEOpToLLVMPattern<arm_sme::OuterProductOp> {
691  using ConvertArmSMEOpToLLVMPattern::ConvertArmSMEOpToLLVMPattern;
692 
693  LogicalResult
694  matchAndRewrite(arm_sme::OuterProductOp outerProductOp,
695  arm_sme::OuterProductOp::Adaptor adaptor,
696  ConversionPatternRewriter &rewriter) const override {
697  auto tileId = getTileIdOrError(outerProductOp);
698  if (!tileId)
699  return failure();
700 
701  auto isSupportedType = [](VectorType vectorType) {
702  // TODO: the FP outer product instruction variants are predicated on
703  // different features [1]:
704  //
705  // * FMOPA (non-widening)
706  // * half-precision - +sme2p1,+sme-f16f16
707  // * single-precision - +sme
708  // * double-precision - +sme-f64f64
709  // * BFMOPA
710  // * half-precision - +sme2p1,+b16b16
711  //
712  // It should be possible to control lowering based on target features.
713  // [1]
714  // https://developer.arm.com/downloads/-/exploration-tools/feature-names-for-a-profile
715  if ((vectorType.getRank() != 2) || !vectorType.allDimsScalable())
716  return false;
717 
718  auto elementType = vectorType.getElementType();
719 
720  if (!elementType.isF16() && !elementType.isBF16() &&
721  !elementType.isF32() && !elementType.isF64())
722  return false;
723 
724  unsigned minNumElts = arm_sme::MinStreamingVectorLengthInBits /
725  vectorType.getElementTypeBitWidth();
726  return vectorType.getShape() ==
727  ArrayRef<int64_t>({minNumElts, minNumElts});
728  };
729 
730  // TODO: Support CombiningKind::Sub for outer products.
731  if (outerProductOp.getKind() != arm_sme::CombiningKind::Add)
732  return outerProductOp.emitError("unsupported kind");
733 
734  auto resultVectorType = outerProductOp.getResultType();
735  if (!isSupportedType(resultVectorType))
736  return outerProductOp.emitError("unsupported type");
737 
738  auto loc = outerProductOp.getLoc();
739 
740  Value acc = outerProductOp.getAcc();
741  if (!acc) {
742  // Initalize accumulator with zero.
743  auto zero = rewriter.create<arm_sme::ZeroOp>(loc, resultVectorType);
744  zero.setTileId(tileId);
745  acc = zero;
746  }
747 
748  Value lhsMask = outerProductOp.getLhsMask();
749  Value rhsMask = outerProductOp.getRhsMask();
750 
751  if (!lhsMask || !rhsMask) {
752  auto predTy =
753  outerProductOp.getLhsType().cloneWith({}, rewriter.getI1Type());
754  Value allActiveMask = rewriter.create<arith::ConstantOp>(
755  loc, DenseElementsAttr::get(predTy, true));
756  lhsMask = allActiveMask;
757  rhsMask = allActiveMask;
758  }
759 
760  // Create 'arm_sme.intr.mopa' outer product intrinsic.
761  rewriter.create<arm_sme::aarch64_sme_mopa>(loc, tileId, lhsMask, rhsMask,
762  outerProductOp.getLhs(),
763  outerProductOp.getRhs());
764 
765  // The outerproduct intrinsics have no result, replace
766  // 'arm_sme.outerproduct' with the input tile to preserve dataflow.
767  rewriter.replaceOp(outerProductOp, acc);
768 
769  return success();
770  }
771 };
772 
773 /// Lower 2-way and 4-way widening outer products to intrinsics.
774 template <class OuterProductWideningOp, class OuterProductWideningIntrOp>
775 struct OuterProductWideningOpConversion
776  : public ConvertArmSMEOpToLLVMPattern<OuterProductWideningOp> {
777  using ConvertArmSMEOpToLLVMPattern<
778  OuterProductWideningOp>::ConvertArmSMEOpToLLVMPattern;
779 
780  LogicalResult
781  matchAndRewrite(OuterProductWideningOp op,
782  typename OuterProductWideningOp::Adaptor adaptor,
783  ConversionPatternRewriter &rewriter) const override {
784  auto tileId = getTileIdOrError(op);
785  if (!tileId)
786  return failure();
787 
788  auto loc = op.getLoc();
789  Value acc = op.getAcc();
790  if (!acc) {
791  // Initalize accumulator with zero.
792  auto zero = rewriter.create<arm_sme::ZeroOp>(loc, op.getResultType());
793  zero.setTileId(tileId);
794  acc = zero;
795  }
796 
797  Value lhsMask = op.getLhsMask();
798  Value rhsMask = op.getRhsMask();
799  if (!lhsMask || !rhsMask) {
800  auto predTy = op.getLhsType().cloneWith({}, rewriter.getI1Type());
801  Value allActiveMask = rewriter.create<arith::ConstantOp>(
802  loc, DenseElementsAttr::get(predTy, true));
803  lhsMask = allActiveMask;
804  rhsMask = allActiveMask;
805  }
806 
807  rewriter.create<OuterProductWideningIntrOp>(
808  loc, tileId, lhsMask, rhsMask, adaptor.getLhs(), adaptor.getRhs());
809 
810  // The outerproduct intrinsics have no result, replace
811  // 'arm_sme.outerproduct' with the input tile to preserve dataflow.
812  rewriter.replaceOp(op, acc);
813 
814  return success();
815  }
816 };
817 
818 /// Lower `arm_sme.streaming_vl` to SME CNTS intrinsics.
819 ///
820 /// Example:
821 ///
822 /// %0 = arm_sme.streaming_vl <half>
823 ///
824 /// is converted to:
825 ///
826 /// %cnt = "arm_sme.intr.cntsh"() : () -> i64
827 /// %0 = arith.index_cast %cnt : i64 to index
828 ///
829 struct StreamingVLOpConversion
830  : public ConvertArmSMEOpToLLVMPattern<arm_sme::StreamingVLOp,
831  RequiresSpillsAndFills::No> {
832  using ConvertArmSMEOpToLLVMPattern::ConvertArmSMEOpToLLVMPattern;
833 
834  LogicalResult
835  matchAndRewrite(arm_sme::StreamingVLOp streamingVlOp,
836  arm_sme::StreamingVLOp::Adaptor adaptor,
837  ConversionPatternRewriter &rewriter) const override {
838  auto loc = streamingVlOp.getLoc();
839  auto i64Type = rewriter.getI64Type();
840  auto *intrOp = [&]() -> Operation * {
841  switch (streamingVlOp.getTypeSize()) {
842  case arm_sme::TypeSize::Byte:
843  return rewriter.create<arm_sme::aarch64_sme_cntsb>(loc, i64Type);
844  case arm_sme::TypeSize::Half:
845  return rewriter.create<arm_sme::aarch64_sme_cntsh>(loc, i64Type);
846  case arm_sme::TypeSize::Word:
847  return rewriter.create<arm_sme::aarch64_sme_cntsw>(loc, i64Type);
848  case arm_sme::TypeSize::Double:
849  return rewriter.create<arm_sme::aarch64_sme_cntsd>(loc, i64Type);
850  }
851  }();
852  rewriter.replaceOpWithNewOp<arith::IndexCastOp>(
853  streamingVlOp, rewriter.getIndexType(), intrOp->getResult(0));
854  return success();
855  }
856 };
857 
858 } // namespace
859 
860 namespace {
861 
862 struct ConvertArmSMEToLLVMPass
863  : public impl::ConvertArmSMEToLLVMBase<ConvertArmSMEToLLVMPass> {
864  ConvertArmSMEToLLVMPass(bool dumpTileLiveRanges) {
865  this->dumpTileLiveRanges = dumpTileLiveRanges;
866  }
867  void runOnOperation() override {
868  auto function = getOperation();
869 
870  if (failed(arm_sme::allocateSMETiles(function, dumpTileLiveRanges)))
871  return signalPassFailure();
872 
874  RewritePatternSet patterns(&getContext());
875  LLVMTypeConverter converter(&getContext());
877  populateArmSMEToLLVMConversionPatterns(converter, patterns);
878 
879  if (failed(applyPartialConversion(function, target, std::move(patterns))))
880  signalPassFailure();
881 
882  // Walk the function and fail if there are unexpected operations on SME
883  // tile types after conversion.
884  function->walk([&](Operation *op) {
885  // These ops are legal post conversion, skip these.
886  if (isa<arm_sme::CopyTileOp, arm_sme::GetTileOp, cf::BranchOp>(op) ||
887  !op->isRegistered())
888  return;
889  auto isSMETileType = [](Type type) {
890  return arm_sme::isValidSMETileVectorType(type);
891  };
892  if (llvm::any_of(op->getResultTypes(), isSMETileType) ||
893  llvm::any_of(op->getOperandTypes(), isSMETileType)) {
894  op->emitOpError("unexpected operation with SME tile type after "
895  "conversion to LLVM");
896  signalPassFailure();
897  }
898  });
899  }
900 };
901 
902 } // namespace
903 
905  target.addIllegalDialect<arm_sme::ArmSMEDialect>();
906  target.addLegalOp<
907  arm_sme::aarch64_sme_zero, arm_sme::aarch64_sme_str,
908  arm_sme::aarch64_sme_ld1b_horiz, arm_sme::aarch64_sme_ld1h_horiz,
909  arm_sme::aarch64_sme_ld1w_horiz, arm_sme::aarch64_sme_ld1d_horiz,
910  arm_sme::aarch64_sme_ld1q_horiz, arm_sme::aarch64_sme_st1b_horiz,
911  arm_sme::aarch64_sme_st1h_horiz, arm_sme::aarch64_sme_st1w_horiz,
912  arm_sme::aarch64_sme_st1d_horiz, arm_sme::aarch64_sme_st1q_horiz,
913  arm_sme::aarch64_sme_ld1b_vert, arm_sme::aarch64_sme_ld1h_vert,
914  arm_sme::aarch64_sme_ld1w_vert, arm_sme::aarch64_sme_ld1d_vert,
915  arm_sme::aarch64_sme_ld1q_vert, arm_sme::aarch64_sme_st1b_vert,
916  arm_sme::aarch64_sme_st1h_vert, arm_sme::aarch64_sme_st1w_vert,
917  arm_sme::aarch64_sme_st1d_vert, arm_sme::aarch64_sme_st1q_vert,
918  arm_sme::aarch64_sme_read_horiz, arm_sme::aarch64_sme_read_vert,
919  arm_sme::aarch64_sme_write_horiz, arm_sme::aarch64_sme_write_vert,
920  arm_sme::aarch64_sme_mopa, arm_sme::aarch64_sme_mopa_wide,
921  arm_sme::aarch64_sme_mops_wide, arm_sme::aarch64_sme_smopa_wide,
922  arm_sme::aarch64_sme_smops_wide, arm_sme::aarch64_sme_umopa_wide,
923  arm_sme::aarch64_sme_umops_wide, arm_sme::aarch64_sme_smopa_za32,
924  arm_sme::aarch64_sme_smops_za32, arm_sme::aarch64_sme_umopa_za32,
925  arm_sme::aarch64_sme_umops_za32, arm_sme::aarch64_sme_sumopa_wide,
926  arm_sme::aarch64_sme_sumops_wide, arm_sme::aarch64_sme_usmopa_wide,
927  arm_sme::aarch64_sme_usmops_wide, arm_sme::aarch64_sme_cntsb,
928  arm_sme::aarch64_sme_cntsh, arm_sme::aarch64_sme_cntsw,
929  arm_sme::aarch64_sme_cntsd>();
930  target.addLegalDialect<arith::ArithDialect,
931  /* The following are used to lower tile spills/fills */
932  vector::VectorDialect, scf::SCFDialect,
933  memref::MemRefDialect>();
934  // Pseudo operations. These cannot be code-generated but may exist in the
935  // input IR, or be generated during the conversion. They need to be eliminated
936  // before the final conversion to LLVM IR (and likely will be due to DCE).
937  target.addLegalOp<arm_sme::GetTileOp, arm_sme::CopyTileOp,
938  UnrealizedConversionCastOp>();
939 }
940 
942  RewritePatternSet &patterns) {
943  converter.addConversion([&](VectorType type) -> std::optional<Type> {
944  // There's no LLVM type for SME tiles, but after lowering to intrinsics all
945  // SME vector types should be eliminated.
947  return type;
948  return std::nullopt;
949  });
950 
951  addArmSMEConversionPatterns<
952  LoadTileSliceConversion, MoveTileSliceToVectorConversion,
953  MoveVectorToTileSliceConversion, StoreTileSliceConversion,
954  StreamingVLOpConversion, OuterProductOpConversion,
955  OuterProductWideningOpConversion<arm_sme::FMopa2WayOp,
956  arm_sme::aarch64_sme_mopa_wide>,
957  OuterProductWideningOpConversion<arm_sme::FMops2WayOp,
958  arm_sme::aarch64_sme_mops_wide>,
959  OuterProductWideningOpConversion<arm_sme::SMopa2WayOp,
960  arm_sme::aarch64_sme_smopa_za32>,
961  OuterProductWideningOpConversion<arm_sme::SMops2WayOp,
962  arm_sme::aarch64_sme_smops_za32>,
963  OuterProductWideningOpConversion<arm_sme::UMopa2WayOp,
964  arm_sme::aarch64_sme_umopa_za32>,
965  OuterProductWideningOpConversion<arm_sme::UMops2WayOp,
966  arm_sme::aarch64_sme_umops_za32>,
967  OuterProductWideningOpConversion<arm_sme::SMopa4WayOp,
968  arm_sme::aarch64_sme_smopa_wide>,
969  OuterProductWideningOpConversion<arm_sme::SMops4WayOp,
970  arm_sme::aarch64_sme_smops_wide>,
971  OuterProductWideningOpConversion<arm_sme::UMopa4WayOp,
972  arm_sme::aarch64_sme_umopa_wide>,
973  OuterProductWideningOpConversion<arm_sme::UMops4WayOp,
974  arm_sme::aarch64_sme_umops_wide>,
975  OuterProductWideningOpConversion<arm_sme::SuMopa4WayOp,
976  arm_sme::aarch64_sme_sumopa_wide>,
977  OuterProductWideningOpConversion<arm_sme::SuMops4WayOp,
978  arm_sme::aarch64_sme_sumops_wide>,
979  OuterProductWideningOpConversion<arm_sme::UsMopa4WayOp,
980  arm_sme::aarch64_sme_usmopa_wide>,
981  OuterProductWideningOpConversion<arm_sme::UsMops4WayOp,
982  arm_sme::aarch64_sme_usmops_wide>,
983  ZeroOpConversion>(patterns, converter);
984 }
985 
986 std::unique_ptr<Pass>
987 mlir::createConvertArmSMEToLLVMPass(bool dumpTileLiveRanges) {
988  return std::make_unique<ConvertArmSMEToLLVMPass>(dumpTileLiveRanges);
989 }
static MLIRContext * getContext(OpFoldResult val)
IntegerAttr getI32IntegerAttr(int32_t value)
Definition: Builders.cpp:220
IntegerAttr getIntegerAttr(Type type, int64_t value)
Definition: Builders.cpp:242
IntegerType getI64Type()
Definition: Builders.cpp:89
IntegerType getI32Type()
Definition: Builders.cpp:87
TypedAttr getZeroAttr(Type type)
Definition: Builders.cpp:335
IntegerType getI1Type()
Definition: Builders.cpp:77
IndexType getIndexType()
Definition: Builders.cpp:75
This class implements a pattern rewriter for use with ConversionPatterns.
void replaceOp(Operation *op, ValueRange newValues) override
PatternRewriter hook for replacing an operation.
This class describes a specific conversion target.
void addLegalOp(OperationName op)
Register the given operations as legal.
void addLegalDialect(StringRef name, Names... names)
Register the operations of the given dialects as legal.
void addIllegalDialect(StringRef name, Names... names)
Register the operations of the given dialects as illegal, i.e.
Utility class for operation conversions targeting the LLVM dialect that match exactly one source oper...
Definition: Pattern.h:143
Base class for operation conversions targeting the LLVM IR dialect.
Definition: Pattern.h:41
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
Derived class that automatically populates legalization information for different LLVM ops.
Conversion from types to the LLVM IR dialect.
Definition: TypeConverter.h:34
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:63
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
Definition: Builders.h:434
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Definition: Builders.h:401
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:468
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
Definition: Builders.h:415
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
bool isRegistered()
Returns true if this operation has a registered operation description, otherwise false.
Definition: Operation.h:129
Operation * clone(IRMapping &mapper, CloneOptions options=CloneOptions::all())
Create a deep copy of this operation, remapping any operands that use values outside of the operation...
Definition: Operation.cpp:717
Location getLoc()
The source location the operation was defined or derived from.
Definition: Operation.h:223
operand_type_range getOperandTypes()
Definition: Operation.h:392
result_type_range getResultTypes()
Definition: Operation.h:423
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
Definition: Operation.cpp:671
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
Definition: PatternMatch.h:34
This class contains all of the data related to a pattern, but does not contain any methods or logic f...
Definition: PatternMatch.h:73
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
Definition: PatternMatch.h:847
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
Definition: PatternMatch.h:400
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
Definition: PatternMatch.h:630
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Definition: PatternMatch.h:536
void addConversion(FnT &&callback)
Register a conversion function.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:381
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:129
This is a builder type that keeps local references to arguments.
Definition: BuiltinTypes.h:311
Builder & dropDim(unsigned pos)
Erase a dim from shape @pos.
Definition: BuiltinTypes.h:336
std::optional< ArmSMETileType > getSMETileType(VectorType)
Returns the type of SME tile this vector type corresponds to, or none if the vector type does not fit...
Definition: Utils.cpp:44
LogicalResult allocateSMETiles(FunctionOpInterface function, bool dumpRanges=false)
Allocate tile IDs to all ArmSME operations in a function.
unsigned getSMETileSliceMinNumElts(Type type)
Return minimum number of elements for the given element type in a vector of SVL bits.
Definition: Utils.cpp:18
bool isValidSMETileVectorType(VectorType vType)
Returns true if vType is a valid vector type for an SME tile or false otherwise.
Definition: Utils.cpp:29
constexpr unsigned MinStreamingVectorLengthInBits
Definition: Utils.h:33
Include the generated interface declarations.
std::unique_ptr< Pass > createConvertArmSMEToLLVMPass(bool dumpTileLiveRanges=false)
Create a pass to convert from the ArmSME dialect to LLVM intrinsics.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
void configureArmSMEToLLVMConversionLegality(ConversionTarget &target)
Configure target to convert from the ArmSME dialect to LLVM intrinsics.
LogicalResult applyPartialConversion(ArrayRef< Operation * > ops, const ConversionTarget &target, const FrozenRewritePatternSet &patterns, ConversionConfig config=ConversionConfig())
Below we define several entry points for operation conversion.
void populateArmSMEToLLVMConversionPatterns(LLVMTypeConverter &converter, RewritePatternSet &patterns)
Populate the given list with patterns that convert from the ArmSME dialect to LLVM intrinsics.