MLIR 22.0.0git
ArmSMEToLLVM.cpp
Go to the documentation of this file.
1//===- ArmSMEToLLVM.cpp - Convert ArmSME to LLVM dialect ------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements lowering of ArmSME operations to LLVM intrinsics.
10//
11//===----------------------------------------------------------------------===//
12
14
25#include "mlir/Pass/Pass.h"
27#include "llvm/ADT/ScopeExit.h"
28
29namespace mlir {
30#define GEN_PASS_DEF_CONVERTARMSMETOLLVM
31#include "mlir/Conversion/Passes.h.inc"
32} // namespace mlir
33
34using namespace mlir;
35
36namespace {
37
38static constexpr StringLiteral kInMemoryTileIdAttr("arm_sme.in_memory_tile_id");
39
40/// Helper to create an arm_sme.intr.ld1*.(horiz|vert)' intrinsic.
41static Operation *createLoadTileSliceIntrinsic(
42 RewriterBase &rewriter, Location loc, arm_sme::ArmSMETileType type,
43 arm_sme::TileSliceLayout layout, Value maskOp, Value ptr,
44 IntegerAttr tileId, Value tileSliceI32) {
45 if (layout == arm_sme::TileSliceLayout::Horizontal) {
46 switch (type) {
47 case arm_sme::ArmSMETileType::ZAB:
48 return arm_sme::aarch64_sme_ld1b_horiz::create(rewriter, loc, maskOp, ptr,
49 tileId, tileSliceI32);
50 case arm_sme::ArmSMETileType::ZAH:
51 return arm_sme::aarch64_sme_ld1h_horiz::create(rewriter, loc, maskOp, ptr,
52 tileId, tileSliceI32);
53 case arm_sme::ArmSMETileType::ZAS:
54 return arm_sme::aarch64_sme_ld1w_horiz::create(rewriter, loc, maskOp, ptr,
55 tileId, tileSliceI32);
56 case arm_sme::ArmSMETileType::ZAD:
57 return arm_sme::aarch64_sme_ld1d_horiz::create(rewriter, loc, maskOp, ptr,
58 tileId, tileSliceI32);
59 case arm_sme::ArmSMETileType::ZAQ:
60 return arm_sme::aarch64_sme_ld1q_horiz::create(rewriter, loc, maskOp, ptr,
61 tileId, tileSliceI32);
62 }
63 } else {
64 switch (type) {
65 case arm_sme::ArmSMETileType::ZAB:
66 return arm_sme::aarch64_sme_ld1b_vert::create(rewriter, loc, maskOp, ptr,
67 tileId, tileSliceI32);
68 case arm_sme::ArmSMETileType::ZAH:
69 return arm_sme::aarch64_sme_ld1h_vert::create(rewriter, loc, maskOp, ptr,
70 tileId, tileSliceI32);
71 case arm_sme::ArmSMETileType::ZAS:
72 return arm_sme::aarch64_sme_ld1w_vert::create(rewriter, loc, maskOp, ptr,
73 tileId, tileSliceI32);
74 case arm_sme::ArmSMETileType::ZAD:
75 return arm_sme::aarch64_sme_ld1d_vert::create(rewriter, loc, maskOp, ptr,
76 tileId, tileSliceI32);
77 case arm_sme::ArmSMETileType::ZAQ:
78 return arm_sme::aarch64_sme_ld1q_vert::create(rewriter, loc, maskOp, ptr,
79 tileId, tileSliceI32);
80 break;
81 }
82 }
83 llvm_unreachable("unknown type in createLoadTileSliceIntrinsic");
84}
85
86/// Helper to create an arm_sme.intr.st1*.(horiz|vert)' intrinsic.
87static Operation *createStoreTileSliceIntrinsic(
88 RewriterBase &rewriter, Location loc, arm_sme::ArmSMETileType type,
89 arm_sme::TileSliceLayout layout, Value maskOp, Value ptr,
90 IntegerAttr tileId, Value tileSliceI32) {
91 if (layout == arm_sme::TileSliceLayout::Horizontal) {
92 switch (type) {
93 case arm_sme::ArmSMETileType::ZAB:
94 return arm_sme::aarch64_sme_st1b_horiz::create(rewriter, loc, maskOp, ptr,
95 tileId, tileSliceI32);
96 case arm_sme::ArmSMETileType::ZAH:
97 return arm_sme::aarch64_sme_st1h_horiz::create(rewriter, loc, maskOp, ptr,
98 tileId, tileSliceI32);
99 case arm_sme::ArmSMETileType::ZAS:
100 return arm_sme::aarch64_sme_st1w_horiz::create(rewriter, loc, maskOp, ptr,
101 tileId, tileSliceI32);
102 case arm_sme::ArmSMETileType::ZAD:
103 return arm_sme::aarch64_sme_st1d_horiz::create(rewriter, loc, maskOp, ptr,
104 tileId, tileSliceI32);
105 case arm_sme::ArmSMETileType::ZAQ:
106 return arm_sme::aarch64_sme_st1q_horiz::create(rewriter, loc, maskOp, ptr,
107 tileId, tileSliceI32);
108 }
109 } else {
110 switch (type) {
111 case arm_sme::ArmSMETileType::ZAB:
112 return arm_sme::aarch64_sme_st1b_vert::create(rewriter, loc, maskOp, ptr,
113 tileId, tileSliceI32);
114 case arm_sme::ArmSMETileType::ZAH:
115 return arm_sme::aarch64_sme_st1h_vert::create(rewriter, loc, maskOp, ptr,
116 tileId, tileSliceI32);
117 case arm_sme::ArmSMETileType::ZAS:
118 return arm_sme::aarch64_sme_st1w_vert::create(rewriter, loc, maskOp, ptr,
119 tileId, tileSliceI32);
120 case arm_sme::ArmSMETileType::ZAD:
121 return arm_sme::aarch64_sme_st1d_vert::create(rewriter, loc, maskOp, ptr,
122 tileId, tileSliceI32);
123 case arm_sme::ArmSMETileType::ZAQ:
124 return arm_sme::aarch64_sme_st1q_vert::create(rewriter, loc, maskOp, ptr,
125 tileId, tileSliceI32);
126 }
127 }
128 llvm_unreachable("unknown type in createStoreTileSliceIntrinsic");
129}
130
131IntegerAttr getTileIdOrError(arm_sme::ArmSMETileOpInterface op) {
132 auto tileId = op.getTileId();
133 if (!tileId)
134 op.emitOpError(
135 "expected tile ID to be allocated before conversion to LLVM");
136 return tileId;
137}
138
139/// Creates an alloca matching the size of tile used by `tileOp`. The alloca is
140/// placed in the first block of the function.
141static memref::AllocaOp
142createAllocaForTile(RewriterBase &rewriter, Location loc,
143 FunctionOpInterface func,
144 arm_sme::ArmSMETileOpInterface tileOp) {
146 // Move to the first operation in the function.
147 rewriter.setInsertionPointToStart(&func.getBlocks().front());
148 // Create an alloca matching the tile size of the `tileOp`.
149 auto vscale = vector::VectorScaleOp::create(rewriter, loc);
150 auto tileElementType = tileOp.getTileType().getElementType();
151 auto memrefType = MemRefType::get(
152 {ShapedType::kDynamic, ShapedType::kDynamic}, tileElementType);
153 unsigned minElements = arm_sme::getSMETileSliceMinNumElts(tileElementType);
154 auto minElementsOp =
155 arith::ConstantIndexOp::create(rewriter, loc, minElements);
156 auto vectorLen = arith::MulIOp::create(rewriter, loc, vscale, minElementsOp);
157 auto alloca = memref::AllocaOp::create(rewriter, loc, memrefType,
158 ValueRange{vectorLen, vectorLen});
159 return alloca;
160}
161
162/// Finds or creates an alloca for a spill of a tile.
163static memref::AllocaOp getOrCreateAllocaForTile(
164 RewriterBase &rewriter, Location loc, FunctionOpInterface func,
165 arm_sme::ArmSMETileOpInterface tileOp, unsigned tileId) {
166 // Find an alloca at the top of the function tagged with a
167 // 'arm_sme.in_memory_tile_id' that matches `tileId`.
168 for (auto &op : func.getBlocks().front()) {
169 auto alloca = llvm::dyn_cast<memref::AllocaOp>(op);
170 if (!alloca)
171 continue;
172 auto inMemoryTileId = llvm::dyn_cast_or_null<IntegerAttr>(
173 alloca->getDiscardableAttr(kInMemoryTileIdAttr));
174 if (!inMemoryTileId)
175 continue;
176 if (inMemoryTileId.getInt() == tileId)
177 return alloca;
178 }
179 // Otherwise, create a new alloca:
180 auto alloca = createAllocaForTile(rewriter, loc, func, tileOp);
181 alloca->setDiscardableAttr(kInMemoryTileIdAttr,
182 rewriter.getI32IntegerAttr(tileId));
183 return alloca;
184}
185
186/// Very naive lowering of in-memory tiles (i.e. tiles that were not assigned a
187/// hardware tile ID) to ArmSME intrinsics. Currently, this works by assigning
188/// the op to tile 0, then emitting a full tile swap between ZA and memory
189/// before + after the tile op.
190///
191/// Example:
192///
193/// // Note: <IN MEMORY TILE> = tile ID >= 16.
194/// arm_sme.tile_op { tile_id = <IN MEMORY TILE> }
195///
196/// is converted to:
197/// // At function entry:
198/// %spill = memref.alloca ... : memref<?x?xty>
199///
200/// // Around op:
201/// scf.for %slice_idx {
202/// %slice_to_save = "arm_sme.intr.read.horiz" ... <{tile_id = 0 : i32}>
203/// "arm_sme.intr.ld1h.horiz"(%spill, %slice_idx) <{tile_id = 0 : i32}>
204/// vector.store %slice_to_save, %spill[%slice_idx, %c0]
205/// }
206/// arm_sme.tile_op { tile_id = 0 }
207/// scf.for %slice_idx {
208/// %slice_to_save = "arm_sme.intr.read.horiz" ... <{tile_id = 0 : i32}>
209/// "arm_sme.intr.ld1h.horiz"(%spill, %slice_idx) <{tile_id = 0 : i32}>
210/// vector.store %slice_to_save, %spill[%slice_idx, %c0]
211/// }
212///
213/// Note that these spills/fills are not inserted earlier as concept of a
214/// register, and the need to swap the contents, can't really be represented
215/// correctly at a high level in MLIR.
216///
217/// TODO: Reduce the spills/reloads to single slices where possible (and omit
218/// redundant reloads). This could be done via a method on the
219/// `ArmSMETileOpInterface` which returns how the operation uses ZA. E.g.:
220///
221/// `tileOp.getZaUsage()` could return:
222///
223/// struct ArmSMEOpZAUsage {
224/// enum class Kind {
225/// TileRead, // Omit store after tile operation.
226/// TileWrite, // Omit load before tile operation.
227/// TileReadWrite, // Needs both tile load and store.
228/// SliceRead, // Spill single slice and omit store after operation.
229/// SliceWrite, // Spill single slice and omit load before operation.
230/// SliceReadWrite // Spill single slice.
231/// };
232/// Value sliceIndex {};
233/// TileSliceLayout sliceLayout { TileSliceLayout::Horizontal };
234/// };
235///
236struct ConvertArmSMESpillsAndFillsToLLVM : public ConvertToLLVMPattern {
237
238 ConvertArmSMESpillsAndFillsToLLVM(StringRef rootOpName,
239 const LLVMTypeConverter &typeConverter,
240 PatternBenefit benefit)
241 : ConvertToLLVMPattern(rootOpName, &typeConverter.getContext(),
242 typeConverter, benefit) {}
243
244 LogicalResult
245 matchAndRewrite(Operation *op, ArrayRef<Value> operands,
246 ConversionPatternRewriter &rewriter) const override {
247 auto tileOp = cast<arm_sme::ArmSMETileOpInterface>(op);
248 // Tile has a real (hardware) tile. No spills/reloads required.
249 if (!tileOp.isInMemoryTile())
250 return failure();
251
252 tileOp->emitWarning(
253 "failed to allocate SME virtual tile to operation, tile value will go "
254 "through memory, expect degraded performance");
255
256 // Step 1. Create an alloca for the tile at the top of the function (if one
257 // does not already exist).
258 auto loc = tileOp.getLoc();
259 auto func = tileOp->getParentOfType<FunctionOpInterface>();
260 auto tileAlloca = getOrCreateAllocaForTile(rewriter, loc, func, tileOp,
261 tileOp.getTileId().getInt());
262
263 // Step 2. Assign the op a real tile ID.
264 // For simplicity, we always use tile 0 (which always exists).
265 auto zeroTileId = rewriter.getI32IntegerAttr(0);
266 rewriter.modifyOpInPlace(tileOp, [&] { tileOp.setTileId(zeroTileId); });
267
268 VectorType tileVectorType = tileOp.getTileType();
269 auto sliceType = VectorType::Builder(tileVectorType).dropDim(0);
270 auto swapInMemoryTileWithSMETileZero = [&] {
271 emitFullTileSwap(rewriter, loc, tileAlloca,
272 *arm_sme::getSMETileType(tileVectorType), sliceType,
273 zeroTileId);
274 };
275
276 // Step 3. Emit tile swaps before and after the op.
277 // TODO: Reduce the amount spilled to the amount of data the `tileOp`
278 // touches (i.e. a single tile slice).
279 {
280 rewriter.setInsertionPoint(op);
281 // Swap the contents of ZA and the in-memory tile before the op.
282 swapInMemoryTileWithSMETileZero();
283 rewriter.setInsertionPointAfter(op);
284 // Swap the tile back out to memory again after the op.
285 swapInMemoryTileWithSMETileZero();
286 }
287
288 return success();
289 }
290
291 /// Extracts a pointer to a slice of an in-memory tile.
292 Value getInMemoryTileSlicePtr(RewriterBase &rewriter, Location loc,
293 Value tileMemory, Value sliceIndex) const {
294 auto llvmType = getTypeConverter()->convertType(tileMemory.getType());
295 auto descriptor =
296 UnrealizedConversionCastOp::create(rewriter, loc, llvmType, tileMemory);
297 auto zero = arith::ConstantIntOp::create(rewriter, loc, 0, /*width=*/64);
298 auto sliceIndexI64 = arith::IndexCastOp::create(
299 rewriter, loc, rewriter.getI64Type(), sliceIndex);
301 static_cast<ConversionPatternRewriter &>(rewriter), loc,
302 llvm::cast<MemRefType>(tileMemory.getType()), descriptor.getResult(0),
303 {sliceIndexI64, zero});
304 }
305
306 /// Emits an in-place swap of a slice of a tile in ZA and a slice of a
307 /// tile-sized memref (`tileAlloca`).
308 void emitSliceSwap(RewriterBase &rewriter, Location loc, Value tileAlloca,
309 arm_sme::ArmSMETileType tileType, VectorType sliceType,
310 IntegerAttr tileId, Value sliceIndex) const {
311 // Cast the slice index to an i32.
312 auto sliceIndexI32 = arith::IndexCastOp::create(
313 rewriter, loc, rewriter.getI32Type(), sliceIndex);
314 // Create an all-true predicate for the slice.
315 auto predicateType = sliceType.clone(rewriter.getI1Type());
316 auto allTruePredicate = arith::ConstantOp::create(
317 rewriter, loc, DenseElementsAttr::get(predicateType, true));
318 // Create padding vector (never used due to all-true predicate).
319 auto padVector = LLVM::PoisonOp::create(rewriter, loc, sliceType);
320 // Get a pointer to the current slice.
321 auto slicePtr =
322 getInMemoryTileSlicePtr(rewriter, loc, tileAlloca, sliceIndex);
323 // Read the value of the current slice from ZA.
324 auto currentTileSlice = arm_sme::aarch64_sme_read_horiz::create(
325 rewriter, loc, sliceType, padVector, allTruePredicate, tileId,
326 sliceIndexI32);
327 // Load the new tile slice back from memory into ZA.
328 createLoadTileSliceIntrinsic(
329 rewriter, loc, tileType, arm_sme::TileSliceLayout::Horizontal,
330 allTruePredicate, slicePtr, tileId, sliceIndexI32);
331 // Store the current tile slice to memory.
332 auto zero = arith::ConstantIndexOp::create(rewriter, loc, 0);
333 vector::StoreOp::create(rewriter, loc, currentTileSlice, tileAlloca,
334 ValueRange{sliceIndex, zero});
335 }
336
337 /// Emits a full in-place swap of the contents of a tile in ZA and a
338 /// tile-sized memref (`tileAlloca`).
339 void emitFullTileSwap(RewriterBase &rewriter, Location loc, Value tileAlloca,
340 arm_sme::ArmSMETileType tileType, VectorType sliceType,
341 IntegerAttr tileId) const {
342 RewriterBase::InsertionGuard guard(rewriter);
343 // Create an scf.for over all tile slices.
344 auto minNumElts =
345 arith::ConstantIndexOp::create(rewriter, loc, sliceType.getDimSize(0));
346 auto lowerBound = arith::ConstantIndexOp::create(rewriter, loc, 0);
347 auto upperBound =
348 arith::MulIOp::create(rewriter, loc, minNumElts,
349 vector::VectorScaleOp::create(rewriter, loc));
350 auto step = arith::ConstantIndexOp::create(rewriter, loc, 1);
351 auto forOp =
352 scf::ForOp::create(rewriter, loc, lowerBound, upperBound, step);
353 // Emit a swap for each tile slice.
354 rewriter.setInsertionPointToStart(forOp.getBody());
355 auto sliceIndex = forOp.getInductionVar();
356 emitSliceSwap(rewriter, loc, tileAlloca, tileType, sliceType, tileId,
357 sliceIndex);
358 }
359};
360
361enum class RequiresSpillsAndFills { Yes, No };
362
363/// Base class for ArmSME to LLVM conversion patterns. By default, this adds
364/// spills and fills around ArmSME ops that use in-memory tile IDs. This can be
365/// disabled by setting the `requiresSpillsAndFills` template parameter to
366/// `RequiresSpillsAndFills::No`.
367template <typename SourceOp, RequiresSpillsAndFills requiresSpillsAndFills =
368 RequiresSpillsAndFills::Yes>
369struct ConvertArmSMEOpToLLVMPattern : ConvertOpToLLVMPattern<SourceOp> {
370 using ArmSMEOp = SourceOp;
371 using ConvertOpToLLVMPattern<SourceOp>::ConvertOpToLLVMPattern;
372
373 static constexpr bool requiresSpillsAndFillsConversion() {
374 return requiresSpillsAndFills == RequiresSpillsAndFills::Yes;
375 }
376};
377
378template <typename Pattern>
379static void addArmSMEConversionPattern(RewritePatternSet &patterns,
380 LLVMTypeConverter const &typeConverter) {
381 // Register spills/fills for ops that implement the
382 // `ArmSMETileOpInterface` and have `requiresSpillsAndFills` set to
383 // `RequiresSpillsAndFills::Yes`.
384 if constexpr (Pattern::requiresSpillsAndFillsConversion() &&
385 std::is_base_of_v<arm_sme::ArmSMETileOpInterface::Trait<
386 typename Pattern::ArmSMEOp>,
387 typename Pattern::ArmSMEOp>) {
388 // Add spill/fill conversions with a very high benefit to ensure
389 // they are lowered first.
390 patterns.add<ConvertArmSMESpillsAndFillsToLLVM>(
391 Pattern::ArmSMEOp::getOperationName(), typeConverter,
392 /*benefit=*/1337);
393 }
394 patterns.add<Pattern>(typeConverter);
395}
396
397/// Helper to register `ConvertArmSMEOpToLLVMPattern` patterns.
398template <typename... Patterns>
399static void
400addArmSMEConversionPatterns(RewritePatternSet &patterns,
401 LLVMTypeConverter const &typeConverter) {
402 (addArmSMEConversionPattern<Patterns>(patterns, typeConverter), ...);
403}
404
405/// Lower 'arm_sme.zero' to SME intrinsics.
406///
407/// BEFORE:
408/// ```mlir
409/// %v = arm_sme.zero {tile_id = 0 : i32} : vector<[4]x[4]xi32>
410/// ```
411///
412/// AFTER:
413/// ```mlir
414/// "arm_sme.intr.zero"() <{tile_mask = 17 : i32}> : () -> ()
415/// %v = arm_sme.get_tile : vector<[4]x[4]xi32>
416/// ```
417///
418/// The 'arm_sme.get_tile' (which models the return) will fold away once all
419/// ArmSME ops have been converted to LLVM intrinsics.
420struct ZeroOpConversion : public ConvertArmSMEOpToLLVMPattern<arm_sme::ZeroOp> {
421 using ConvertArmSMEOpToLLVMPattern::ConvertArmSMEOpToLLVMPattern;
422
423 LogicalResult
424 matchAndRewrite(arm_sme::ZeroOp zero, OpAdaptor adaptor,
425 ConversionPatternRewriter &rewriter) const override {
426 auto loc = zero.getLoc();
427
428 auto tileId = getTileIdOrError(zero);
429 if (!tileId)
430 return failure();
431
432 // Get the base mask for tile based on the element size.
433 // The base mask is just the mask to zero the first tile (of a size).
434 // These masks are derived from:
435 // https://developer.arm.com/documentation/ddi0602/2022-06/SME-Instructions/ZERO--Zero-a-list-of-64-bit-element-ZA-tiles-
436 arm_sme::ArmSMETileType tileType =
437 *arm_sme::getSMETileType(zero.getTileType());
438 auto baseMaskForSize = [&] {
439 switch (tileType) {
440 case arm_sme::ArmSMETileType::ZAB:
441 // Zeroing the 8-bit ZA0.B tile is equivalent to zeroing all eight
442 // 64-bit element tiles named ZA0.D to ZA7.D.
443 return 0b1111'1111;
444 case arm_sme::ArmSMETileType::ZAH:
445 // Zeroing the 16-bit ZA0.H tile is equivalent to zeroing 64-bit
446 // element tiles named ZA0.D, ZA2.D, ZA4.D, and ZA6.D. Shift this left
447 // once for ZA1.H.
448 return 0b0101'0101;
449 case arm_sme::ArmSMETileType::ZAS:
450 // Zeroing the 32-bit ZA0.S tile is equivalent to zeroing 64-bit
451 // element tiles named ZA0.D and ZA4.D.
452 // Shift left by 1, 2, or 3 respectively for ZA1.S, ZA2.S, ZA3.S.
453 return 0b0001'0001;
454 case arm_sme::ArmSMETileType::ZAD:
455 // Zeroing one of the a 64-bit tiles ZA0.D to ZA7.D just requires
456 // setting the bit for that tile.
457 return 0b0000'0001;
458 default:
459 llvm_unreachable("bad element size");
460 }
461 }();
462
463 // The actual mask is just the base mask shifted by the tile ID.
464 // This will be folded to a constant after tile allocation.
465 //
466 // The shift is just derived from the layout of the tiles, and that the tile
467 // ID is the index of the tile. For example, looking at the 32-bit ZAx.S
468 // tiles:
469 //
470 // ZA0.S = ZA0.D and ZA4.D
471 // * Tile ID -> 0
472 // * Mask -> 00010001 = (00010001 << 0)
473 // ZA1.S = ZA1.D and ZA5.D
474 // * Tile ID -> 1
475 // * Mask -> 00100010 = (00010001 << 1)
476 // ZA2.S = ZA2.D and ZA6.D
477 // * Tile ID -> 2
478 // * Mask -> 01000100 = (00010001 << 2)
479 // ZA3.S = ZA3.D and ZA7.D
480 // * Tile ID -> 3
481 // * Mask -> 10001000 = (00010001 << 3)
482 //
483 // This holds for all tile sizes.
484 int32_t zeroMask = baseMaskForSize << int32_t(tileId.getInt());
485 arm_sme::aarch64_sme_zero::create(rewriter, loc,
486 rewriter.getI32IntegerAttr(zeroMask));
487
488 // Create a placeholder op to preserve dataflow.
489 // Note: Place the `get_tile` op at the start of the block. This ensures
490 // that if there are multiple `zero` ops the intrinsics will be consecutive.
491 rewriter.setInsertionPointToStart(zero->getBlock());
492 rewriter.replaceOpWithNewOp<arm_sme::GetTileOp>(zero, zero.getVectorType());
493
494 return success();
495 }
496};
497
498/// Lower `arm_sme.load_tile_slice` to SME intrinsics.
499struct LoadTileSliceConversion
500 : public ConvertArmSMEOpToLLVMPattern<arm_sme::LoadTileSliceOp> {
501 using ConvertArmSMEOpToLLVMPattern::ConvertArmSMEOpToLLVMPattern;
502
503 LogicalResult
504 matchAndRewrite(arm_sme::LoadTileSliceOp loadTileSliceOp,
505 arm_sme::LoadTileSliceOp::Adaptor adaptor,
506 ConversionPatternRewriter &rewriter) const override {
507 auto loc = loadTileSliceOp.getLoc();
508 auto tileId = getTileIdOrError(loadTileSliceOp);
509 if (!tileId)
510 return failure();
511
512 Value ptr = this->getStridedElementPtr(
513 rewriter, loc, loadTileSliceOp.getMemRefType(), adaptor.getBase(),
514 adaptor.getIndices());
515
516 auto tileSlice = loadTileSliceOp.getTileSliceIndex();
517
518 // Cast tile slice to i32 for intrinsic.
519 auto tileSliceI32 = arith::IndexCastUIOp::create(
520 rewriter, loc, rewriter.getI32Type(), tileSlice);
521
522 // Create all active predicate mask.
523 auto maskOp = loadTileSliceOp.getMask();
524
525 auto tileVectorType = loadTileSliceOp.getVectorType();
526 arm_sme::ArmSMETileType tileType = *arm_sme::getSMETileType(tileVectorType);
527 arm_sme::TileSliceLayout layout = loadTileSliceOp.getLayout();
528
529 // Create 'arm_sme.intr.ld1*.(horiz|vert)' intrinsic to load ZA tile slice.
530 createLoadTileSliceIntrinsic(rewriter, loc, tileType, layout, maskOp, ptr,
531 tileId, tileSliceI32);
532
533 // The load intrinsics have no result, replace 'arm_sme.tile_load' with
534 // the input tile to preserve dataflow.
535 rewriter.replaceOp(loadTileSliceOp, loadTileSliceOp.getTile());
536
537 return success();
538 }
539};
540
541/// Lower for `arm_sme.store_tile_slice` to SME intrinsics.
542struct StoreTileSliceConversion
543 : public ConvertArmSMEOpToLLVMPattern<arm_sme::StoreTileSliceOp> {
544 using ConvertArmSMEOpToLLVMPattern::ConvertArmSMEOpToLLVMPattern;
545
546 LogicalResult
547 matchAndRewrite(arm_sme::StoreTileSliceOp storeTileSliceOp,
548 arm_sme::StoreTileSliceOp::Adaptor adaptor,
549 ConversionPatternRewriter &rewriter) const override {
550 auto loc = storeTileSliceOp.getLoc();
551 auto tileVectorType = storeTileSliceOp.getVectorType();
552
553 auto tileId = getTileIdOrError(storeTileSliceOp);
554 if (!tileId)
555 return failure();
556
557 // Create 'arm_sme.intr.st1*.horiz' intrinsic to store ZA tile slice.
558 Value ptr = this->getStridedElementPtr(
559 rewriter, loc, storeTileSliceOp.getMemRefType(), adaptor.getBase(),
560 adaptor.getIndices());
561
562 auto tileSlice = storeTileSliceOp.getTileSliceIndex();
563
564 // Cast tile slice to i32 for intrinsic.
565 auto tileSliceI32 = arith::IndexCastUIOp::create(
566 rewriter, loc, rewriter.getI32Type(), tileSlice);
567
568 auto maskOp = storeTileSliceOp.getMask();
569
570 arm_sme::TileSliceLayout layout = storeTileSliceOp.getLayout();
571 arm_sme::ArmSMETileType tileType = *arm_sme::getSMETileType(tileVectorType);
572
573 rewriter.replaceOp(storeTileSliceOp,
574 createStoreTileSliceIntrinsic(rewriter, loc, tileType,
575 layout, maskOp, ptr,
576 tileId, tileSliceI32));
577
578 return success();
579 }
580};
581
582/// Lower `arm_sme.insert_tile_slice` to SME intrinsics.
583struct InsertTileSliceConversion
584 : public ConvertArmSMEOpToLLVMPattern<arm_sme::InsertTileSliceOp> {
585 using ConvertArmSMEOpToLLVMPattern::ConvertArmSMEOpToLLVMPattern;
586
587 LogicalResult
588 matchAndRewrite(arm_sme::InsertTileSliceOp insertTileSliceOp,
589 arm_sme::InsertTileSliceOp::Adaptor adaptor,
590 ConversionPatternRewriter &rewriter) const override {
591 auto loc = insertTileSliceOp.getLoc();
592 auto tileType = insertTileSliceOp.getTileType();
593
594 auto tileId = getTileIdOrError(insertTileSliceOp);
595 if (!tileId)
596 return failure();
597
598 auto tileSlice = insertTileSliceOp.getTileSliceIndex();
599
600 // Cast tile slice from index to i32 for intrinsic.
601 auto tileSliceI32 = arith::IndexCastUIOp::create(
602 rewriter, loc, rewriter.getI32Type(), tileSlice);
603
604 // Create all active predicate mask.
605 auto one = arith::ConstantOp::create(
606 rewriter, loc, rewriter.getI1Type(),
607 rewriter.getIntegerAttr(rewriter.getI1Type(), 1));
608 auto predTy = VectorType::get(tileType.getShape()[0], rewriter.getI1Type(),
609 /*scalableDims=*/{true});
610 auto allActiveMask =
611 vector::BroadcastOp::create(rewriter, loc, predTy, one);
612
613 // Create 'arm_sme.intr.write.(horiz|vert)' to write vector to tile slice.
614 switch (insertTileSliceOp.getLayout()) {
615 case arm_sme::TileSliceLayout::Horizontal:
616 arm_sme::aarch64_sme_write_horiz::create(rewriter, loc, tileId,
617 tileSliceI32, allActiveMask,
618 insertTileSliceOp.getVector());
619 break;
620 case arm_sme::TileSliceLayout::Vertical:
621 arm_sme::aarch64_sme_write_vert::create(rewriter, loc, tileId,
622 tileSliceI32, allActiveMask,
623 insertTileSliceOp.getVector());
624 break;
625 }
626
627 // Intrinsic has no result, replace 'arm_sme.insert_tile_slice' with
628 // the input tile to preserve dataflow.
629 rewriter.replaceOp(insertTileSliceOp, insertTileSliceOp.getTile());
630
631 return success();
632 }
633};
634
635/// Lower `arm_sme.extract_tile_slice` to SME intrinsics.
636struct ExtractTileSliceConversion
637 : public ConvertArmSMEOpToLLVMPattern<arm_sme::ExtractTileSliceOp> {
638 using ConvertArmSMEOpToLLVMPattern::ConvertArmSMEOpToLLVMPattern;
639
640 LogicalResult
641 matchAndRewrite(arm_sme::ExtractTileSliceOp extractTileSlice, OpAdaptor,
642 ConversionPatternRewriter &rewriter) const override {
643 auto loc = extractTileSlice.getLoc();
644 auto sliceType = extractTileSlice.getSliceType();
645 auto sliceIndex = extractTileSlice.getTileSliceIndex();
646
647 auto tileId = getTileIdOrError(extractTileSlice);
648 if (!tileId)
649 return failure();
650
651 // Create an 'all true' predicate for the tile slice.
652 auto predicateType = sliceType.cloneWith({}, rewriter.getI1Type());
653 auto allTruePredicate = arith::ConstantOp::create(
654 rewriter, loc, DenseElementsAttr::get(predicateType, true));
655
656 // Zero destination/fallback for tile slice extraction.
657 auto zeroVector = arith::ConstantOp::create(
658 rewriter, loc, sliceType, rewriter.getZeroAttr(sliceType));
659
660 // Cast tile slice from index to i32 for intrinsic.
661 auto sliceIndexI32 = arith::IndexCastOp::create(
662 rewriter, loc, rewriter.getI32Type(), sliceIndex);
663
664 // Create 'arm_sme.intr.read.(horiz|vert)' to extract the tile slice.
665 switch (extractTileSlice.getLayout()) {
666 case arm_sme::TileSliceLayout::Horizontal:
667 rewriter.replaceOpWithNewOp<arm_sme::aarch64_sme_read_horiz>(
668 extractTileSlice, sliceType, zeroVector, allTruePredicate, tileId,
669 sliceIndexI32);
670 break;
671 case arm_sme::TileSliceLayout::Vertical:
672 rewriter.replaceOpWithNewOp<arm_sme::aarch64_sme_read_vert>(
673 extractTileSlice, sliceType, zeroVector, allTruePredicate, tileId,
674 sliceIndexI32);
675 break;
676 }
677
678 return success();
679 }
680};
681
682/// Lower `arm_sme.outerproduct` to SME MOPA intrinsics.
683///
684/// Example:
685///
686/// %0 = arm_sme.outerproduct %lhs, %rhs acc(%acc)
687/// : vector<[4]xf32>, vector<[4]xf32>
688///
689/// is converted to:
690///
691/// "arm_sme.intr.mopa"(%ptrue_s, %ptrue_s, %lhs, %rhs) <{tile_id = 0 : i32}>
692/// : (vector<[4]xi1>, vector<[4]xi1>, vector<[4]xf32>,
693/// vector<[4]xf32>) -> ()
694///
695/// Currently only supports FMOPA and BFMOPA (non-widening).
696struct OuterProductOpConversion
697 : public ConvertArmSMEOpToLLVMPattern<arm_sme::OuterProductOp> {
698 using ConvertArmSMEOpToLLVMPattern::ConvertArmSMEOpToLLVMPattern;
699
700 LogicalResult
701 matchAndRewrite(arm_sme::OuterProductOp outerProductOp,
702 arm_sme::OuterProductOp::Adaptor adaptor,
703 ConversionPatternRewriter &rewriter) const override {
704 auto tileId = getTileIdOrError(outerProductOp);
705 if (!tileId)
706 return failure();
707
708 auto isSupportedType = [](VectorType vectorType) {
709 // TODO: the FP outer product instruction variants are predicated on
710 // different features [1]:
711 //
712 // * FMOPA (non-widening)
713 // * half-precision - +sme2p1,+sme-f16f16
714 // * single-precision - +sme
715 // * double-precision - +sme-f64f64
716 // * BFMOPA
717 // * half-precision - +sme2p1,+b16b16
718 //
719 // It should be possible to control lowering based on target features.
720 // [1]
721 // https://developer.arm.com/downloads/-/exploration-tools/feature-names-for-a-profile
722 if ((vectorType.getRank() != 2) || !vectorType.allDimsScalable())
723 return false;
724
725 auto elementType = vectorType.getElementType();
726
727 if (!elementType.isF16() && !elementType.isBF16() &&
728 !elementType.isF32() && !elementType.isF64())
729 return false;
730
731 unsigned minNumElts = arm_sme::MinStreamingVectorLengthInBits /
732 vectorType.getElementTypeBitWidth();
733 return vectorType.getShape() ==
734 ArrayRef<int64_t>({minNumElts, minNumElts});
735 };
736
737 // TODO: Support CombiningKind::Sub for outer products.
738 if (outerProductOp.getKind() != arm_sme::CombiningKind::Add)
739 return outerProductOp.emitError("unsupported kind");
740
741 auto resultVectorType = outerProductOp.getResultType();
742 if (!isSupportedType(resultVectorType))
743 return outerProductOp.emitError("unsupported type");
744
745 auto loc = outerProductOp.getLoc();
746
747 Value acc = outerProductOp.getAcc();
748 if (!acc) {
749 // Initalize accumulator with zero.
750 auto zero = arm_sme::ZeroOp::create(rewriter, loc, resultVectorType);
751 zero.setTileId(tileId);
752 acc = zero;
753 }
754
755 Value lhsMask = outerProductOp.getLhsMask();
756 Value rhsMask = outerProductOp.getRhsMask();
757
758 if (!lhsMask || !rhsMask) {
759 auto predTy =
760 outerProductOp.getLhsType().cloneWith({}, rewriter.getI1Type());
761 Value allActiveMask = arith::ConstantOp::create(
762 rewriter, loc, DenseElementsAttr::get(predTy, true));
763 lhsMask = allActiveMask;
764 rhsMask = allActiveMask;
765 }
766
767 // Create 'arm_sme.intr.mopa' outer product intrinsic.
768 arm_sme::aarch64_sme_mopa::create(rewriter, loc, tileId, lhsMask, rhsMask,
769 outerProductOp.getLhs(),
770 outerProductOp.getRhs());
771
772 // The outerproduct intrinsics have no result, replace
773 // 'arm_sme.outerproduct' with the input tile to preserve dataflow.
774 rewriter.replaceOp(outerProductOp, acc);
775
776 return success();
777 }
778};
779
780/// Lower 2-way and 4-way widening outer products to intrinsics.
781template <class OuterProductWideningOp, class OuterProductWideningIntrOp>
782struct OuterProductWideningOpConversion
783 : public ConvertArmSMEOpToLLVMPattern<OuterProductWideningOp> {
784 using ConvertArmSMEOpToLLVMPattern<
785 OuterProductWideningOp>::ConvertArmSMEOpToLLVMPattern;
786
787 LogicalResult
788 matchAndRewrite(OuterProductWideningOp op,
789 typename OuterProductWideningOp::Adaptor adaptor,
790 ConversionPatternRewriter &rewriter) const override {
791 auto tileId = getTileIdOrError(op);
792 if (!tileId)
793 return failure();
794
795 auto loc = op.getLoc();
796 Value acc = op.getAcc();
797 if (!acc) {
798 // Initalize accumulator with zero.
799 auto zero = arm_sme::ZeroOp::create(rewriter, loc, op.getResultType());
800 zero.setTileId(tileId);
801 acc = zero;
802 }
803
804 Value lhsMask = op.getLhsMask();
805 Value rhsMask = op.getRhsMask();
806 if (!lhsMask || !rhsMask) {
807 auto predTy = op.getLhsType().cloneWith({}, rewriter.getI1Type());
808 Value allActiveMask = arith::ConstantOp::create(
809 rewriter, loc, DenseElementsAttr::get(predTy, true));
810 lhsMask = allActiveMask;
811 rhsMask = allActiveMask;
812 }
813
814 OuterProductWideningIntrOp::create(rewriter, loc, tileId, lhsMask, rhsMask,
815 adaptor.getLhs(), adaptor.getRhs());
816
817 // The outerproduct intrinsics have no result, replace
818 // 'arm_sme.outerproduct' with the input tile to preserve dataflow.
819 rewriter.replaceOp(op, acc);
820
821 return success();
822 }
823};
824
825/// Lower `arm_sme.streaming_vl` to SME CNTSD intrinsic.
826///
827/// Example:
828///
829/// %0 = arm_sme.streaming_vl <half>
830///
831/// is converted to:
832///
833/// %cnt = "arm_sme.intr.cntsd"() : () -> i64
834/// %scale = arith.constant 4 : index
835/// %cntIndex = arith.index_cast %cnt : i64 to index
836/// %0 = arith.muli %cntIndex, %scale : index
837///
838struct StreamingVLOpConversion
839 : public ConvertArmSMEOpToLLVMPattern<arm_sme::StreamingVLOp,
840 RequiresSpillsAndFills::No> {
841 using ConvertArmSMEOpToLLVMPattern::ConvertArmSMEOpToLLVMPattern;
842
843 LogicalResult
844 matchAndRewrite(arm_sme::StreamingVLOp streamingVlOp,
845 arm_sme::StreamingVLOp::Adaptor adaptor,
846 ConversionPatternRewriter &rewriter) const override {
847 auto loc = streamingVlOp.getLoc();
848 auto i64Type = rewriter.getI64Type();
849 auto cntsd = arm_sme::aarch64_sme_cntsd::create(rewriter, loc, i64Type);
850 auto cntsdIdx = arith::IndexCastOp::create(rewriter, loc,
851 rewriter.getIndexType(), cntsd);
853 rewriter, loc,
854 8 / arm_sme::getSizeInBytes(streamingVlOp.getTypeSize()));
855 rewriter.replaceOpWithNewOp<arith::MulIOp>(streamingVlOp, cntsdIdx, scale);
856 return success();
857 }
858};
859
860/// Merges consecutive `arm_sme.intr.zero` operations in a block by bitwise
861/// or-ing the zero masks. Note: In future the backend _should_ handle this.
862static void mergeConsecutiveTileZerosInBlock(Block *block) {
863 uint32_t mergedZeroMask = 0;
865 auto replaceMergedZeroOps = [&] {
866 auto cleanup = llvm::make_scope_exit([&] {
867 mergedZeroMask = 0;
868 zeroOpsToMerge.clear();
869 });
870 if (zeroOpsToMerge.size() <= 1)
871 return;
872 IRRewriter rewriter(zeroOpsToMerge.front());
873 arm_sme::aarch64_sme_zero::create(
874 rewriter, zeroOpsToMerge.front().getLoc(),
875 rewriter.getI32IntegerAttr(mergedZeroMask));
876 for (auto zeroOp : zeroOpsToMerge)
877 rewriter.eraseOp(zeroOp);
878 };
879 for (Operation &op : *block) {
880 if (auto zeroOp = dyn_cast<arm_sme::aarch64_sme_zero>(op)) {
881 mergedZeroMask |= zeroOp.getTileMask();
882 zeroOpsToMerge.push_back(zeroOp);
883 } else {
884 replaceMergedZeroOps();
885 }
886 }
887 replaceMergedZeroOps();
888}
889
890} // namespace
891
892namespace {
893
894struct ConvertArmSMEToLLVMPass
895 : public impl::ConvertArmSMEToLLVMBase<ConvertArmSMEToLLVMPass> {
896 ConvertArmSMEToLLVMPass(bool dumpTileLiveRanges) {
897 this->dumpTileLiveRanges = dumpTileLiveRanges;
898 }
899 void runOnOperation() override {
900 auto function = getOperation();
901
902 if (failed(arm_sme::allocateSMETiles(function, dumpTileLiveRanges)))
903 return signalPassFailure();
904
905 LLVMConversionTarget target(getContext());
906 RewritePatternSet patterns(&getContext());
907 LLVMTypeConverter converter(&getContext());
910
911 if (failed(applyPartialConversion(function, target, std::move(patterns))))
912 signalPassFailure();
913
914 function->walk(mergeConsecutiveTileZerosInBlock);
915
916 // Walk the function and fail if there are unexpected operations on SME
917 // tile types after conversion.
918 function->walk([&](Operation *op) {
919 // These ops are legal post conversion, skip these.
920 if (isa<arm_sme::CopyTileOp, arm_sme::GetTileOp, cf::BranchOp>(op) ||
921 !op->isRegistered())
922 return;
923 auto isSMETileType = [](Type type) {
925 };
926 if (llvm::any_of(op->getResultTypes(), isSMETileType) ||
927 llvm::any_of(op->getOperandTypes(), isSMETileType)) {
928 op->emitOpError("unexpected operation with SME tile type after "
929 "conversion to LLVM");
930 signalPassFailure();
931 }
932 });
933 }
934};
935
936} // namespace
937
939 target.addIllegalDialect<arm_sme::ArmSMEDialect>();
940 target.addLegalOp<
941 arm_sme::aarch64_sme_zero, arm_sme::aarch64_sme_str,
942 arm_sme::aarch64_sme_ld1b_horiz, arm_sme::aarch64_sme_ld1h_horiz,
943 arm_sme::aarch64_sme_ld1w_horiz, arm_sme::aarch64_sme_ld1d_horiz,
944 arm_sme::aarch64_sme_ld1q_horiz, arm_sme::aarch64_sme_st1b_horiz,
945 arm_sme::aarch64_sme_st1h_horiz, arm_sme::aarch64_sme_st1w_horiz,
946 arm_sme::aarch64_sme_st1d_horiz, arm_sme::aarch64_sme_st1q_horiz,
947 arm_sme::aarch64_sme_ld1b_vert, arm_sme::aarch64_sme_ld1h_vert,
948 arm_sme::aarch64_sme_ld1w_vert, arm_sme::aarch64_sme_ld1d_vert,
949 arm_sme::aarch64_sme_ld1q_vert, arm_sme::aarch64_sme_st1b_vert,
950 arm_sme::aarch64_sme_st1h_vert, arm_sme::aarch64_sme_st1w_vert,
951 arm_sme::aarch64_sme_st1d_vert, arm_sme::aarch64_sme_st1q_vert,
952 arm_sme::aarch64_sme_read_horiz, arm_sme::aarch64_sme_read_vert,
953 arm_sme::aarch64_sme_write_horiz, arm_sme::aarch64_sme_write_vert,
954 arm_sme::aarch64_sme_mopa, arm_sme::aarch64_sme_mopa_wide,
955 arm_sme::aarch64_sme_mops_wide, arm_sme::aarch64_sme_smopa_wide,
956 arm_sme::aarch64_sme_smops_wide, arm_sme::aarch64_sme_umopa_wide,
957 arm_sme::aarch64_sme_umops_wide, arm_sme::aarch64_sme_smopa_za32,
958 arm_sme::aarch64_sme_smops_za32, arm_sme::aarch64_sme_umopa_za32,
959 arm_sme::aarch64_sme_umops_za32, arm_sme::aarch64_sme_sumopa_wide,
960 arm_sme::aarch64_sme_sumops_wide, arm_sme::aarch64_sme_usmopa_wide,
961 arm_sme::aarch64_sme_usmops_wide, arm_sme::aarch64_sme_cntsd>();
962 target.addLegalDialect<arith::ArithDialect,
963 /* The following are used to lower tile spills/fills */
964 vector::VectorDialect, scf::SCFDialect,
965 memref::MemRefDialect>();
966 // Pseudo operations. These cannot be code-generated but may exist in the
967 // input IR, or be generated during the conversion. They need to be eliminated
968 // before the final conversion to LLVM IR (and likely will be due to DCE).
969 target.addLegalOp<arm_sme::GetTileOp, arm_sme::CopyTileOp,
970 UnrealizedConversionCastOp>();
971}
972
975 converter.addConversion([&](VectorType type) -> std::optional<Type> {
976 // There's no LLVM type for SME tiles, but after lowering to intrinsics all
977 // SME vector types should be eliminated.
979 return type;
980 return std::nullopt;
981 });
982
983 addArmSMEConversionPatterns<
984 LoadTileSliceConversion, ExtractTileSliceConversion,
985 InsertTileSliceConversion, StoreTileSliceConversion,
986 StreamingVLOpConversion, OuterProductOpConversion,
987 OuterProductWideningOpConversion<arm_sme::FMopa2WayOp,
988 arm_sme::aarch64_sme_mopa_wide>,
989 OuterProductWideningOpConversion<arm_sme::FMops2WayOp,
990 arm_sme::aarch64_sme_mops_wide>,
991 OuterProductWideningOpConversion<arm_sme::SMopa2WayOp,
992 arm_sme::aarch64_sme_smopa_za32>,
993 OuterProductWideningOpConversion<arm_sme::SMops2WayOp,
994 arm_sme::aarch64_sme_smops_za32>,
995 OuterProductWideningOpConversion<arm_sme::UMopa2WayOp,
996 arm_sme::aarch64_sme_umopa_za32>,
997 OuterProductWideningOpConversion<arm_sme::UMops2WayOp,
998 arm_sme::aarch64_sme_umops_za32>,
999 OuterProductWideningOpConversion<arm_sme::SMopa4WayOp,
1000 arm_sme::aarch64_sme_smopa_wide>,
1001 OuterProductWideningOpConversion<arm_sme::SMops4WayOp,
1002 arm_sme::aarch64_sme_smops_wide>,
1003 OuterProductWideningOpConversion<arm_sme::UMopa4WayOp,
1004 arm_sme::aarch64_sme_umopa_wide>,
1005 OuterProductWideningOpConversion<arm_sme::UMops4WayOp,
1006 arm_sme::aarch64_sme_umops_wide>,
1007 OuterProductWideningOpConversion<arm_sme::SuMopa4WayOp,
1008 arm_sme::aarch64_sme_sumopa_wide>,
1009 OuterProductWideningOpConversion<arm_sme::SuMops4WayOp,
1010 arm_sme::aarch64_sme_sumops_wide>,
1011 OuterProductWideningOpConversion<arm_sme::UsMopa4WayOp,
1012 arm_sme::aarch64_sme_usmopa_wide>,
1013 OuterProductWideningOpConversion<arm_sme::UsMops4WayOp,
1014 arm_sme::aarch64_sme_usmops_wide>,
1015 ZeroOpConversion>(patterns, converter);
1016}
1017
1018std::unique_ptr<Pass>
1019mlir::createConvertArmSMEToLLVMPass(bool dumpTileLiveRanges) {
1020 return std::make_unique<ConvertArmSMEToLLVMPass>(dumpTileLiveRanges);
1021}
return success()
b getContext())
Block represents an ordered list of Operations.
Definition Block.h:33
IntegerAttr getI32IntegerAttr(int32_t value)
Definition Builders.cpp:200
IntegerType getI64Type()
Definition Builders.cpp:65
IntegerType getI32Type()
Definition Builders.cpp:63
IntegerType getI1Type()
Definition Builders.cpp:53
Utility class for operation conversions targeting the LLVM dialect that match exactly one source oper...
Definition Pattern.h:216
Base class for operation conversions targeting the LLVM IR dialect.
Definition Pattern.h:95
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
This class coordinates rewriting a piece of IR outside of a pattern rewrite, providing a way to keep ...
Conversion from types to the LLVM IR dialect.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition Location.h:76
RAII guard to reset the insertion point of the builder when destroyed.
Definition Builders.h:348
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
Definition Builders.h:431
Operation is the basic unit of execution within MLIR.
Definition Operation.h:88
bool isRegistered()
Returns true if this operation has a registered operation description, otherwise false.
Definition Operation.h:129
operand_type_range getOperandTypes()
Definition Operation.h:397
result_type_range getResultTypes()
Definition Operation.h:428
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
This class contains all of the data related to a pattern, but does not contain any methods or logic f...
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
This class provides an abstraction over the different types of ranges over Values.
Definition ValueRange.h:387
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
Type getType() const
Return the type of this value.
Definition Value.h:105
static ConstantIndexOp create(OpBuilder &builder, Location location, int64_t value)
Definition ArithOps.cpp:359
static ConstantIntOp create(OpBuilder &builder, Location location, int64_t value, unsigned width)
Definition ArithOps.cpp:258
Value getStridedElementPtr(OpBuilder &builder, Location loc, const LLVMTypeConverter &converter, MemRefType type, Value memRefDesc, ValueRange indices, LLVM::GEPNoWrapFlags noWrapFlags=LLVM::GEPNoWrapFlags::none)
Performs the index computation to get to the element at indices of the memory pointed to by memRefDes...
Definition Pattern.cpp:471
std::optional< ArmSMETileType > getSMETileType(VectorType)
Returns the type of SME tile this vector type corresponds to, or none if the vector type does not fit...
Definition Utils.cpp:58
LogicalResult allocateSMETiles(FunctionOpInterface function, bool dumpRanges=false)
Allocate tile IDs to all ArmSME operations in a function.
unsigned getSMETileSliceMinNumElts(Type type)
Return minimum number of elements for the given element type in a vector of SVL bits.
Definition Utils.cpp:32
unsigned getSizeInBytes(TypeSize type)
Return the size represented by arm_sme::TypeSize in bytes.
Definition Utils.cpp:17
bool isValidSMETileVectorType(VectorType vType)
Returns true if vType is a valid vector type for an SME tile or false otherwise.
Definition Utils.cpp:43
constexpr unsigned MinStreamingVectorLengthInBits
Definition Utils.h:33
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition Remarks.h:573
Include the generated interface declarations.
std::unique_ptr< Pass > createConvertArmSMEToLLVMPass(bool dumpTileLiveRanges=false)
Create a pass to convert from the ArmSME dialect to LLVM intrinsics.
const FrozenRewritePatternSet & patterns
void configureArmSMEToLLVMConversionLegality(ConversionTarget &target)
Configure target to convert from the ArmSME dialect to LLVM intrinsics.
void populateArmSMEToLLVMConversionPatterns(LLVMTypeConverter &converter, RewritePatternSet &patterns)
Populate the given list with patterns that convert from the ArmSME dialect to LLVM intrinsics.