MLIR 22.0.0git
ArmSMEToLLVM.cpp
Go to the documentation of this file.
1//===- ArmSMEToLLVM.cpp - Convert ArmSME to LLVM dialect ------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements lowering of ArmSME operations to LLVM intrinsics.
10//
11//===----------------------------------------------------------------------===//
12
14
25#include "mlir/Pass/Pass.h"
27#include "llvm/ADT/ScopeExit.h"
28
29namespace mlir {
30#define GEN_PASS_DEF_CONVERTARMSMETOLLVM
31#include "mlir/Conversion/Passes.h.inc"
32} // namespace mlir
33
34using namespace mlir;
35
36namespace {
37
38static constexpr StringLiteral kInMemoryTileIdAttr("arm_sme.in_memory_tile_id");
39
40/// Helper to create an arm_sme.intr.ld1*.(horiz|vert)' intrinsic.
41static Operation *createLoadTileSliceIntrinsic(
42 RewriterBase &rewriter, Location loc, arm_sme::ArmSMETileType type,
43 arm_sme::TileSliceLayout layout, Value maskOp, Value ptr,
44 IntegerAttr tileId, Value tileSliceI32) {
45 if (layout == arm_sme::TileSliceLayout::Horizontal) {
46 switch (type) {
47 case arm_sme::ArmSMETileType::ZAB:
48 return arm_sme::aarch64_sme_ld1b_horiz::create(rewriter, loc, maskOp, ptr,
49 tileId, tileSliceI32);
50 case arm_sme::ArmSMETileType::ZAH:
51 return arm_sme::aarch64_sme_ld1h_horiz::create(rewriter, loc, maskOp, ptr,
52 tileId, tileSliceI32);
53 case arm_sme::ArmSMETileType::ZAS:
54 return arm_sme::aarch64_sme_ld1w_horiz::create(rewriter, loc, maskOp, ptr,
55 tileId, tileSliceI32);
56 case arm_sme::ArmSMETileType::ZAD:
57 return arm_sme::aarch64_sme_ld1d_horiz::create(rewriter, loc, maskOp, ptr,
58 tileId, tileSliceI32);
59 case arm_sme::ArmSMETileType::ZAQ:
60 return arm_sme::aarch64_sme_ld1q_horiz::create(rewriter, loc, maskOp, ptr,
61 tileId, tileSliceI32);
62 }
63 } else {
64 switch (type) {
65 case arm_sme::ArmSMETileType::ZAB:
66 return arm_sme::aarch64_sme_ld1b_vert::create(rewriter, loc, maskOp, ptr,
67 tileId, tileSliceI32);
68 case arm_sme::ArmSMETileType::ZAH:
69 return arm_sme::aarch64_sme_ld1h_vert::create(rewriter, loc, maskOp, ptr,
70 tileId, tileSliceI32);
71 case arm_sme::ArmSMETileType::ZAS:
72 return arm_sme::aarch64_sme_ld1w_vert::create(rewriter, loc, maskOp, ptr,
73 tileId, tileSliceI32);
74 case arm_sme::ArmSMETileType::ZAD:
75 return arm_sme::aarch64_sme_ld1d_vert::create(rewriter, loc, maskOp, ptr,
76 tileId, tileSliceI32);
77 case arm_sme::ArmSMETileType::ZAQ:
78 return arm_sme::aarch64_sme_ld1q_vert::create(rewriter, loc, maskOp, ptr,
79 tileId, tileSliceI32);
80 break;
81 }
82 }
83 llvm_unreachable("unknown type in createLoadTileSliceIntrinsic");
84}
85
86/// Helper to create an arm_sme.intr.st1*.(horiz|vert)' intrinsic.
87static Operation *createStoreTileSliceIntrinsic(
88 RewriterBase &rewriter, Location loc, arm_sme::ArmSMETileType type,
89 arm_sme::TileSliceLayout layout, Value maskOp, Value ptr,
90 IntegerAttr tileId, Value tileSliceI32) {
91 if (layout == arm_sme::TileSliceLayout::Horizontal) {
92 switch (type) {
93 case arm_sme::ArmSMETileType::ZAB:
94 return arm_sme::aarch64_sme_st1b_horiz::create(rewriter, loc, maskOp, ptr,
95 tileId, tileSliceI32);
96 case arm_sme::ArmSMETileType::ZAH:
97 return arm_sme::aarch64_sme_st1h_horiz::create(rewriter, loc, maskOp, ptr,
98 tileId, tileSliceI32);
99 case arm_sme::ArmSMETileType::ZAS:
100 return arm_sme::aarch64_sme_st1w_horiz::create(rewriter, loc, maskOp, ptr,
101 tileId, tileSliceI32);
102 case arm_sme::ArmSMETileType::ZAD:
103 return arm_sme::aarch64_sme_st1d_horiz::create(rewriter, loc, maskOp, ptr,
104 tileId, tileSliceI32);
105 case arm_sme::ArmSMETileType::ZAQ:
106 return arm_sme::aarch64_sme_st1q_horiz::create(rewriter, loc, maskOp, ptr,
107 tileId, tileSliceI32);
108 }
109 } else {
110 switch (type) {
111 case arm_sme::ArmSMETileType::ZAB:
112 return arm_sme::aarch64_sme_st1b_vert::create(rewriter, loc, maskOp, ptr,
113 tileId, tileSliceI32);
114 case arm_sme::ArmSMETileType::ZAH:
115 return arm_sme::aarch64_sme_st1h_vert::create(rewriter, loc, maskOp, ptr,
116 tileId, tileSliceI32);
117 case arm_sme::ArmSMETileType::ZAS:
118 return arm_sme::aarch64_sme_st1w_vert::create(rewriter, loc, maskOp, ptr,
119 tileId, tileSliceI32);
120 case arm_sme::ArmSMETileType::ZAD:
121 return arm_sme::aarch64_sme_st1d_vert::create(rewriter, loc, maskOp, ptr,
122 tileId, tileSliceI32);
123 case arm_sme::ArmSMETileType::ZAQ:
124 return arm_sme::aarch64_sme_st1q_vert::create(rewriter, loc, maskOp, ptr,
125 tileId, tileSliceI32);
126 }
127 }
128 llvm_unreachable("unknown type in createStoreTileSliceIntrinsic");
129}
130
131IntegerAttr getTileIdOrError(arm_sme::ArmSMETileOpInterface op) {
132 auto tileId = op.getTileId();
133 if (!tileId)
134 op.emitOpError(
135 "expected tile ID to be allocated before conversion to LLVM");
136 return tileId;
137}
138
139/// Creates an alloca matching the size of tile used by `tileOp`. The alloca is
140/// placed in the first block of the function.
141static memref::AllocaOp
142createAllocaForTile(RewriterBase &rewriter, Location loc,
143 FunctionOpInterface func,
146 // Move to the first operation in the function.
147 rewriter.setInsertionPointToStart(&func.getBlocks().front());
148 // Create an alloca matching the tile size of the `tileOp`.
149 auto vscale = vector::VectorScaleOp::create(rewriter, loc);
150 auto tileElementType = tileOp.getTileType().getElementType();
151 auto memrefType = MemRefType::get(
152 {ShapedType::kDynamic, ShapedType::kDynamic}, tileElementType);
153 unsigned minElements = arm_sme::getSMETileSliceMinNumElts(tileElementType);
154 auto minElementsOp =
155 arith::ConstantIndexOp::create(rewriter, loc, minElements);
156 auto vectorLen = arith::MulIOp::create(rewriter, loc, vscale, minElementsOp);
157 auto alloca = memref::AllocaOp::create(rewriter, loc, memrefType,
158 ValueRange{vectorLen, vectorLen});
159 return alloca;
160}
161
162/// Finds or creates an alloca for a spill of a tile.
163static memref::AllocaOp getOrCreateAllocaForTile(
164 RewriterBase &rewriter, Location loc, FunctionOpInterface func,
165 arm_sme::ArmSMETileOpInterface tileOp, unsigned tileId) {
166 // Find an alloca at the top of the function tagged with a
167 // 'arm_sme.in_memory_tile_id' that matches `tileId`.
168 for (auto &op : func.getBlocks().front()) {
169 auto alloca = llvm::dyn_cast<memref::AllocaOp>(op);
170 if (!alloca)
171 continue;
172 auto inMemoryTileId = llvm::dyn_cast_or_null<IntegerAttr>(
173 alloca->getDiscardableAttr(kInMemoryTileIdAttr));
174 if (!inMemoryTileId)
175 continue;
176 if (inMemoryTileId.getInt() == tileId)
177 return alloca;
178 }
179 // Otherwise, create a new alloca:
180 auto alloca = createAllocaForTile(rewriter, loc, func, tileOp);
181 alloca->setDiscardableAttr(kInMemoryTileIdAttr,
182 rewriter.getI32IntegerAttr(tileId));
183 return alloca;
184}
185
186/// Very naive lowering of in-memory tiles (i.e. tiles that were not assigned a
187/// hardware tile ID) to ArmSME intrinsics. Currently, this works by assigning
188/// the op to tile 0, then emitting a full tile swap between ZA and memory
189/// before + after the tile op.
190///
191/// Example:
192///
193/// // Note: <IN MEMORY TILE> = tile ID >= 16.
194/// arm_sme.tile_op { tile_id = <IN MEMORY TILE> }
195///
196/// is converted to:
197/// // At function entry:
198/// %spill = memref.alloca ... : memref<?x?xty>
199///
200/// // Around op:
201/// scf.for %slice_idx {
202/// %slice_to_save = "arm_sme.intr.read.horiz" ... <{tile_id = 0 : i32}>
203/// "arm_sme.intr.ld1h.horiz"(%spill, %slice_idx) <{tile_id = 0 : i32}>
204/// vector.store %slice_to_save, %spill[%slice_idx, %c0]
205/// }
206/// arm_sme.tile_op { tile_id = 0 }
207/// scf.for %slice_idx {
208/// %slice_to_save = "arm_sme.intr.read.horiz" ... <{tile_id = 0 : i32}>
209/// "arm_sme.intr.ld1h.horiz"(%spill, %slice_idx) <{tile_id = 0 : i32}>
210/// vector.store %slice_to_save, %spill[%slice_idx, %c0]
211/// }
212///
213/// Note that these spills/fills are not inserted earlier as concept of a
214/// register, and the need to swap the contents, can't really be represented
215/// correctly at a high level in MLIR.
216///
217/// TODO: Reduce the spills/reloads to single slices where possible (and omit
218/// redundant reloads). This could be done via a method on the
219/// `ArmSMETileOpInterface` which returns how the operation uses ZA. E.g.:
220///
221/// `tileOp.getZaUsage()` could return:
222///
223/// struct ArmSMEOpZAUsage {
224/// enum class Kind {
225/// TileRead, // Omit store after tile operation.
226/// TileWrite, // Omit load before tile operation.
227/// TileReadWrite, // Needs both tile load and store.
228/// SliceRead, // Spill single slice and omit store after operation.
229/// SliceWrite, // Spill single slice and omit load before operation.
230/// SliceReadWrite // Spill single slice.
231/// };
232/// Value sliceIndex {};
233/// TileSliceLayout sliceLayout { TileSliceLayout::Horizontal };
234/// };
235///
236struct ConvertArmSMESpillsAndFillsToLLVM : public ConvertToLLVMPattern {
237
238 ConvertArmSMESpillsAndFillsToLLVM(StringRef rootOpName,
239 const LLVMTypeConverter &typeConverter,
240 PatternBenefit benefit)
241 : ConvertToLLVMPattern(rootOpName, &typeConverter.getContext(),
242 typeConverter, benefit) {}
243
244 LogicalResult
245 matchAndRewrite(Operation *op, ArrayRef<Value> operands,
246 ConversionPatternRewriter &rewriter) const override {
247 auto tileOp = cast<arm_sme::ArmSMETileOpInterface>(op);
248 // Tile has a real (hardware) tile. No spills/reloads required.
249 if (!tileOp.isInMemoryTile())
250 return failure();
251
252 tileOp->emitWarning(
253 "failed to allocate SME virtual tile to operation, tile value will go "
254 "through memory, expect degraded performance");
255
256 // Step 1. Create an alloca for the tile at the top of the function (if one
257 // does not already exist).
258 auto loc = tileOp.getLoc();
259 auto func = tileOp->getParentOfType<FunctionOpInterface>();
260 auto tileAlloca = getOrCreateAllocaForTile(rewriter, loc, func, tileOp,
261 tileOp.getTileId().getInt());
262
263 // Step 2. Assign the op a real tile ID.
264 // For simplicity, we always use tile 0 (which always exists).
265 auto zeroTileId = rewriter.getI32IntegerAttr(0);
266 rewriter.modifyOpInPlace(tileOp, [&] { tileOp.setTileId(zeroTileId); });
267
268 VectorType tileVectorType = tileOp.getTileType();
269 auto sliceType = VectorType::Builder(tileVectorType).dropDim(0);
270 auto swapInMemoryTileWithSMETileZero = [&] {
271 emitFullTileSwap(rewriter, loc, tileAlloca,
272 *arm_sme::getSMETileType(tileVectorType), sliceType,
273 zeroTileId);
274 };
275
276 // Step 3. Emit tile swaps before and after the op.
277 // TODO: Reduce the amount spilled to the amount of data the `tileOp`
278 // touches (i.e. a single tile slice).
279 {
280 rewriter.setInsertionPoint(op);
281 // Swap the contents of ZA and the in-memory tile before the op.
282 swapInMemoryTileWithSMETileZero();
283 rewriter.setInsertionPointAfter(op);
284 // Swap the tile back out to memory again after the op.
285 swapInMemoryTileWithSMETileZero();
286 }
287
288 return success();
289 }
290
291 /// Extracts a pointer to a slice of an in-memory tile.
292 Value getInMemoryTileSlicePtr(RewriterBase &rewriter, Location loc,
293 Value tileMemory, Value sliceIndex) const {
294 auto llvmType = getTypeConverter()->convertType(tileMemory.getType());
295 auto descriptor =
296 UnrealizedConversionCastOp::create(rewriter, loc, llvmType, tileMemory);
297 auto zero = arith::ConstantIntOp::create(rewriter, loc, 0, /*width=*/64);
298 auto sliceIndexI64 = arith::IndexCastOp::create(
299 rewriter, loc, rewriter.getI64Type(), sliceIndex);
301 static_cast<ConversionPatternRewriter &>(rewriter), loc,
302 llvm::cast<MemRefType>(tileMemory.getType()), descriptor.getResult(0),
303 {sliceIndexI64, zero});
304 }
305
306 /// Emits an in-place swap of a slice of a tile in ZA and a slice of a
307 /// tile-sized memref (`tileAlloca`).
308 void emitSliceSwap(RewriterBase &rewriter, Location loc, Value tileAlloca,
309 arm_sme::ArmSMETileType tileType, VectorType sliceType,
310 IntegerAttr tileId, Value sliceIndex) const {
311 // Cast the slice index to an i32.
312 auto sliceIndexI32 = arith::IndexCastOp::create(
313 rewriter, loc, rewriter.getI32Type(), sliceIndex);
314 // Create an all-true predicate for the slice.
315 auto predicateType = sliceType.clone(rewriter.getI1Type());
316 auto allTruePredicate = arith::ConstantOp::create(
317 rewriter, loc, DenseElementsAttr::get(predicateType, true));
318 // Create padding vector (never used due to all-true predicate).
319 auto padVector = LLVM::PoisonOp::create(rewriter, loc, sliceType);
320 // Get a pointer to the current slice.
321 auto slicePtr =
322 getInMemoryTileSlicePtr(rewriter, loc, tileAlloca, sliceIndex);
323 // Read the value of the current slice from ZA.
324 auto currentTileSlice = arm_sme::aarch64_sme_read_horiz::create(
325 rewriter, loc, sliceType, padVector, allTruePredicate, tileId,
326 sliceIndexI32);
327 // Load the new tile slice back from memory into ZA.
328 createLoadTileSliceIntrinsic(
329 rewriter, loc, tileType, arm_sme::TileSliceLayout::Horizontal,
330 allTruePredicate, slicePtr, tileId, sliceIndexI32);
331 // Store the current tile slice to memory.
332 auto zero = arith::ConstantIndexOp::create(rewriter, loc, 0);
333 vector::StoreOp::create(rewriter, loc, currentTileSlice, tileAlloca,
334 ValueRange{sliceIndex, zero});
335 }
336
337 /// Emits a full in-place swap of the contents of a tile in ZA and a
338 /// tile-sized memref (`tileAlloca`).
339 void emitFullTileSwap(RewriterBase &rewriter, Location loc, Value tileAlloca,
340 arm_sme::ArmSMETileType tileType, VectorType sliceType,
341 IntegerAttr tileId) const {
342 RewriterBase::InsertionGuard guard(rewriter);
343 // Create an scf.for over all tile slices.
344 auto minNumElts =
345 arith::ConstantIndexOp::create(rewriter, loc, sliceType.getDimSize(0));
346 auto lowerBound = arith::ConstantIndexOp::create(rewriter, loc, 0);
347 auto upperBound =
348 arith::MulIOp::create(rewriter, loc, minNumElts,
349 vector::VectorScaleOp::create(rewriter, loc));
350 auto step = arith::ConstantIndexOp::create(rewriter, loc, 1);
351 auto forOp =
352 scf::ForOp::create(rewriter, loc, lowerBound, upperBound, step);
353 // Emit a swap for each tile slice.
354 rewriter.setInsertionPointToStart(forOp.getBody());
355 auto sliceIndex = forOp.getInductionVar();
356 emitSliceSwap(rewriter, loc, tileAlloca, tileType, sliceType, tileId,
357 sliceIndex);
358 }
359};
360
361enum class RequiresSpillsAndFills { Yes, No };
362
363/// Base class for ArmSME to LLVM conversion patterns. By default, this adds
364/// spills and fills around ArmSME ops that use in-memory tile IDs. This can be
365/// disabled by setting the `requiresSpillsAndFills` template parameter to
366/// `RequiresSpillsAndFills::No`.
367template <typename SourceOp, RequiresSpillsAndFills requiresSpillsAndFills =
368 RequiresSpillsAndFills::Yes>
369struct ConvertArmSMEOpToLLVMPattern : ConvertOpToLLVMPattern<SourceOp> {
370 using ArmSMEOp = SourceOp;
371 using ConvertOpToLLVMPattern<SourceOp>::ConvertOpToLLVMPattern;
372
373 static constexpr bool requiresSpillsAndFillsConversion() {
374 return requiresSpillsAndFills == RequiresSpillsAndFills::Yes;
375 }
376};
377
378template <typename Pattern>
379static void addArmSMEConversionPattern(RewritePatternSet &patterns,
380 LLVMTypeConverter const &typeConverter) {
381 // Register spills/fills for ops that implement the
382 // `ArmSMETileOpInterface` and have `requiresSpillsAndFills` set to
383 // `RequiresSpillsAndFills::Yes`.
384 if constexpr (Pattern::requiresSpillsAndFillsConversion() &&
386 typename Pattern::ArmSMEOp>,
387 typename Pattern::ArmSMEOp>) {
388 // Add spill/fill conversions with a very high benefit to ensure
389 // they are lowered first.
390 patterns.add<ConvertArmSMESpillsAndFillsToLLVM>(
391 Pattern::ArmSMEOp::getOperationName(), typeConverter,
392 /*benefit=*/1337);
393 }
394 patterns.add<Pattern>(typeConverter);
395}
396
397/// Helper to register `ConvertArmSMEOpToLLVMPattern` patterns.
398template <typename... Patterns>
399static void
400addArmSMEConversionPatterns(RewritePatternSet &patterns,
401 LLVMTypeConverter const &typeConverter) {
402 (addArmSMEConversionPattern<Patterns>(patterns, typeConverter), ...);
403}
404
405/// Lower 'arm_sme.zero' to SME intrinsics.
406///
407/// BEFORE:
408/// ```mlir
409/// %v = arm_sme.zero {tile_id = 0 : i32} : vector<[4]x[4]xi32>
410/// ```
411///
412/// AFTER:
413/// ```mlir
414/// "arm_sme.intr.zero"() <{tile_mask = 17 : i32}> : () -> ()
415/// %v = arm_sme.get_tile : vector<[4]x[4]xi32>
416/// ```
417///
418/// The 'arm_sme.get_tile' (which models the return) will fold away once all
419/// ArmSME ops have been converted to LLVM intrinsics.
420struct ZeroOpConversion : public ConvertArmSMEOpToLLVMPattern<arm_sme::ZeroOp> {
421 using ConvertArmSMEOpToLLVMPattern::ConvertArmSMEOpToLLVMPattern;
422
423 LogicalResult
424 matchAndRewrite(arm_sme::ZeroOp zero, OpAdaptor adaptor,
425 ConversionPatternRewriter &rewriter) const override {
426 auto loc = zero.getLoc();
427
428 auto tileId = getTileIdOrError(zero);
429 if (!tileId)
430 return failure();
431
432 // Get the base mask for tile based on the element size.
433 // The base mask is just the mask to zero the first tile (of a size).
434 // These masks are derived from:
435 // https://developer.arm.com/documentation/ddi0602/2022-06/SME-Instructions/ZERO--Zero-a-list-of-64-bit-element-ZA-tiles-
436 arm_sme::ArmSMETileType tileType =
437 *arm_sme::getSMETileType(zero.getTileType());
438 auto baseMaskForSize = [&] {
439 switch (tileType) {
440 case arm_sme::ArmSMETileType::ZAB:
441 // Zeroing the 8-bit ZA0.B tile is equivalent to zeroing all eight
442 // 64-bit element tiles named ZA0.D to ZA7.D.
443 return 0b1111'1111;
444 case arm_sme::ArmSMETileType::ZAH:
445 // Zeroing the 16-bit ZA0.H tile is equivalent to zeroing 64-bit
446 // element tiles named ZA0.D, ZA2.D, ZA4.D, and ZA6.D. Shift this left
447 // once for ZA1.H.
448 return 0b0101'0101;
449 case arm_sme::ArmSMETileType::ZAS:
450 // Zeroing the 32-bit ZA0.S tile is equivalent to zeroing 64-bit
451 // element tiles named ZA0.D and ZA4.D.
452 // Shift left by 1, 2, or 3 respectively for ZA1.S, ZA2.S, ZA3.S.
453 return 0b0001'0001;
454 case arm_sme::ArmSMETileType::ZAD:
455 // Zeroing one of the a 64-bit tiles ZA0.D to ZA7.D just requires
456 // setting the bit for that tile.
457 return 0b0000'0001;
458 default:
459 llvm_unreachable("bad element size");
460 }
461 }();
462
463 // The actual mask is just the base mask shifted by the tile ID.
464 // This will be folded to a constant after tile allocation.
465 //
466 // The shift is just derived from the layout of the tiles, and that the tile
467 // ID is the index of the tile. For example, looking at the 32-bit ZAx.S
468 // tiles:
469 //
470 // ZA0.S = ZA0.D and ZA4.D
471 // * Tile ID -> 0
472 // * Mask -> 00010001 = (00010001 << 0)
473 // ZA1.S = ZA1.D and ZA5.D
474 // * Tile ID -> 1
475 // * Mask -> 00100010 = (00010001 << 1)
476 // ZA2.S = ZA2.D and ZA6.D
477 // * Tile ID -> 2
478 // * Mask -> 01000100 = (00010001 << 2)
479 // ZA3.S = ZA3.D and ZA7.D
480 // * Tile ID -> 3
481 // * Mask -> 10001000 = (00010001 << 3)
482 //
483 // This holds for all tile sizes.
484 int32_t zeroMask = baseMaskForSize << int32_t(tileId.getInt());
485 arm_sme::aarch64_sme_zero::create(rewriter, loc,
486 rewriter.getI32IntegerAttr(zeroMask));
487
488 // Create a placeholder op to preserve dataflow.
489 // Note: Place the `get_tile` op at the start of the block. This ensures
490 // that if there are multiple `zero` ops the intrinsics will be consecutive.
491 rewriter.setInsertionPointToStart(zero->getBlock());
492 rewriter.replaceOpWithNewOp<arm_sme::GetTileOp>(zero, zero.getVectorType());
493
494 return success();
495 }
496};
497
498/// Lower `arm_sme.load_tile_slice` to SME intrinsics.
499struct LoadTileSliceConversion
500 : public ConvertArmSMEOpToLLVMPattern<arm_sme::LoadTileSliceOp> {
501 using ConvertArmSMEOpToLLVMPattern::ConvertArmSMEOpToLLVMPattern;
502
503 LogicalResult
504 matchAndRewrite(arm_sme::LoadTileSliceOp loadTileSliceOp,
505 arm_sme::LoadTileSliceOp::Adaptor adaptor,
506 ConversionPatternRewriter &rewriter) const override {
507 auto loc = loadTileSliceOp.getLoc();
508 auto tileId = getTileIdOrError(loadTileSliceOp);
509 if (!tileId)
510 return failure();
511
512 Value ptr = this->getStridedElementPtr(
513 rewriter, loc, loadTileSliceOp.getMemRefType(), adaptor.getBase(),
514 adaptor.getIndices());
515
516 auto tileSlice = loadTileSliceOp.getTileSliceIndex();
517
518 // Cast tile slice to i32 for intrinsic.
519 auto tileSliceI32 = arith::IndexCastUIOp::create(
520 rewriter, loc, rewriter.getI32Type(), tileSlice);
521
522 // Create all active predicate mask.
523 auto maskOp = loadTileSliceOp.getMask();
524
525 auto tileVectorType = loadTileSliceOp.getVectorType();
526 arm_sme::ArmSMETileType tileType = *arm_sme::getSMETileType(tileVectorType);
527 arm_sme::TileSliceLayout layout = loadTileSliceOp.getLayout();
528
529 // Create 'arm_sme.intr.ld1*.(horiz|vert)' intrinsic to load ZA tile slice.
530 createLoadTileSliceIntrinsic(rewriter, loc, tileType, layout, maskOp, ptr,
531 tileId, tileSliceI32);
532
533 // The load intrinsics have no result, replace 'arm_sme.tile_load' with
534 // the input tile to preserve dataflow.
535 rewriter.replaceOp(loadTileSliceOp, loadTileSliceOp.getTile());
536
537 return success();
538 }
539};
540
541/// Lower for `arm_sme.store_tile_slice` to SME intrinsics.
542struct StoreTileSliceConversion
543 : public ConvertArmSMEOpToLLVMPattern<arm_sme::StoreTileSliceOp> {
544 using ConvertArmSMEOpToLLVMPattern::ConvertArmSMEOpToLLVMPattern;
545
546 LogicalResult
547 matchAndRewrite(arm_sme::StoreTileSliceOp storeTileSliceOp,
548 arm_sme::StoreTileSliceOp::Adaptor adaptor,
549 ConversionPatternRewriter &rewriter) const override {
550 auto loc = storeTileSliceOp.getLoc();
551 auto tileVectorType = storeTileSliceOp.getVectorType();
552
553 auto tileId = getTileIdOrError(storeTileSliceOp);
554 if (!tileId)
555 return failure();
556
557 // Create 'arm_sme.intr.st1*.horiz' intrinsic to store ZA tile slice.
558 Value ptr = this->getStridedElementPtr(
559 rewriter, loc, storeTileSliceOp.getMemRefType(), adaptor.getBase(),
560 adaptor.getIndices());
561
562 auto tileSlice = storeTileSliceOp.getTileSliceIndex();
563
564 // Cast tile slice to i32 for intrinsic.
565 auto tileSliceI32 = arith::IndexCastUIOp::create(
566 rewriter, loc, rewriter.getI32Type(), tileSlice);
567
568 auto maskOp = storeTileSliceOp.getMask();
569
570 arm_sme::TileSliceLayout layout = storeTileSliceOp.getLayout();
571 arm_sme::ArmSMETileType tileType = *arm_sme::getSMETileType(tileVectorType);
572
573 rewriter.replaceOp(storeTileSliceOp,
574 createStoreTileSliceIntrinsic(rewriter, loc, tileType,
575 layout, maskOp, ptr,
576 tileId, tileSliceI32));
577
578 return success();
579 }
580};
581
582/// Lower `arm_sme.insert_tile_slice` to SME intrinsics.
583struct InsertTileSliceConversion
584 : public ConvertArmSMEOpToLLVMPattern<arm_sme::InsertTileSliceOp> {
585 using ConvertArmSMEOpToLLVMPattern::ConvertArmSMEOpToLLVMPattern;
586
587 LogicalResult
588 matchAndRewrite(arm_sme::InsertTileSliceOp insertTileSliceOp,
589 arm_sme::InsertTileSliceOp::Adaptor adaptor,
590 ConversionPatternRewriter &rewriter) const override {
591 auto loc = insertTileSliceOp.getLoc();
592 auto tileType = insertTileSliceOp.getTileType();
593
594 auto tileId = getTileIdOrError(insertTileSliceOp);
595 if (!tileId)
596 return failure();
597
598 auto tileSlice = insertTileSliceOp.getTileSliceIndex();
599
600 // Cast tile slice from index to i32 for intrinsic.
601 auto tileSliceI32 = arith::IndexCastUIOp::create(
602 rewriter, loc, rewriter.getI32Type(), tileSlice);
603
604 // Create all active predicate mask.
605 auto one = arith::ConstantOp::create(
606 rewriter, loc, rewriter.getI1Type(),
607 rewriter.getIntegerAttr(rewriter.getI1Type(), 1));
608 auto predTy = VectorType::get(tileType.getShape()[0], rewriter.getI1Type(),
609 /*scalableDims=*/{true});
610 auto allActiveMask =
611 vector::BroadcastOp::create(rewriter, loc, predTy, one);
612
613 // Create 'arm_sme.intr.write.(horiz|vert)' to write vector to tile slice.
614 switch (insertTileSliceOp.getLayout()) {
615 case arm_sme::TileSliceLayout::Horizontal:
616 arm_sme::aarch64_sme_write_horiz::create(rewriter, loc, tileId,
617 tileSliceI32, allActiveMask,
618 insertTileSliceOp.getVector());
619 break;
620 case arm_sme::TileSliceLayout::Vertical:
621 arm_sme::aarch64_sme_write_vert::create(rewriter, loc, tileId,
622 tileSliceI32, allActiveMask,
623 insertTileSliceOp.getVector());
624 break;
625 }
626
627 // Intrinsic has no result, replace 'arm_sme.insert_tile_slice' with
628 // the input tile to preserve dataflow.
629 rewriter.replaceOp(insertTileSliceOp, insertTileSliceOp.getTile());
630
631 return success();
632 }
633};
634
635/// Lower `arm_sme.extract_tile_slice` to SME intrinsics.
636struct ExtractTileSliceConversion
637 : public ConvertArmSMEOpToLLVMPattern<arm_sme::ExtractTileSliceOp> {
638 using ConvertArmSMEOpToLLVMPattern::ConvertArmSMEOpToLLVMPattern;
639
640 LogicalResult
641 matchAndRewrite(arm_sme::ExtractTileSliceOp extractTileSlice, OpAdaptor,
642 ConversionPatternRewriter &rewriter) const override {
643 auto loc = extractTileSlice.getLoc();
644 auto sliceType = extractTileSlice.getSliceType();
645 auto sliceIndex = extractTileSlice.getTileSliceIndex();
646
647 auto tileId = getTileIdOrError(extractTileSlice);
648 if (!tileId)
649 return failure();
650
651 // Create an 'all true' predicate for the tile slice.
652 auto predicateType = sliceType.cloneWith({}, rewriter.getI1Type());
653 auto allTruePredicate = arith::ConstantOp::create(
654 rewriter, loc, DenseElementsAttr::get(predicateType, true));
655
656 // Zero destination/fallback for tile slice extraction.
657 auto zeroVector = arith::ConstantOp::create(
658 rewriter, loc, sliceType, rewriter.getZeroAttr(sliceType));
659
660 // Cast tile slice from index to i32 for intrinsic.
661 auto sliceIndexI32 = arith::IndexCastOp::create(
662 rewriter, loc, rewriter.getI32Type(), sliceIndex);
663
664 // Create 'arm_sme.intr.read.(horiz|vert)' to extract the tile slice.
665 switch (extractTileSlice.getLayout()) {
666 case arm_sme::TileSliceLayout::Horizontal:
667 rewriter.replaceOpWithNewOp<arm_sme::aarch64_sme_read_horiz>(
668 extractTileSlice, sliceType, zeroVector, allTruePredicate, tileId,
669 sliceIndexI32);
670 break;
671 case arm_sme::TileSliceLayout::Vertical:
672 rewriter.replaceOpWithNewOp<arm_sme::aarch64_sme_read_vert>(
673 extractTileSlice, sliceType, zeroVector, allTruePredicate, tileId,
674 sliceIndexI32);
675 break;
676 }
677
678 return success();
679 }
680};
681
682/// Lower `arm_sme.outerproduct` to SME MOPA intrinsics.
683///
684/// Example:
685///
686/// %0 = arm_sme.outerproduct %lhs, %rhs acc(%acc)
687/// : vector<[4]xf32>, vector<[4]xf32>
688///
689/// is converted to:
690///
691/// "arm_sme.intr.mopa"(%ptrue_s, %ptrue_s, %lhs, %rhs) <{tile_id = 0 : i32}>
692/// : (vector<[4]xi1>, vector<[4]xi1>, vector<[4]xf32>,
693/// vector<[4]xf32>) -> ()
694///
695/// Currently only supports FMOPA and BFMOPA (non-widening).
696struct OuterProductOpConversion
697 : public ConvertArmSMEOpToLLVMPattern<arm_sme::OuterProductOp> {
698 using ConvertArmSMEOpToLLVMPattern::ConvertArmSMEOpToLLVMPattern;
699
700 LogicalResult
701 matchAndRewrite(arm_sme::OuterProductOp outerProductOp,
702 arm_sme::OuterProductOp::Adaptor adaptor,
703 ConversionPatternRewriter &rewriter) const override {
704 auto tileId = getTileIdOrError(outerProductOp);
705 if (!tileId)
706 return failure();
707
708 auto isSupportedType = [](VectorType vectorType) {
709 // TODO: the FP outer product instruction variants are predicated on
710 // different features [1]:
711 //
712 // * FMOPA (non-widening)
713 // * half-precision - +sme2p1,+sme-f16f16
714 // * single-precision - +sme
715 // * double-precision - +sme-f64f64
716 // * BFMOPA
717 // * half-precision - +sme2p1,+b16b16
718 //
719 // It should be possible to control lowering based on target features.
720 // [1]
721 // https://developer.arm.com/downloads/-/exploration-tools/feature-names-for-a-profile
722 if ((vectorType.getRank() != 2) || !vectorType.allDimsScalable())
723 return false;
724
725 auto elementType = vectorType.getElementType();
726
727 if (!elementType.isF16() && !elementType.isBF16() &&
728 !elementType.isF32() && !elementType.isF64())
729 return false;
730
731 unsigned minNumElts = arm_sme::MinStreamingVectorLengthInBits /
732 vectorType.getElementTypeBitWidth();
733 return vectorType.getShape() ==
734 ArrayRef<int64_t>({minNumElts, minNumElts});
735 };
736
737 // TODO: Support CombiningKind::Sub for outer products.
738 if (outerProductOp.getKind() != arm_sme::CombiningKind::Add)
739 return outerProductOp.emitError("unsupported kind");
740
741 auto resultVectorType = outerProductOp.getResultType();
742 if (!isSupportedType(resultVectorType))
743 return outerProductOp.emitError("unsupported type");
744
745 auto loc = outerProductOp.getLoc();
746
747 Value acc = outerProductOp.getAcc();
748 if (!acc) {
749 // Initalize accumulator with zero.
750 auto zero = arm_sme::ZeroOp::create(rewriter, loc, resultVectorType);
751 zero.setTileId(tileId);
752 acc = zero;
753 }
754
755 Value lhsMask = outerProductOp.getLhsMask();
756 Value rhsMask = outerProductOp.getRhsMask();
757
758 if (!lhsMask || !rhsMask) {
759 auto predTy =
760 outerProductOp.getLhsType().cloneWith({}, rewriter.getI1Type());
761 Value allActiveMask = arith::ConstantOp::create(
762 rewriter, loc, DenseElementsAttr::get(predTy, true));
763 lhsMask = allActiveMask;
764 rhsMask = allActiveMask;
765 }
766
767 // Create 'arm_sme.intr.mopa' outer product intrinsic.
768 arm_sme::aarch64_sme_mopa::create(rewriter, loc, tileId, lhsMask, rhsMask,
769 outerProductOp.getLhs(),
770 outerProductOp.getRhs());
771
772 // The outerproduct intrinsics have no result, replace
773 // 'arm_sme.outerproduct' with the input tile to preserve dataflow.
774 rewriter.replaceOp(outerProductOp, acc);
775
776 return success();
777 }
778};
779
780/// Lower 2-way and 4-way widening outer products to intrinsics.
781template <class OuterProductWideningOp, class OuterProductWideningIntrOp>
782struct OuterProductWideningOpConversion
783 : public ConvertArmSMEOpToLLVMPattern<OuterProductWideningOp> {
784 using ConvertArmSMEOpToLLVMPattern<
785 OuterProductWideningOp>::ConvertArmSMEOpToLLVMPattern;
786
787 LogicalResult
788 matchAndRewrite(OuterProductWideningOp op,
789 typename OuterProductWideningOp::Adaptor adaptor,
790 ConversionPatternRewriter &rewriter) const override {
791 auto tileId = getTileIdOrError(op);
792 if (!tileId)
793 return failure();
794
795 auto loc = op.getLoc();
796 Value acc = op.getAcc();
797 if (!acc) {
798 // Initalize accumulator with zero.
799 auto zero = arm_sme::ZeroOp::create(rewriter, loc, op.getResultType());
800 zero.setTileId(tileId);
801 acc = zero;
802 }
803
804 Value lhsMask = op.getLhsMask();
805 Value rhsMask = op.getRhsMask();
806 if (!lhsMask || !rhsMask) {
807 auto predTy = op.getLhsType().cloneWith({}, rewriter.getI1Type());
808 Value allActiveMask = arith::ConstantOp::create(
809 rewriter, loc, DenseElementsAttr::get(predTy, true));
810 lhsMask = allActiveMask;
811 rhsMask = allActiveMask;
812 }
813
814 OuterProductWideningIntrOp::create(rewriter, loc, tileId, lhsMask, rhsMask,
815 adaptor.getLhs(), adaptor.getRhs());
816
817 // The outerproduct intrinsics have no result, replace
818 // 'arm_sme.outerproduct' with the input tile to preserve dataflow.
819 rewriter.replaceOp(op, acc);
820
821 return success();
822 }
823};
824
825/// Lower `arm_sme.streaming_vl` to SME CNTSD intrinsic.
826///
827/// Example:
828///
829/// %0 = arm_sme.streaming_vl <half>
830///
831/// is converted to:
832///
833/// %cnt = "arm_sme.intr.cntsd"() : () -> i64
834/// %scale = arith.constant 4 : index
835/// %cntIndex = arith.index_cast %cnt : i64 to index
836/// %0 = arith.muli %cntIndex, %scale : index
837///
838struct StreamingVLOpConversion
839 : public ConvertArmSMEOpToLLVMPattern<arm_sme::StreamingVLOp,
840 RequiresSpillsAndFills::No> {
841 using ConvertArmSMEOpToLLVMPattern::ConvertArmSMEOpToLLVMPattern;
842
843 LogicalResult
844 matchAndRewrite(arm_sme::StreamingVLOp streamingVlOp,
845 arm_sme::StreamingVLOp::Adaptor adaptor,
846 ConversionPatternRewriter &rewriter) const override {
847 auto loc = streamingVlOp.getLoc();
848 auto i64Type = rewriter.getI64Type();
849 auto cntsd = arm_sme::aarch64_sme_cntsd::create(rewriter, loc, i64Type);
850 auto cntsdIdx = arith::IndexCastOp::create(rewriter, loc,
851 rewriter.getIndexType(), cntsd);
853 rewriter, loc,
854 8 / arm_sme::getSizeInBytes(streamingVlOp.getTypeSize()));
855 rewriter.replaceOpWithNewOp<arith::MulIOp>(streamingVlOp, cntsdIdx, scale);
856 return success();
857 }
858};
859
860/// Merges consecutive `arm_sme.intr.zero` operations in a block by bitwise
861/// or-ing the zero masks. Note: In future the backend _should_ handle this.
862static void mergeConsecutiveTileZerosInBlock(Block *block) {
863 uint32_t mergedZeroMask = 0;
865 auto replaceMergedZeroOps = [&] {
866 auto cleanup = llvm::make_scope_exit([&] {
867 mergedZeroMask = 0;
868 zeroOpsToMerge.clear();
869 });
870 if (zeroOpsToMerge.size() <= 1)
871 return;
872 IRRewriter rewriter(zeroOpsToMerge.front());
873 arm_sme::aarch64_sme_zero::create(
874 rewriter, zeroOpsToMerge.front().getLoc(),
875 rewriter.getI32IntegerAttr(mergedZeroMask));
876 for (auto zeroOp : zeroOpsToMerge)
877 rewriter.eraseOp(zeroOp);
878 };
879 for (Operation &op : *block) {
880 if (auto zeroOp = dyn_cast<arm_sme::aarch64_sme_zero>(op)) {
881 mergedZeroMask |= zeroOp.getTileMask();
882 zeroOpsToMerge.push_back(zeroOp);
883 } else {
884 replaceMergedZeroOps();
885 }
886 }
887 replaceMergedZeroOps();
888}
889
890} // namespace
891
892namespace {
893
894struct ConvertArmSMEToLLVMPass
895 : public impl::ConvertArmSMEToLLVMBase<ConvertArmSMEToLLVMPass> {
896 ConvertArmSMEToLLVMPass(bool dumpTileLiveRanges) {
897 this->dumpTileLiveRanges = dumpTileLiveRanges;
899 void runOnOperation() override {
900 auto function = getOperation();
904
910
911 if (failed(applyPartialConversion(function, target, std::move(patterns))))
913
914 function->walk(mergeConsecutiveTileZerosInBlock);
915
916 // Walk the function and fail if there are unexpected operations on SME
917 // tile types after conversion.
918 function->walk([&](Operation *op) {
919 // These ops are legal post conversion, skip these.
920 if (isa<arm_sme::CopyTileOp, arm_sme::GetTileOp, cf::BranchOp>(op) ||
921 !op->isRegistered())
922 return;
923 auto isSMETileType = [](Type type) {
925 };
926 if (llvm::any_of(op->getResultTypes(), isSMETileType) ||
927 llvm::any_of(op->getOperandTypes(), isSMETileType)) {
928 op->emitOpError("unexpected operation with SME tile type after "
929 "conversion to LLVM");
931 }
932 });
933 }
934};
935
936} // namespace
937
939 target.addIllegalDialect<arm_sme::ArmSMEDialect>();
940 target.addLegalOp<
941 arm_sme::aarch64_sme_zero, arm_sme::aarch64_sme_str,
942 arm_sme::aarch64_sme_ld1b_horiz, arm_sme::aarch64_sme_ld1h_horiz,
943 arm_sme::aarch64_sme_ld1w_horiz, arm_sme::aarch64_sme_ld1d_horiz,
944 arm_sme::aarch64_sme_ld1q_horiz, arm_sme::aarch64_sme_st1b_horiz,
945 arm_sme::aarch64_sme_st1h_horiz, arm_sme::aarch64_sme_st1w_horiz,
946 arm_sme::aarch64_sme_st1d_horiz, arm_sme::aarch64_sme_st1q_horiz,
947 arm_sme::aarch64_sme_ld1b_vert, arm_sme::aarch64_sme_ld1h_vert,
948 arm_sme::aarch64_sme_ld1w_vert, arm_sme::aarch64_sme_ld1d_vert,
949 arm_sme::aarch64_sme_ld1q_vert, arm_sme::aarch64_sme_st1b_vert,
950 arm_sme::aarch64_sme_st1h_vert, arm_sme::aarch64_sme_st1w_vert,
951 arm_sme::aarch64_sme_st1d_vert, arm_sme::aarch64_sme_st1q_vert,
952 arm_sme::aarch64_sme_read_horiz, arm_sme::aarch64_sme_read_vert,
953 arm_sme::aarch64_sme_write_horiz, arm_sme::aarch64_sme_write_vert,
954 arm_sme::aarch64_sme_mopa, arm_sme::aarch64_sme_mopa_wide,
955 arm_sme::aarch64_sme_mops_wide, arm_sme::aarch64_sme_smopa_wide,
956 arm_sme::aarch64_sme_smops_wide, arm_sme::aarch64_sme_umopa_wide,
957 arm_sme::aarch64_sme_umops_wide, arm_sme::aarch64_sme_smopa_za32,
958 arm_sme::aarch64_sme_smops_za32, arm_sme::aarch64_sme_umopa_za32,
959 arm_sme::aarch64_sme_umops_za32, arm_sme::aarch64_sme_sumopa_wide,
960 arm_sme::aarch64_sme_sumops_wide, arm_sme::aarch64_sme_usmopa_wide,
961 arm_sme::aarch64_sme_usmops_wide, arm_sme::aarch64_sme_cntsd>();
962 target.addLegalDialect<arith::ArithDialect,
963 /* The following are used to lower tile spills/fills */
964 vector::VectorDialect, scf::SCFDialect,
965 memref::MemRefDialect>();
966 // Pseudo operations. These cannot be code-generated but may exist in the
967 // input IR, or be generated during the conversion. They need to be eliminated
968 // before the final conversion to LLVM IR (and likely will be due to DCE).
969 target.addLegalOp<arm_sme::GetTileOp, arm_sme::CopyTileOp,
970 UnrealizedConversionCastOp>();
971}
972
975 converter.addConversion([&](VectorType type) -> std::optional<Type> {
976 // There's no LLVM type for SME tiles, but after lowering to intrinsics all
977 // SME vector types should be eliminated.
979 return type;
980 return std::nullopt;
981 });
982
983 addArmSMEConversionPatterns<
984 LoadTileSliceConversion, ExtractTileSliceConversion,
985 InsertTileSliceConversion, StoreTileSliceConversion,
986 StreamingVLOpConversion, OuterProductOpConversion,
987 OuterProductWideningOpConversion<arm_sme::FMopa2WayOp,
988 arm_sme::aarch64_sme_mopa_wide>,
989 OuterProductWideningOpConversion<arm_sme::FMops2WayOp,
990 arm_sme::aarch64_sme_mops_wide>,
991 OuterProductWideningOpConversion<arm_sme::SMopa2WayOp,
992 arm_sme::aarch64_sme_smopa_za32>,
993 OuterProductWideningOpConversion<arm_sme::SMops2WayOp,
994 arm_sme::aarch64_sme_smops_za32>,
995 OuterProductWideningOpConversion<arm_sme::UMopa2WayOp,
996 arm_sme::aarch64_sme_umopa_za32>,
997 OuterProductWideningOpConversion<arm_sme::UMops2WayOp,
998 arm_sme::aarch64_sme_umops_za32>,
999 OuterProductWideningOpConversion<arm_sme::SMopa4WayOp,
1000 arm_sme::aarch64_sme_smopa_wide>,
1001 OuterProductWideningOpConversion<arm_sme::SMops4WayOp,
1002 arm_sme::aarch64_sme_smops_wide>,
1003 OuterProductWideningOpConversion<arm_sme::UMopa4WayOp,
1004 arm_sme::aarch64_sme_umopa_wide>,
1005 OuterProductWideningOpConversion<arm_sme::UMops4WayOp,
1006 arm_sme::aarch64_sme_umops_wide>,
1007 OuterProductWideningOpConversion<arm_sme::SuMopa4WayOp,
1008 arm_sme::aarch64_sme_sumopa_wide>,
1009 OuterProductWideningOpConversion<arm_sme::SuMops4WayOp,
1010 arm_sme::aarch64_sme_sumops_wide>,
1011 OuterProductWideningOpConversion<arm_sme::UsMopa4WayOp,
1012 arm_sme::aarch64_sme_usmopa_wide>,
1013 OuterProductWideningOpConversion<arm_sme::UsMops4WayOp,
1014 arm_sme::aarch64_sme_usmops_wide>,
1015 ZeroOpConversion>(patterns, converter);
1016}
1017
1018std::unique_ptr<Pass>
1019mlir::createConvertArmSMEToLLVMPass(bool dumpTileLiveRanges) {
1020 return std::make_unique<ConvertArmSMEToLLVMPass>(dumpTileLiveRanges);
1021}
return success()
b getContext())
Block represents an ordered list of Operations.
Definition Block.h:33
IntegerAttr getI32IntegerAttr(int32_t value)
Definition Builders.cpp:200
IntegerType getI64Type()
Definition Builders.cpp:65
IntegerType getI32Type()
Definition Builders.cpp:63
IntegerType getI1Type()
Definition Builders.cpp:53
Utility class for operation conversions targeting the LLVM dialect that match exactly one source oper...
Definition Pattern.h:207
Base class for operation conversions targeting the LLVM IR dialect.
Definition Pattern.h:86
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
This class coordinates rewriting a piece of IR outside of a pattern rewrite, providing a way to keep ...
FunctionOpInterface getOperation()
Definition Pass.h:452
Derived class that automatically populates legalization information for different LLVM ops.
Conversion from types to the LLVM IR dialect.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition Location.h:76
RAII guard to reset the insertion point of the builder when destroyed.
Definition Builders.h:348
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
Definition Builders.h:431
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
InFlightDiagnostic emitWarning(const Twine &message={})
Emit a warning about this operation, reporting up to any diagnostic handlers that may be listening.
Location getLoc()
The source location the operation was defined or derived from.
Operation is the basic unit of execution within MLIR.
Definition Operation.h:88
bool isRegistered()
Returns true if this operation has a registered operation description, otherwise false.
Definition Operation.h:129
operand_type_range getOperandTypes()
Definition Operation.h:397
result_type_range getResultTypes()
Definition Operation.h:428
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
virtual void runOnOperation()=0
The polymorphic API that runs the pass over the currently held operation.
void signalPassFailure()
Signal that some invariant was broken when running.
Definition Pass.h:225
This class contains all of the data related to a pattern, but does not contain any methods or logic f...
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition Types.h:74
This class provides an abstraction over the different types of ranges over Values.
Definition ValueRange.h:387
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
Type getType() const
Return the type of this value.
Definition Value.h:105
static ConstantIndexOp create(OpBuilder &builder, Location location, int64_t value)
Definition ArithOps.cpp:359
static ConstantIntOp create(OpBuilder &builder, Location location, int64_t value, unsigned width)
Definition ArithOps.cpp:258
VectorType getTileType()
Returns the VectorType of the tile used by this operation.
mlir::IntegerAttr getTileId()
Returns the tile ID assigned to this operation.
void setTileId(mlir::IntegerAttr tileId)
Sets the tile ID for this operation.
::mlir::Pass::Option< bool > dumpTileLiveRanges
Value getStridedElementPtr(OpBuilder &builder, Location loc, const LLVMTypeConverter &converter, MemRefType type, Value memRefDesc, ValueRange indices, LLVM::GEPNoWrapFlags noWrapFlags=LLVM::GEPNoWrapFlags::none)
Performs the index computation to get to the element at indices of the memory pointed to by memRefDes...
Definition Pattern.cpp:471
std::optional< ArmSMETileType > getSMETileType(VectorType)
Returns the type of SME tile this vector type corresponds to, or none if the vector type does not fit...
Definition Utils.cpp:58
LogicalResult allocateSMETiles(FunctionOpInterface function, bool dumpRanges=false)
Allocate tile IDs to all ArmSME operations in a function.
unsigned getSMETileSliceMinNumElts(Type type)
Return minimum number of elements for the given element type in a vector of SVL bits.
Definition Utils.cpp:32
unsigned getSizeInBytes(TypeSize type)
Return the size represented by arm_sme::TypeSize in bytes.
Definition Utils.cpp:17
bool isValidSMETileVectorType(VectorType vType)
Returns true if vType is a valid vector type for an SME tile or false otherwise.
Definition Utils.cpp:43
constexpr unsigned MinStreamingVectorLengthInBits
Definition Utils.h:33
Include the generated interface declarations.
std::unique_ptr< Pass > createConvertArmSMEToLLVMPass(bool dumpTileLiveRanges=false)
Create a pass to convert from the ArmSME dialect to LLVM intrinsics.
const FrozenRewritePatternSet & patterns
void configureArmSMEToLLVMConversionLegality(ConversionTarget &target)
Configure target to convert from the ArmSME dialect to LLVM intrinsics.
void populateArmSMEToLLVMConversionPatterns(LLVMTypeConverter &converter, RewritePatternSet &patterns)
Populate the given list with patterns that convert from the ArmSME dialect to LLVM intrinsics.