MLIR  19.0.0git
ArmSMEToLLVM.cpp
Go to the documentation of this file.
1 //===- ArmSMEToLLVM.cpp - Convert ArmSME to LLVM dialect ------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements lowering of ArmSME operations to LLVM intrinsics.
10 //
11 //===----------------------------------------------------------------------===//
12 
14 
24 #include "mlir/Pass/Pass.h"
26 
27 namespace mlir {
28 #define GEN_PASS_DEF_CONVERTARMSMETOLLVM
29 #include "mlir/Conversion/Passes.h.inc"
30 } // namespace mlir
31 
32 using namespace mlir;
33 
34 namespace {
35 
36 static constexpr StringLiteral kInMemoryTileIdAttr("arm_sme.in_memory_tile_id");
37 
38 /// Helper to create an arm_sme.intr.ld1*.(horiz|vert)' intrinsic.
39 static Operation *createLoadTileSliceIntrinsic(
40  RewriterBase &rewriter, Location loc, arm_sme::ArmSMETileType type,
41  arm_sme::TileSliceLayout layout, Value maskOp, Value ptr,
42  IntegerAttr tileId, Value tileSliceI32) {
43  if (layout == arm_sme::TileSliceLayout::Horizontal) {
44  switch (type) {
45  case arm_sme::ArmSMETileType::ZAB:
46  return rewriter.create<arm_sme::aarch64_sme_ld1b_horiz>(
47  loc, maskOp, ptr, tileId, tileSliceI32);
48  case arm_sme::ArmSMETileType::ZAH:
49  return rewriter.create<arm_sme::aarch64_sme_ld1h_horiz>(
50  loc, maskOp, ptr, tileId, tileSliceI32);
51  case arm_sme::ArmSMETileType::ZAS:
52  return rewriter.create<arm_sme::aarch64_sme_ld1w_horiz>(
53  loc, maskOp, ptr, tileId, tileSliceI32);
54  case arm_sme::ArmSMETileType::ZAD:
55  return rewriter.create<arm_sme::aarch64_sme_ld1d_horiz>(
56  loc, maskOp, ptr, tileId, tileSliceI32);
57  case arm_sme::ArmSMETileType::ZAQ:
58  return rewriter.create<arm_sme::aarch64_sme_ld1q_horiz>(
59  loc, maskOp, ptr, tileId, tileSliceI32);
60  }
61  } else {
62  switch (type) {
63  case arm_sme::ArmSMETileType::ZAB:
64  return rewriter.create<arm_sme::aarch64_sme_ld1b_vert>(
65  loc, maskOp, ptr, tileId, tileSliceI32);
66  case arm_sme::ArmSMETileType::ZAH:
67  return rewriter.create<arm_sme::aarch64_sme_ld1h_vert>(
68  loc, maskOp, ptr, tileId, tileSliceI32);
69  case arm_sme::ArmSMETileType::ZAS:
70  return rewriter.create<arm_sme::aarch64_sme_ld1w_vert>(
71  loc, maskOp, ptr, tileId, tileSliceI32);
72  case arm_sme::ArmSMETileType::ZAD:
73  return rewriter.create<arm_sme::aarch64_sme_ld1d_vert>(
74  loc, maskOp, ptr, tileId, tileSliceI32);
75  case arm_sme::ArmSMETileType::ZAQ:
76  return rewriter.create<arm_sme::aarch64_sme_ld1q_vert>(
77  loc, maskOp, ptr, tileId, tileSliceI32);
78  break;
79  }
80  }
81 }
82 
83 /// Helper to create an arm_sme.intr.st1*.(horiz|vert)' intrinsic.
84 static Operation *createStoreTileSliceIntrinsic(
85  RewriterBase &rewriter, Location loc, arm_sme::ArmSMETileType type,
86  arm_sme::TileSliceLayout layout, Value maskOp, Value ptr,
87  IntegerAttr tileId, Value tileSliceI32) {
88  if (layout == arm_sme::TileSliceLayout::Horizontal) {
89  switch (type) {
90  case arm_sme::ArmSMETileType::ZAB:
91  return rewriter.create<arm_sme::aarch64_sme_st1b_horiz>(
92  loc, maskOp, ptr, tileId, tileSliceI32);
93  case arm_sme::ArmSMETileType::ZAH:
94  return rewriter.create<arm_sme::aarch64_sme_st1h_horiz>(
95  loc, maskOp, ptr, tileId, tileSliceI32);
96  case arm_sme::ArmSMETileType::ZAS:
97  return rewriter.create<arm_sme::aarch64_sme_st1w_horiz>(
98  loc, maskOp, ptr, tileId, tileSliceI32);
99  case arm_sme::ArmSMETileType::ZAD:
100  return rewriter.create<arm_sme::aarch64_sme_st1d_horiz>(
101  loc, maskOp, ptr, tileId, tileSliceI32);
102  case arm_sme::ArmSMETileType::ZAQ:
103  return rewriter.create<arm_sme::aarch64_sme_st1q_horiz>(
104  loc, maskOp, ptr, tileId, tileSliceI32);
105  }
106  } else {
107  switch (type) {
108  case arm_sme::ArmSMETileType::ZAB:
109  return rewriter.create<arm_sme::aarch64_sme_st1b_vert>(
110  loc, maskOp, ptr, tileId, tileSliceI32);
111  case arm_sme::ArmSMETileType::ZAH:
112  return rewriter.create<arm_sme::aarch64_sme_st1h_vert>(
113  loc, maskOp, ptr, tileId, tileSliceI32);
114  case arm_sme::ArmSMETileType::ZAS:
115  return rewriter.create<arm_sme::aarch64_sme_st1w_vert>(
116  loc, maskOp, ptr, tileId, tileSliceI32);
117  case arm_sme::ArmSMETileType::ZAD:
118  return rewriter.create<arm_sme::aarch64_sme_st1d_vert>(
119  loc, maskOp, ptr, tileId, tileSliceI32);
120  case arm_sme::ArmSMETileType::ZAQ:
121  return rewriter.create<arm_sme::aarch64_sme_st1q_vert>(
122  loc, maskOp, ptr, tileId, tileSliceI32);
123  }
124  }
125 }
126 
127 IntegerAttr getTileIdOrError(arm_sme::ArmSMETileOpInterface op) {
128  auto tileId = op.getTileId();
129  if (!tileId)
130  op.emitOpError(
131  "expected tile ID to be allocated before conversion to LLVM");
132  return tileId;
133 }
134 
135 /// Creates an alloca matching the size of tile used by `tileOp`. The alloca is
136 /// placed in the first block of the function.
137 static memref::AllocaOp
138 createAllocaForTile(RewriterBase &rewriter, Location loc,
139  FunctionOpInterface func,
140  arm_sme::ArmSMETileOpInterface tileOp) {
141  RewriterBase::InsertionGuard g(rewriter);
142  // Move to the first operation in the function.
143  rewriter.setInsertionPointToStart(&func.getBlocks().front());
144  // Create an alloca matching the tile size of the `tileOp`.
145  auto vscale = rewriter.create<vector::VectorScaleOp>(loc);
146  auto tileElementType = tileOp.getTileType().getElementType();
147  auto memrefType = MemRefType::get(
148  {ShapedType::kDynamic, ShapedType::kDynamic}, tileElementType);
149  unsigned minElements = arm_sme::getSMETileSliceMinNumElts(tileElementType);
150  auto minElementsOp =
151  rewriter.create<arith::ConstantIndexOp>(loc, minElements);
152  auto vectorLen = rewriter.create<arith::MulIOp>(loc, vscale, minElementsOp);
153  auto alloca = rewriter.create<memref::AllocaOp>(
154  loc, memrefType, ValueRange{vectorLen, vectorLen});
155  return alloca;
156 }
157 
158 /// Finds or creates an alloca for a spill of a tile.
159 static memref::AllocaOp getOrCreateAllocaForTile(
160  RewriterBase &rewriter, Location loc, FunctionOpInterface func,
161  arm_sme::ArmSMETileOpInterface tileOp, unsigned tileId) {
162  // Find an alloca at the top of the function tagged with a
163  // 'arm_sme.in_memory_tile_id' that matches `tileId`.
164  for (auto &op : func.getBlocks().front()) {
165  auto alloca = llvm::dyn_cast<memref::AllocaOp>(op);
166  if (!alloca)
167  continue;
168  auto inMemoryTileId = llvm::dyn_cast_or_null<IntegerAttr>(
169  alloca->getDiscardableAttr(kInMemoryTileIdAttr));
170  if (!inMemoryTileId)
171  continue;
172  if (inMemoryTileId.getInt() == tileId)
173  return alloca;
174  }
175  // Otherwise, create a new alloca:
176  auto alloca = createAllocaForTile(rewriter, loc, func, tileOp);
177  alloca->setDiscardableAttr(kInMemoryTileIdAttr,
178  rewriter.getI32IntegerAttr(tileId));
179  return alloca;
180 }
181 
182 /// Very naive lowering of in-memory tiles (i.e. tiles that were not assigned a
183 /// hardware tile ID) to ArmSME intrinsics. Currently, this works by assigning
184 /// the op to tile 0, then emitting a full tile swap between ZA and memory
185 /// before + after the tile op.
186 ///
187 /// Example:
188 ///
189 /// // Note: <IN MEMORY TILE> = tile ID >= 16.
190 /// arm_sme.tile_op { tile_id = <IN MEMORY TILE> }
191 ///
192 /// is converted to:
193 /// // At function entry:
194 /// %spill = memref.alloca ... : memref<?x?xty>
195 ///
196 /// // Around op:
197 /// scf.for %slice_idx {
198 /// %slice_to_save = "arm_sme.intr.read.horiz" ... <{tile_id = 0 : i32}>
199 /// "arm_sme.intr.ld1h.horiz"(%spill, %slice_idx) <{tile_id = 0 : i32}>
200 /// vector.store %slice_to_save, %spill[%slice_idx, %c0]
201 /// }
202 /// arm_sme.tile_op { tile_id = 0 }
203 /// scf.for %slice_idx {
204 /// %slice_to_save = "arm_sme.intr.read.horiz" ... <{tile_id = 0 : i32}>
205 /// "arm_sme.intr.ld1h.horiz"(%spill, %slice_idx) <{tile_id = 0 : i32}>
206 /// vector.store %slice_to_save, %spill[%slice_idx, %c0]
207 /// }
208 ///
209 /// Note that these spills/fills are not inserted earlier as concept of a
210 /// register, and the need to swap the contents, can't really be represented
211 /// correctly at a high level in MLIR.
212 ///
213 /// TODO: Reduce the spills/reloads to single slices where possible (and omit
214 /// redundant reloads). This could be done via a method on the
215 /// `ArmSMETileOpInterface` which returns how the operation uses ZA. E.g.:
216 ///
217 /// `tileOp.getZaUsage()` could return:
218 ///
219 /// struct ArmSMEOpZAUsage {
220 /// enum class Kind {
221 /// TileRead, // Omit store after tile operation.
222 /// TileWrite, // Omit load before tile operation.
223 /// TileReadWrite, // Needs both tile load and store.
224 /// SliceRead, // Spill single slice and omit store after operation.
225 /// SliceWrite, // Spill single slice and omit load before operation.
226 /// SliceReadWrite // Spill single slice.
227 /// };
228 /// Value sliceIndex {};
229 /// TileSliceLayout sliceLayout { TileSliceLayout::Horizontal };
230 /// };
231 ///
232 struct ConvertArmSMESpillsAndFillsToLLVM : public ConvertToLLVMPattern {
233 
234  ConvertArmSMESpillsAndFillsToLLVM(StringRef rootOpName,
235  const LLVMTypeConverter &typeConverter,
236  PatternBenefit benefit)
237  : ConvertToLLVMPattern(rootOpName, &typeConverter.getContext(),
238  typeConverter, benefit) {}
239 
241  matchAndRewrite(Operation *op, ArrayRef<Value> operands,
242  ConversionPatternRewriter &rewriter) const override {
243  auto tileOp = cast<arm_sme::ArmSMETileOpInterface>(op);
244  // Tile has a real (hardware) tile. No spills/reloads required.
245  if (!tileOp.isInMemoryTile())
246  return failure();
247 
248  // Step 1. Create an alloca for the tile at the top of the function (if one
249  // does not already exist).
250  auto loc = tileOp.getLoc();
251  auto func = tileOp->getParentOfType<FunctionOpInterface>();
252  auto tileAlloca = getOrCreateAllocaForTile(rewriter, loc, func, tileOp,
253  tileOp.getTileId().getInt());
254 
255  // Step 2. Assign the op a real tile ID.
256  // For simplicity, we always use tile 0 (which always exists).
257  auto zeroTileId = rewriter.getI32IntegerAttr(0);
258  rewriter.modifyOpInPlace(tileOp, [&] { tileOp.setTileId(zeroTileId); });
259 
260  VectorType tileVectorType = tileOp.getTileType();
261  auto sliceType = VectorType::Builder(tileVectorType).dropDim(0);
262  auto swapInMemoryTileWithSMETileZero = [&] {
263  emitFullTileSwap(rewriter, loc, tileAlloca,
264  *arm_sme::getSMETileType(tileVectorType), sliceType,
265  zeroTileId);
266  };
267 
268  // Step 3. Emit tile swaps before and after the op.
269  // TODO: Reduce the amount spilled to the amount of data the `tileOp`
270  // touches (i.e. a single tile slice).
271  {
272  rewriter.setInsertionPoint(op);
273  // Swap the contents of ZA and the in-memory tile before the op.
274  swapInMemoryTileWithSMETileZero();
275  rewriter.setInsertionPointAfter(op);
276  // Swap the tile back out to memory again after the op.
277  swapInMemoryTileWithSMETileZero();
278  }
279 
280  return success();
281  }
282 
283  /// Extracts a pointer to a slice of an in-memory tile.
284  Value getInMemoryTileSlicePtr(RewriterBase &rewriter, Location loc,
285  Value tileMemory, Value sliceIndex) const {
286  auto llvmType = getTypeConverter()->convertType(tileMemory.getType());
287  auto descriptor =
288  rewriter.create<UnrealizedConversionCastOp>(loc, llvmType, tileMemory);
289  auto zero = rewriter.create<arith::ConstantIntOp>(loc, 0, /*width=*/64);
290  auto sliceIndexI64 = rewriter.create<arith::IndexCastOp>(
291  loc, rewriter.getI64Type(), sliceIndex);
292  return getStridedElementPtr(
293  loc, llvm::cast<MemRefType>(tileMemory.getType()),
294  descriptor.getResult(0), {sliceIndexI64, zero},
295  static_cast<ConversionPatternRewriter &>(rewriter));
296  }
297 
298  /// Emits an in-place swap of a slice of a tile in ZA and a slice of a
299  /// tile-sized memref (`tileAlloca`).
300  void emitSliceSwap(RewriterBase &rewriter, Location loc, Value tileAlloca,
301  arm_sme::ArmSMETileType tileType, VectorType sliceType,
302  IntegerAttr tileId, Value sliceIndex) const {
303  // Cast the slice index to an i32.
304  auto sliceIndexI32 = rewriter.create<arith::IndexCastOp>(
305  loc, rewriter.getI32Type(), sliceIndex);
306  // Create an all-true predicate for the slice.
307  auto predicateType = sliceType.clone(rewriter.getI1Type());
308  auto allTruePredicate = rewriter.create<arith::ConstantOp>(
309  loc, DenseElementsAttr::get(predicateType, true));
310  // Create padding vector (never used due to all-true predicate).
311  auto padVector = rewriter.create<LLVM::UndefOp>(loc, sliceType);
312  // Get a pointer to the current slice.
313  auto slicePtr =
314  getInMemoryTileSlicePtr(rewriter, loc, tileAlloca, sliceIndex);
315  // Read the value of the current slice from ZA.
316  auto currentTileSlice = rewriter.create<arm_sme::aarch64_sme_read_horiz>(
317  loc, sliceType, padVector, allTruePredicate, tileId, sliceIndexI32);
318  // Load the new tile slice back from memory into ZA.
319  createLoadTileSliceIntrinsic(
320  rewriter, loc, tileType, arm_sme::TileSliceLayout::Horizontal,
321  allTruePredicate, slicePtr, tileId, sliceIndexI32);
322  // Store the current tile slice to memory.
323  auto zero = rewriter.create<arith::ConstantIndexOp>(loc, 0);
324  rewriter.create<vector::StoreOp>(loc, currentTileSlice, tileAlloca,
325  ValueRange{sliceIndex, zero});
326  }
327 
328  /// Emits a full in-place swap of the contents of a tile in ZA and a
329  /// tile-sized memref (`tileAlloca`).
330  void emitFullTileSwap(RewriterBase &rewriter, Location loc, Value tileAlloca,
331  arm_sme::ArmSMETileType tileType, VectorType sliceType,
332  IntegerAttr tileId) const {
333  RewriterBase::InsertionGuard guard(rewriter);
334  // Create an scf.for over all tile slices.
335  auto minNumElts =
336  rewriter.create<arith::ConstantIndexOp>(loc, sliceType.getDimSize(0));
337  auto lowerBound = rewriter.create<arith::ConstantIndexOp>(loc, 0);
338  auto upperBound = rewriter.create<arith::MulIOp>(
339  loc, minNumElts, rewriter.create<vector::VectorScaleOp>(loc));
340  auto step = rewriter.create<arith::ConstantIndexOp>(loc, 1);
341  auto forOp = rewriter.create<scf::ForOp>(loc, lowerBound, upperBound, step);
342  // Emit a swap for each tile slice.
343  rewriter.setInsertionPointToStart(forOp.getBody());
344  auto sliceIndex = forOp.getInductionVar();
345  emitSliceSwap(rewriter, loc, tileAlloca, tileType, sliceType, tileId,
346  sliceIndex);
347  }
348 };
349 
350 enum class RequiresSpillsAndFills { Yes, No };
351 
352 /// Base class for ArmSME to LLVM conversion patterns. By default, this adds
353 /// spills and fills around ArmSME ops that use in-memory tile IDs. This can be
354 /// disabled by setting the `requiresSpillsAndFills` template parameter to
355 /// `RequiresSpillsAndFills::No`.
356 template <typename SourceOp, RequiresSpillsAndFills requiresSpillsAndFills =
357  RequiresSpillsAndFills::Yes>
358 struct ConvertArmSMEOpToLLVMPattern : ConvertOpToLLVMPattern<SourceOp> {
359  using ArmSMEOp = SourceOp;
361 
362  static constexpr bool requiresSpillsAndFillsConversion() {
363  return requiresSpillsAndFills == RequiresSpillsAndFills::Yes;
364  }
365 };
366 
367 template <typename Pattern>
368 static void addArmSMEConversionPattern(RewritePatternSet &patterns,
369  LLVMTypeConverter const &typeConverter) {
370  // Register spills/fills for ops that implement the
371  // `ArmSMETileOpInterface` and have `requiresSpillsAndFills` set to
372  // `RequiresSpillsAndFills::Yes`.
373  if constexpr (Pattern::requiresSpillsAndFillsConversion() &&
374  std::is_base_of_v<arm_sme::ArmSMETileOpInterface::Trait<
375  typename Pattern::ArmSMEOp>,
376  typename Pattern::ArmSMEOp>) {
377  // Add spill/fill conversions with a very high benefit to ensure
378  // they are lowered first.
379  patterns.add<ConvertArmSMESpillsAndFillsToLLVM>(
380  Pattern::ArmSMEOp::getOperationName(), typeConverter,
381  /*benefit=*/1337);
382  }
383  patterns.add<Pattern>(typeConverter);
384 }
385 
386 /// Helper to register `ConvertArmSMEOpToLLVMPattern` patterns.
387 template <typename... Patterns>
388 static void
389 addArmSMEConversionPatterns(RewritePatternSet &patterns,
390  LLVMTypeConverter const &typeConverter) {
391  (addArmSMEConversionPattern<Patterns>(patterns, typeConverter), ...);
392 }
393 
394 struct GetTileConversion
395  : public ConvertArmSMEOpToLLVMPattern<arm_sme::GetTileOp,
396  RequiresSpillsAndFills::No> {
397  using ConvertArmSMEOpToLLVMPattern::ConvertArmSMEOpToLLVMPattern;
398 
400  matchAndRewrite(arm_sme::GetTileOp getTile, OpAdaptor,
401  ConversionPatternRewriter &rewriter) const override {
402  rewriter.replaceOpWithNewOp<arm_sme::MaterializeSSATileOp>(
403  getTile, getTile.getTileType());
404  return success();
405  }
406 };
407 
408 /// Lower 'arm_sme.zero' to SME intrinsics.
409 ///
410 /// BEFORE:
411 /// ```mlir
412 /// %v = arm_sme.zero {tile_id = 0 : i32} : vector<[4]x[4]xi32>
413 /// ```
414 ///
415 /// AFTER:
416 /// ```mlir
417 /// "arm_sme.intr.zero"() <{tile_mask = 17 : i32}> : () -> ()
418 /// %v = arm_sme.materialize_ssa_tile : vector<[4]x[4]xi32>
419 /// ```
420 ///
421 /// The 'arm_sme.materialize_ssa_tile' (which models the return) will fold away
422 /// once all ArmSME ops have been converted to LLVM intrinsics.
423 struct ZeroOpConversion : public ConvertArmSMEOpToLLVMPattern<arm_sme::ZeroOp> {
424  using ConvertArmSMEOpToLLVMPattern::ConvertArmSMEOpToLLVMPattern;
425 
427  matchAndRewrite(arm_sme::ZeroOp zero, OpAdaptor adaptor,
428  ConversionPatternRewriter &rewriter) const override {
429  auto loc = zero.getLoc();
430 
431  auto tileId = getTileIdOrError(zero);
432  if (!tileId)
433  return failure();
434 
435  // Get the base mask for tile based on the element size.
436  // The base mask is just the mask to zero the first tile (of a size).
437  // These masks are derived from:
438  // https://developer.arm.com/documentation/ddi0602/2022-06/SME-Instructions/ZERO--Zero-a-list-of-64-bit-element-ZA-tiles-
439  arm_sme::ArmSMETileType tileType = *zero.getAllocatedTileType();
440  auto baseMaskForSize = [&] {
441  switch (tileType) {
442  case arm_sme::ArmSMETileType::ZAB:
443  // Zeroing the 8-bit ZA0.B tile is equivalent to zeroing all eight
444  // 64-bit element tiles named ZA0.D to ZA7.D.
445  return 0b1111'1111;
446  case arm_sme::ArmSMETileType::ZAH:
447  // Zeroing the 16-bit ZA0.H tile is equivalent to zeroing 64-bit
448  // element tiles named ZA0.D, ZA2.D, ZA4.D, and ZA6.D. Shift this left
449  // once for ZA1.H.
450  return 0b0101'0101;
451  case arm_sme::ArmSMETileType::ZAS:
452  // Zeroing the 32-bit ZA0.S tile is equivalent to zeroing 64-bit
453  // element tiles named ZA0.D and ZA4.D.
454  // Shift left by 1, 2, or 3 respectively for ZA1.S, ZA2.S, ZA3.S.
455  return 0b0001'0001;
456  case arm_sme::ArmSMETileType::ZAD:
457  // Zeroing one of the a 64-bit tiles ZA0.D to ZA7.D just requires
458  // setting the bit for that tile.
459  return 0b0000'0001;
460  default:
461  llvm_unreachable("bad element size");
462  }
463  }();
464 
465  // The actual mask is just the base mask shifted by the tile ID.
466  // This will be folded to a constant after tile allocation.
467  //
468  // The shift is just derived from the layout of the tiles, and that the tile
469  // ID is the index of the tile. For example, looking at the 32-bit ZAx.S
470  // tiles:
471  //
472  // ZA0.S = ZA0.D and ZA4.D
473  // * Tile ID -> 0
474  // * Mask -> 00010001 = (00010001 << 0)
475  // ZA1.S = ZA1.D and ZA5.D
476  // * Tile ID -> 1
477  // * Mask -> 00100010 = (00010001 << 1)
478  // ZA2.S = ZA2.D and ZA6.D
479  // * Tile ID -> 2
480  // * Mask -> 01000100 = (00010001 << 2)
481  // ZA3.S = ZA3.D and ZA7.D
482  // * Tile ID -> 3
483  // * Mask -> 10001000 = (00010001 << 3)
484  //
485  // This holds for all tile sizes.
486  int32_t zeroMask = baseMaskForSize << int32_t(tileId.getInt());
487  rewriter.create<arm_sme::aarch64_sme_zero>(
488  loc, rewriter.getI32IntegerAttr(zeroMask));
489 
490  // Create a placeholder op to preserve dataflow.
491  rewriter.replaceOpWithNewOp<arm_sme::MaterializeSSATileOp>(
492  zero, zero.getVectorType());
493 
494  return success();
495  }
496 };
497 
498 /// Lower `arm_sme.load_tile_slice` to SME intrinsics.
499 struct LoadTileSliceConversion
500  : public ConvertArmSMEOpToLLVMPattern<arm_sme::LoadTileSliceOp> {
501  using ConvertArmSMEOpToLLVMPattern::ConvertArmSMEOpToLLVMPattern;
502 
504  matchAndRewrite(arm_sme::LoadTileSliceOp loadTileSliceOp,
505  arm_sme::LoadTileSliceOp::Adaptor adaptor,
506  ConversionPatternRewriter &rewriter) const override {
507  auto loc = loadTileSliceOp.getLoc();
508  auto tileId = getTileIdOrError(loadTileSliceOp);
509  if (!tileId)
510  return failure();
511 
512  Value ptr = this->getStridedElementPtr(loc, loadTileSliceOp.getMemRefType(),
513  adaptor.getBase(),
514  adaptor.getIndices(), rewriter);
515 
516  auto tileSlice = loadTileSliceOp.getTileSliceIndex();
517 
518  // Cast tile slice to i32 for intrinsic.
519  auto tileSliceI32 = rewriter.create<arith::IndexCastUIOp>(
520  loc, rewriter.getI32Type(), tileSlice);
521 
522  // Create all active predicate mask.
523  auto maskOp = loadTileSliceOp.getMask();
524 
525  auto tileVectorType = loadTileSliceOp.getVectorType();
526  arm_sme::ArmSMETileType tileType = *arm_sme::getSMETileType(tileVectorType);
527  arm_sme::TileSliceLayout layout = loadTileSliceOp.getLayout();
528 
529  // Create 'arm_sme.intr.ld1*.(horiz|vert)' intrinsic to load ZA tile slice.
530  createLoadTileSliceIntrinsic(rewriter, loc, tileType, layout, maskOp, ptr,
531  tileId, tileSliceI32);
532 
533  // The load intrinsics have no result, replace 'arm_sme.tile_load' with
534  // the input tile to preserve dataflow.
535  rewriter.replaceOp(loadTileSliceOp, loadTileSliceOp.getTile());
536 
537  return success();
538  }
539 };
540 
541 /// Lower for `arm_sme.store_tile_slice` to SME intrinsics.
542 struct StoreTileSliceConversion
543  : public ConvertArmSMEOpToLLVMPattern<arm_sme::StoreTileSliceOp> {
544  using ConvertArmSMEOpToLLVMPattern::ConvertArmSMEOpToLLVMPattern;
545 
547  matchAndRewrite(arm_sme::StoreTileSliceOp storeTileSliceOp,
548  arm_sme::StoreTileSliceOp::Adaptor adaptor,
549  ConversionPatternRewriter &rewriter) const override {
550  auto loc = storeTileSliceOp.getLoc();
551  auto tileVectorType = storeTileSliceOp.getVectorType();
552 
553  auto tileId = getTileIdOrError(storeTileSliceOp);
554  if (!tileId)
555  return failure();
556 
557  // Create 'arm_sme.intr.st1*.horiz' intrinsic to store ZA tile slice.
558  Value ptr = this->getStridedElementPtr(
559  loc, storeTileSliceOp.getMemRefType(), adaptor.getBase(),
560  adaptor.getIndices(), rewriter);
561 
562  auto tileSlice = storeTileSliceOp.getTileSliceIndex();
563 
564  // Cast tile slice to i32 for intrinsic.
565  auto tileSliceI32 = rewriter.create<arith::IndexCastUIOp>(
566  loc, rewriter.getI32Type(), tileSlice);
567 
568  auto maskOp = storeTileSliceOp.getMask();
569 
570  arm_sme::TileSliceLayout layout = storeTileSliceOp.getLayout();
571  arm_sme::ArmSMETileType tileType = *arm_sme::getSMETileType(tileVectorType);
572 
573  rewriter.replaceOp(storeTileSliceOp,
574  createStoreTileSliceIntrinsic(rewriter, loc, tileType,
575  layout, maskOp, ptr,
576  tileId, tileSliceI32));
577 
578  return success();
579  }
580 };
581 
582 /// Lower `arm_sme.move_vector_to_tile_slice` to SME intrinsics.
583 struct MoveVectorToTileSliceConversion
584  : public ConvertArmSMEOpToLLVMPattern<arm_sme::MoveVectorToTileSliceOp> {
585  using ConvertArmSMEOpToLLVMPattern::ConvertArmSMEOpToLLVMPattern;
586 
588  matchAndRewrite(arm_sme::MoveVectorToTileSliceOp moveVectorToTileSliceOp,
589  arm_sme::MoveVectorToTileSliceOp::Adaptor adaptor,
590  ConversionPatternRewriter &rewriter) const override {
591  auto loc = moveVectorToTileSliceOp.getLoc();
592  auto tileType = moveVectorToTileSliceOp.getTileType();
593 
594  auto tileId = getTileIdOrError(moveVectorToTileSliceOp);
595  if (!tileId)
596  return failure();
597 
598  auto tileSlice = moveVectorToTileSliceOp.getTileSliceIndex();
599 
600  // Cast tile slice from index to i32 for intrinsic.
601  auto tileSliceI32 = rewriter.create<arith::IndexCastUIOp>(
602  loc, rewriter.getI32Type(), tileSlice);
603 
604  // Create all active predicate mask.
605  auto one = rewriter.create<arith::ConstantOp>(
606  loc, rewriter.getI1Type(),
607  rewriter.getIntegerAttr(rewriter.getI1Type(), 1));
608  auto predTy = VectorType::get(tileType.getShape()[0], rewriter.getI1Type(),
609  /*scalableDims=*/{true});
610  auto allActiveMask = rewriter.create<vector::SplatOp>(loc, predTy, one);
611 
612  // Create 'arm_sme.intr.write.(horiz|vert)' to write vector to tile slice.
613  switch (moveVectorToTileSliceOp.getLayout()) {
614  case arm_sme::TileSliceLayout::Horizontal:
615  rewriter.create<arm_sme::aarch64_sme_write_horiz>(
616  loc, tileId, tileSliceI32, allActiveMask,
617  moveVectorToTileSliceOp.getVector());
618  break;
619  case arm_sme::TileSliceLayout::Vertical:
620  rewriter.create<arm_sme::aarch64_sme_write_vert>(
621  loc, tileId, tileSliceI32, allActiveMask,
622  moveVectorToTileSliceOp.getVector());
623  break;
624  }
625 
626  // Intrinsic has no result, replace 'arm_sme.move_vector_to_tile_slice' with
627  // the input tile to preserve dataflow.
628  rewriter.replaceOp(moveVectorToTileSliceOp,
629  moveVectorToTileSliceOp.getTile());
630 
631  return success();
632  }
633 };
634 
635 /// Lower `arm_sme.move_tile_slice_to_vector` to SME intrinsics.
636 struct MoveTileSliceToVectorConversion
637  : public ConvertArmSMEOpToLLVMPattern<arm_sme::MoveTileSliceToVectorOp> {
638  using ConvertArmSMEOpToLLVMPattern::ConvertArmSMEOpToLLVMPattern;
639 
641  matchAndRewrite(arm_sme::MoveTileSliceToVectorOp moveTileSliceToVector,
642  OpAdaptor,
643  ConversionPatternRewriter &rewriter) const override {
644  auto loc = moveTileSliceToVector.getLoc();
645  auto sliceType = moveTileSliceToVector.getSliceType();
646  auto sliceIndex = moveTileSliceToVector.getTileSliceIndex();
647 
648  auto tileId = getTileIdOrError(moveTileSliceToVector);
649  if (!tileId)
650  return failure();
651 
652  // Create an 'all true' predicate for the tile slice.
653  auto predicateType = sliceType.cloneWith({}, rewriter.getI1Type());
654  auto allTruePredicate = rewriter.create<arith::ConstantOp>(
655  loc, DenseElementsAttr::get(predicateType, true));
656 
657  // Zero destination/fallback for tile slice extraction.
658  auto zeroVector = rewriter.create<arith::ConstantOp>(
659  loc, sliceType, rewriter.getZeroAttr(sliceType));
660 
661  // Cast tile slice from index to i32 for intrinsic.
662  auto sliceIndexI32 = rewriter.create<arith::IndexCastOp>(
663  loc, rewriter.getI32Type(), sliceIndex);
664 
665  // Create 'arm_sme.intr.read.(horiz|vert)' to extract the tile slice.
666  switch (moveTileSliceToVector.getLayout()) {
667  case arm_sme::TileSliceLayout::Horizontal:
668  rewriter.replaceOpWithNewOp<arm_sme::aarch64_sme_read_horiz>(
669  moveTileSliceToVector, sliceType, zeroVector, allTruePredicate,
670  tileId, sliceIndexI32);
671  break;
672  case arm_sme::TileSliceLayout::Vertical:
673  rewriter.replaceOpWithNewOp<arm_sme::aarch64_sme_read_vert>(
674  moveTileSliceToVector, sliceType, zeroVector, allTruePredicate,
675  tileId, sliceIndexI32);
676  break;
677  }
678 
679  return success();
680  }
681 };
682 
683 /// Lower `arm_sme.outerproduct` to SME MOPA intrinsics.
684 ///
685 /// Example:
686 ///
687 /// %0 = arm_sme.outerproduct %lhs, %rhs acc(%acc)
688 /// : vector<[4]xf32>, vector<[4]xf32>
689 ///
690 /// is converted to:
691 ///
692 /// "arm_sme.intr.mopa"(%ptrue_s, %ptrue_s, %lhs, %rhs) <{tile_id = 0 : i32}>
693 /// : (vector<[4]xi1>, vector<[4]xi1>, vector<[4]xf32>,
694 /// vector<[4]xf32>) -> ()
695 ///
696 /// Currently only supports FMOPA and BFMOPA (non-widening).
697 struct OuterProductOpConversion
698  : public ConvertArmSMEOpToLLVMPattern<arm_sme::OuterProductOp> {
699  using ConvertArmSMEOpToLLVMPattern::ConvertArmSMEOpToLLVMPattern;
700 
702  matchAndRewrite(arm_sme::OuterProductOp outerProductOp,
703  arm_sme::OuterProductOp::Adaptor adaptor,
704  ConversionPatternRewriter &rewriter) const override {
705  auto tileId = getTileIdOrError(outerProductOp);
706  if (!tileId)
707  return failure();
708 
709  auto isSupportedType = [](VectorType vectorType) {
710  // TODO: the FP outer product instruction variants are predicated on
711  // different features [1]:
712  //
713  // * FMOPA (non-widening)
714  // * half-precision - +sme2p1,+sme-f16f16
715  // * single-precision - +sme
716  // * double-precision - +sme-f64f64
717  // * BFMOPA
718  // * half-precision - +sme2p1,+b16b16
719  //
720  // It should be possible to control lowering based on target features.
721  // [1]
722  // https://developer.arm.com/downloads/-/exploration-tools/feature-names-for-a-profile
723  if ((vectorType.getRank() != 2) || !vectorType.allDimsScalable())
724  return false;
725 
726  auto elementType = vectorType.getElementType();
727 
728  if (!elementType.isF16() && !elementType.isBF16() &&
729  !elementType.isF32() && !elementType.isF64())
730  return false;
731 
732  unsigned minNumElts = arm_sme::MinStreamingVectorLengthInBits /
733  vectorType.getElementTypeBitWidth();
734  return vectorType.getShape() ==
735  ArrayRef<int64_t>({minNumElts, minNumElts});
736  };
737 
738  // TODO: Support CombiningKind::Sub for outer products.
739  if (outerProductOp.getKind() != arm_sme::CombiningKind::Add)
740  return outerProductOp.emitError("unsupported kind");
741 
742  auto resultVectorType = outerProductOp.getResultType();
743  if (!isSupportedType(resultVectorType))
744  return outerProductOp.emitError("unsupported type");
745 
746  auto loc = outerProductOp.getLoc();
747 
748  Value acc = outerProductOp.getAcc();
749  if (!acc)
750  // Initalize accumulator with zero.
751  acc = outerProductOp.createOpAndForwardTileId<arm_sme::ZeroOp>(
752  rewriter, loc, resultVectorType);
753 
754  Value lhsMask = outerProductOp.getLhsMask();
755  Value rhsMask = outerProductOp.getRhsMask();
756 
757  if (!lhsMask || !rhsMask) {
758  auto predTy =
759  outerProductOp.getLhsType().cloneWith({}, rewriter.getI1Type());
760  Value allActiveMask = rewriter.create<arith::ConstantOp>(
761  loc, DenseElementsAttr::get(predTy, true));
762  lhsMask = allActiveMask;
763  rhsMask = allActiveMask;
764  }
765 
766  // Create 'arm_sme.intr.mopa' outer product intrinsic.
767  rewriter.create<arm_sme::aarch64_sme_mopa>(loc, tileId, lhsMask, rhsMask,
768  outerProductOp.getLhs(),
769  outerProductOp.getRhs());
770 
771  // The outerproduct intrinsics have no result, replace
772  // 'arm_sme.outerproduct' with the input tile to preserve dataflow.
773  rewriter.replaceOp(outerProductOp, acc);
774 
775  return success();
776  }
777 };
778 
779 /// Lower 2-way and 4-way widening outer products to intrinsics.
780 template <class OuterProductWideningOp, class OuterProductWideningIntrOp>
781 struct OuterProductWideningOpConversion
782  : public ConvertArmSMEOpToLLVMPattern<OuterProductWideningOp> {
783  using ConvertArmSMEOpToLLVMPattern<
784  OuterProductWideningOp>::ConvertArmSMEOpToLLVMPattern;
785 
787  matchAndRewrite(OuterProductWideningOp op,
788  typename OuterProductWideningOp::Adaptor adaptor,
789  ConversionPatternRewriter &rewriter) const override {
790  auto tileId = getTileIdOrError(op);
791  if (!tileId)
792  return failure();
793 
794  Value acc = op.getAcc();
795  if (!acc)
796  // Initalize accumulator with zero.
797  acc = op.template createOpAndForwardTileId<arm_sme::ZeroOp>(
798  rewriter, op.getLoc(), op.getResultType());
799 
800  Value lhsMask = op.getLhsMask();
801  Value rhsMask = op.getRhsMask();
802  if (!lhsMask || !rhsMask) {
803  auto predTy = op.getLhsType().cloneWith({}, rewriter.getI1Type());
804  Value allActiveMask = rewriter.create<arith::ConstantOp>(
805  op.getLoc(), DenseElementsAttr::get(predTy, true));
806  lhsMask = allActiveMask;
807  rhsMask = allActiveMask;
808  }
809 
810  rewriter.create<OuterProductWideningIntrOp>(op.getLoc(), tileId, lhsMask,
811  rhsMask, adaptor.getLhs(),
812  adaptor.getRhs());
813 
814  // The outerproduct intrinsics have no result, replace
815  // 'arm_sme.outerproduct' with the input tile to preserve dataflow.
816  rewriter.replaceOp(op, acc);
817 
818  return success();
819  }
820 };
821 
822 /// Lower `arm_sme.streaming_vl` to SME CNTS intrinsics.
823 ///
824 /// Example:
825 ///
826 /// %0 = arm_sme.streaming_vl <half>
827 ///
828 /// is converted to:
829 ///
830 /// %cnt = "arm_sme.intr.cntsh"() : () -> i64
831 /// %0 = arith.index_cast %cnt : i64 to index
832 ///
833 struct StreamingVLOpConversion
834  : public ConvertArmSMEOpToLLVMPattern<arm_sme::StreamingVLOp,
835  RequiresSpillsAndFills::No> {
836  using ConvertArmSMEOpToLLVMPattern::ConvertArmSMEOpToLLVMPattern;
837 
839  matchAndRewrite(arm_sme::StreamingVLOp streamingVlOp,
840  arm_sme::StreamingVLOp::Adaptor adaptor,
841  ConversionPatternRewriter &rewriter) const override {
842  auto loc = streamingVlOp.getLoc();
843  auto i64Type = rewriter.getI64Type();
844  auto *intrOp = [&]() -> Operation * {
845  switch (streamingVlOp.getTypeSize()) {
846  case arm_sme::TypeSize::Byte:
847  return rewriter.create<arm_sme::aarch64_sme_cntsb>(loc, i64Type);
848  case arm_sme::TypeSize::Half:
849  return rewriter.create<arm_sme::aarch64_sme_cntsh>(loc, i64Type);
850  case arm_sme::TypeSize::Word:
851  return rewriter.create<arm_sme::aarch64_sme_cntsw>(loc, i64Type);
852  case arm_sme::TypeSize::Double:
853  return rewriter.create<arm_sme::aarch64_sme_cntsd>(loc, i64Type);
854  }
855  }();
856  rewriter.replaceOpWithNewOp<arith::IndexCastOp>(
857  streamingVlOp, rewriter.getIndexType(), intrOp->getResult(0));
858  return success();
859  }
860 };
861 
862 } // namespace
863 
864 namespace {
865 
866 struct ConvertArmSMEToLLVMPass
867  : public impl::ConvertArmSMEToLLVMBase<ConvertArmSMEToLLVMPass> {
868  void runOnOperation() override {
870  RewritePatternSet patterns(&getContext());
871  LLVMTypeConverter converter(&getContext());
873  populateArmSMEToLLVMConversionPatterns(converter, patterns);
874 
875  if (failed(applyPartialConversion(getOperation(), target,
876  std::move(patterns))))
877  signalPassFailure();
878  }
879 };
880 
881 } // namespace
882 
884  target.addIllegalDialect<arm_sme::ArmSMEDialect>();
885  target.addLegalOp<
886  arm_sme::MaterializeSSATileOp, arm_sme::aarch64_sme_zero,
887  arm_sme::aarch64_sme_str, arm_sme::aarch64_sme_ld1b_horiz,
888  arm_sme::aarch64_sme_ld1h_horiz, arm_sme::aarch64_sme_ld1w_horiz,
889  arm_sme::aarch64_sme_ld1d_horiz, arm_sme::aarch64_sme_ld1q_horiz,
890  arm_sme::aarch64_sme_st1b_horiz, arm_sme::aarch64_sme_st1h_horiz,
891  arm_sme::aarch64_sme_st1w_horiz, arm_sme::aarch64_sme_st1d_horiz,
892  arm_sme::aarch64_sme_st1q_horiz, arm_sme::aarch64_sme_ld1b_vert,
893  arm_sme::aarch64_sme_ld1h_vert, arm_sme::aarch64_sme_ld1w_vert,
894  arm_sme::aarch64_sme_ld1d_vert, arm_sme::aarch64_sme_ld1q_vert,
895  arm_sme::aarch64_sme_st1b_vert, arm_sme::aarch64_sme_st1h_vert,
896  arm_sme::aarch64_sme_st1w_vert, arm_sme::aarch64_sme_st1d_vert,
897  arm_sme::aarch64_sme_st1q_vert, arm_sme::aarch64_sme_read_horiz,
898  arm_sme::aarch64_sme_read_vert, arm_sme::aarch64_sme_write_horiz,
899  arm_sme::aarch64_sme_write_vert, arm_sme::aarch64_sme_mopa,
900  arm_sme::aarch64_sme_mopa_wide, arm_sme::aarch64_sme_mops_wide,
901  arm_sme::aarch64_sme_smopa_wide, arm_sme::aarch64_sme_smops_wide,
902  arm_sme::aarch64_sme_umopa_wide, arm_sme::aarch64_sme_umops_wide,
903  arm_sme::aarch64_sme_smopa_za32, arm_sme::aarch64_sme_smops_za32,
904  arm_sme::aarch64_sme_umopa_za32, arm_sme::aarch64_sme_umops_za32,
905  arm_sme::aarch64_sme_sumopa_wide, arm_sme::aarch64_sme_sumops_wide,
906  arm_sme::aarch64_sme_usmopa_wide, arm_sme::aarch64_sme_usmops_wide,
907  arm_sme::aarch64_sme_cntsb, arm_sme::aarch64_sme_cntsh,
908  arm_sme::aarch64_sme_cntsw, arm_sme::aarch64_sme_cntsd>();
909  target.addLegalDialect<arith::ArithDialect,
910  /* The following are used to lower tile spills/fills */
911  vector::VectorDialect, scf::SCFDialect,
912  memref::MemRefDialect>();
913  target.addLegalOp<UnrealizedConversionCastOp>();
914 }
915 
917  RewritePatternSet &patterns) {
918  converter.addConversion([&](VectorType type) -> std::optional<Type> {
919  // There's no LLVM type for SME tiles, but after lowering to intrinsics all
920  // SME vector types should be eliminated.
922  return type;
923  return std::nullopt;
924  });
925 
926  addArmSMEConversionPatterns<
927  LoadTileSliceConversion, MoveTileSliceToVectorConversion,
928  MoveVectorToTileSliceConversion, StoreTileSliceConversion,
929  StreamingVLOpConversion, OuterProductOpConversion,
930  OuterProductWideningOpConversion<arm_sme::FMopa2WayOp,
931  arm_sme::aarch64_sme_mopa_wide>,
932  OuterProductWideningOpConversion<arm_sme::FMops2WayOp,
933  arm_sme::aarch64_sme_mops_wide>,
934  OuterProductWideningOpConversion<arm_sme::SMopa2WayOp,
935  arm_sme::aarch64_sme_smopa_za32>,
936  OuterProductWideningOpConversion<arm_sme::SMops2WayOp,
937  arm_sme::aarch64_sme_smops_za32>,
938  OuterProductWideningOpConversion<arm_sme::UMopa2WayOp,
939  arm_sme::aarch64_sme_umopa_za32>,
940  OuterProductWideningOpConversion<arm_sme::UMops2WayOp,
941  arm_sme::aarch64_sme_umops_za32>,
942  OuterProductWideningOpConversion<arm_sme::SMopa4WayOp,
943  arm_sme::aarch64_sme_smopa_wide>,
944  OuterProductWideningOpConversion<arm_sme::SMops4WayOp,
945  arm_sme::aarch64_sme_smops_wide>,
946  OuterProductWideningOpConversion<arm_sme::UMopa4WayOp,
947  arm_sme::aarch64_sme_umopa_wide>,
948  OuterProductWideningOpConversion<arm_sme::UMops4WayOp,
949  arm_sme::aarch64_sme_umops_wide>,
950  OuterProductWideningOpConversion<arm_sme::SuMopa4WayOp,
951  arm_sme::aarch64_sme_sumopa_wide>,
952  OuterProductWideningOpConversion<arm_sme::SuMops4WayOp,
953  arm_sme::aarch64_sme_sumops_wide>,
954  OuterProductWideningOpConversion<arm_sme::UsMopa4WayOp,
955  arm_sme::aarch64_sme_usmopa_wide>,
956  OuterProductWideningOpConversion<arm_sme::UsMops4WayOp,
957  arm_sme::aarch64_sme_usmops_wide>,
958  ZeroOpConversion, GetTileConversion>(patterns, converter);
959 }
960 
961 std::unique_ptr<Pass> mlir::createConvertArmSMEToLLVMPass() {
962  return std::make_unique<ConvertArmSMEToLLVMPass>();
963 }
static MLIRContext * getContext(OpFoldResult val)
IntegerAttr getI32IntegerAttr(int32_t value)
Definition: Builders.cpp:216
IntegerAttr getIntegerAttr(Type type, int64_t value)
Definition: Builders.cpp:238
IntegerType getI64Type()
Definition: Builders.cpp:85
IntegerType getI32Type()
Definition: Builders.cpp:83
TypedAttr getZeroAttr(Type type)
Definition: Builders.cpp:331
IntegerType getI1Type()
Definition: Builders.cpp:73
IndexType getIndexType()
Definition: Builders.cpp:71
This class implements a pattern rewriter for use with ConversionPatterns.
void replaceOp(Operation *op, ValueRange newValues) override
PatternRewriter hook for replacing an operation.
This class describes a specific conversion target.
void addLegalOp(OperationName op)
Register the given operations as legal.
void addLegalDialect(StringRef name, Names... names)
Register the operations of the given dialects as legal.
void addIllegalDialect(StringRef name, Names... names)
Register the operations of the given dialects as illegal, i.e.
Utility class for operation conversions targeting the LLVM dialect that match exactly one source oper...
Definition: Pattern.h:143
Base class for operation conversions targeting the LLVM IR dialect.
Definition: Pattern.h:41
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
Derived class that automatically populates legalization information for different LLVM ops.
Conversion from types to the LLVM IR dialect.
Definition: TypeConverter.h:34
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:63
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
Definition: Builders.h:433
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Definition: Builders.h:400
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:464
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
Definition: Builders.h:414
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
Operation * clone(IRMapping &mapper, CloneOptions options=CloneOptions::all())
Create a deep copy of this operation, remapping any operands that use values outside of the operation...
Definition: Operation.cpp:717
Location getLoc()
The source location the operation was defined or derived from.
Definition: Operation.h:223
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
Definition: Operation.cpp:671
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
Definition: PatternMatch.h:34
This class contains all of the data related to a pattern, but does not contain any methods or logic f...
Definition: PatternMatch.h:73
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
Definition: PatternMatch.h:846
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
Definition: PatternMatch.h:400
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
Definition: PatternMatch.h:630
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Definition: PatternMatch.h:536
void addConversion(FnT &&callback)
Register a conversion function.
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:381
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:129
This is a builder type that keeps local references to arguments.
Definition: BuiltinTypes.h:305
Builder & dropDim(unsigned pos)
Erase a dim from shape @pos.
Definition: BuiltinTypes.h:330
std::optional< ArmSMETileType > getSMETileType(VectorType)
Returns the type of SME tile this vector type corresponds to, or none if the vector type does not fit...
Definition: Utils.cpp:44
unsigned getSMETileSliceMinNumElts(Type type)
Return minimum number of elements for the given element type in a vector of SVL bits.
Definition: Utils.cpp:18
bool isValidSMETileVectorType(VectorType vType)
Returns true if vType is a valid vector type for an SME tile or false otherwise.
Definition: Utils.cpp:29
constexpr unsigned MinStreamingVectorLengthInBits
Definition: Utils.h:31
Include the generated interface declarations.
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:62
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:56
std::unique_ptr< Pass > createConvertArmSMEToLLVMPass()
Create a pass to convert from the ArmSME dialect to LLVM intrinsics.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
void configureArmSMEToLLVMConversionLegality(ConversionTarget &target)
Configure target to convert from the ArmSME dialect to LLVM intrinsics.
LogicalResult applyPartialConversion(ArrayRef< Operation * > ops, const ConversionTarget &target, const FrozenRewritePatternSet &patterns, ConversionConfig config=ConversionConfig())
Below we define several entry points for operation conversion.
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value.
Definition: LogicalResult.h:72
void populateArmSMEToLLVMConversionPatterns(LLVMTypeConverter &converter, RewritePatternSet &patterns)
Populate the given list with patterns that convert from the ArmSME dialect to LLVM intrinsics.
This class represents an efficient way to signal success or failure.
Definition: LogicalResult.h:26