MLIR 23.0.0git
AMDGPUToROCDL.cpp
Go to the documentation of this file.
1//===- AMDGPUToROCDL.cpp - AMDGPU to ROCDL dialect conversion -------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
10
21#include "mlir/IR/Attributes.h"
24#include "mlir/IR/Matchers.h"
26#include "mlir/Pass/Pass.h"
27
29
30#include "llvm/ADT/STLExtras.h"
31#include "llvm/ADT/TypeSwitch.h"
32#include "llvm/Support/AMDGPUAddrSpace.h"
33#include "llvm/Support/Casting.h"
34#include "llvm/Support/ErrorHandling.h"
35#include <cstdint>
36#include <optional>
37
38namespace mlir {
39#define GEN_PASS_DEF_CONVERTAMDGPUTOROCDLPASS
40#include "mlir/Conversion/Passes.h.inc"
41} // namespace mlir
42
43using namespace mlir;
44using namespace mlir::amdgpu;
45
46// Define commonly used chipsets versions for convenience.
47constexpr Chipset kGfx908 = Chipset(9, 0, 8);
48constexpr Chipset kGfx90a = Chipset(9, 0, 0xa);
49constexpr Chipset kGfx942 = Chipset(9, 4, 2);
50constexpr Chipset kGfx950 = Chipset(9, 5, 0);
51constexpr Chipset kGfx1200 = Chipset(12, 0, 0);
52constexpr Chipset kGfx1250 = Chipset(12, 5, 0);
53
54// Predicates mirroring the LLVM AMDGPU `HasDot{N}Insts` features that gate
55// the `v_dot*` instructions consumed by the `amdgpu.dot` lowering.
56static bool hasDot1Insts(const Chipset &chipset) {
57 if (chipset.majorVersion == 9)
58 return chipset >= Chipset(9, 0, 6);
59 if (chipset.majorVersion == 10) {
60 if (chipset.minorVersion == 1)
61 return chipset.steppingVersion == 1u || chipset.steppingVersion == 2u;
62 return chipset.minorVersion >= 3u;
63 }
64 return false;
65}
66
67static bool hasDot2Insts(const Chipset &chipset) {
68 return hasDot1Insts(chipset);
69}
70
71static bool hasDot7Insts(const Chipset &chipset) {
72 return chipset.majorVersion >= 11 || hasDot1Insts(chipset);
73}
74
75static bool hasDot8Insts(const Chipset &chipset) {
76 return chipset.majorVersion >= 11;
77}
78
79static bool hasDot9Insts(const Chipset &chipset) {
80 if (chipset.majorVersion == 11)
81 return true;
82 return chipset.majorVersion == 12 && chipset.minorVersion == 0;
83}
84
85static bool hasDot10Insts(const Chipset &chipset) {
86 if (chipset.majorVersion == 11)
87 return true;
88 if (chipset.majorVersion == 12)
89 return chipset.minorVersion == 0;
90 return hasDot1Insts(chipset);
91}
92
93static bool hasDot11Insts(const Chipset &chipset) {
94 if (chipset.majorVersion == 11)
95 return chipset.minorVersion == 7u;
96 return chipset.majorVersion == 12 && chipset.minorVersion == 0;
97}
98
99static bool hasDot12Insts(const Chipset &chipset) {
100 if (chipset == Chipset(9, 5, 0))
101 return true;
102 if (chipset.majorVersion == 11)
103 return true;
104 return chipset.majorVersion == 12 && chipset.minorVersion == 0;
105}
106
107/// Convert an unsigned number `val` to i32.
108static Value convertUnsignedToI32(ConversionPatternRewriter &rewriter,
109 Location loc, Value val) {
110 IntegerType i32 = rewriter.getI32Type();
111 // Force check that `val` is of int type.
112 auto valTy = cast<IntegerType>(val.getType());
113 if (i32 == valTy)
114 return val;
115 return valTy.getWidth() > 32
116 ? Value(LLVM::TruncOp::create(rewriter, loc, i32, val))
117 : Value(LLVM::ZExtOp::create(rewriter, loc, i32, val));
118}
119
120static Value createI32Constant(ConversionPatternRewriter &rewriter,
121 Location loc, int32_t value) {
122 return LLVM::ConstantOp::create(rewriter, loc, rewriter.getI32Type(), value);
123}
124
125/// Convert an unsigned number `val` to i64.
126static Value convertUnsignedToI64(ConversionPatternRewriter &rewriter,
127 Location loc, Value val) {
128 IntegerType i64 = rewriter.getI64Type();
129 // Force check that `val` is of int type.
130 auto valTy = cast<IntegerType>(val.getType());
131 if (i64 == valTy)
132 return val;
133 return valTy.getWidth() > 64
134 ? Value(LLVM::TruncOp::create(rewriter, loc, i64, val))
135 : Value(LLVM::ZExtOp::create(rewriter, loc, i64, val));
136}
137
138static Value createI64Constant(ConversionPatternRewriter &rewriter,
139 Location loc, int64_t value) {
140 return LLVM::ConstantOp::create(rewriter, loc, rewriter.getI64Type(), value);
141}
142
143/// Returns the linear index used to access an element in the memref.
144static Value getLinearIndexI32(ConversionPatternRewriter &rewriter,
145 Location loc, MemRefDescriptor &memRefDescriptor,
147 IntegerType i32 = rewriter.getI32Type();
148 Value index;
149 for (auto [i, increment, stride] : llvm::enumerate(indices, strides)) {
150 if (stride != 1) { // Skip if stride is 1.
151 Value strideValue =
152 ShapedType::isDynamic(stride)
153 ? convertUnsignedToI32(rewriter, loc,
154 memRefDescriptor.stride(rewriter, loc, i))
155 : LLVM::ConstantOp::create(rewriter, loc, i32, stride);
156 increment = LLVM::MulOp::create(rewriter, loc, increment, strideValue);
157 }
158 index = index ? LLVM::AddOp::create(rewriter, loc, index, increment)
159 : increment;
160 }
161 return index ? index : createI32Constant(rewriter, loc, 0);
162}
163
164/// Compute the contents of the `num_records` field for a given memref
165/// descriptor - that is, the number of bytes that's one element past the
166/// greatest possible valid index into the memref.
167static Value getNumRecords(ConversionPatternRewriter &rewriter, Location loc,
168 MemRefType memrefType,
169 MemRefDescriptor &memrefDescriptor,
170 ArrayRef<int64_t> strides, int64_t elementByteWidth,
171 amdgpu::Chipset chipset, bool boundsCheck) {
172 if (chipset >= kGfx1250 && !boundsCheck) {
173 constexpr int64_t first45bits = (1ll << 45) - 1;
174 return createI64Constant(rewriter, loc, first45bits);
175 }
176 if (memrefType.hasStaticShape() &&
177 !llvm::any_of(strides, ShapedType::isDynamic)) {
178 int64_t size = memrefType.getRank() == 0 ? 1 : 0;
179 ArrayRef<int64_t> shape = memrefType.getShape();
180 for (uint32_t i = 0, e = memrefType.getRank(); i < e; ++i)
181 size = std::max(shape[i] * strides[i], size);
182 size = size * elementByteWidth;
183 return createI64Constant(rewriter, loc, size);
184 }
185 Value maxIndex;
186 for (uint32_t i = 0, e = memrefType.getRank(); i < e; ++i) {
187 Value size = memrefDescriptor.size(rewriter, loc, i);
188 Value stride = memrefDescriptor.stride(rewriter, loc, i);
189 Value maxThisDim = LLVM::MulOp::create(rewriter, loc, size, stride);
190 maxIndex = maxIndex
191 ? LLVM::UMaxOp::create(rewriter, loc, maxIndex, maxThisDim)
192 : maxThisDim;
193 }
194 Value maxIndexI64 = convertUnsignedToI64(rewriter, loc, maxIndex);
195 Value byteWidthConst = createI64Constant(rewriter, loc, elementByteWidth);
196 return LLVM::MulOp::create(rewriter, loc, maxIndexI64, byteWidthConst);
197}
198
199static Value makeBufferRsrc(ConversionPatternRewriter &rewriter, Location loc,
200 Value basePointer, Value numRecords,
201 bool boundsCheck, amdgpu::Chipset chipset,
202 Value cacheSwizzleStride = nullptr,
203 unsigned addressSpace = 8) {
204 // The stride value is generally 0. However, on MI-300 and onward, you can
205 // enable a cache swizzling mode by setting bit 14 of the stride field
206 // and setting that stride to a cache stride.
207 Type i16 = rewriter.getI16Type();
208 Value stride;
209 if (chipset.majorVersion == 9 && chipset >= kGfx942 && cacheSwizzleStride) {
210 Value cacheStrideZext =
211 LLVM::ZExtOp::create(rewriter, loc, i16, cacheSwizzleStride);
212 Value swizzleBit = LLVM::ConstantOp::create(
213 rewriter, loc, i16, rewriter.getI16IntegerAttr(1 << 14));
214 stride = LLVM::OrOp::create(rewriter, loc, cacheStrideZext, swizzleBit,
215 /*isDisjoint=*/true);
216 } else {
217 stride = LLVM::ConstantOp::create(rewriter, loc, i16,
218 rewriter.getI16IntegerAttr(0));
219 }
220
221 uint32_t flags = 0;
222 if (chipset >= kGfx1250) {
223 // Flag word:
224 // bit 0: swizzle
225 // bit 1: 0 means (total_offset + payload > numRecords)
226 // 1 means ((total_offset + payload >) numRecords) || ((offset +
227 // payload) > stride) only applied when swizzle_enable = 0. keep at
228 // zero.
229 // whether oob is done depends on numRecords.
230 // bits 2-3: Type (must be 0)
231 } else {
232 // Get the number of elements.
233 // Flag word:
234 // bits 0-11: dst sel, ignored by these intrinsics
235 // bits 12-14: data format (ignored, must be nonzero, 7=float)
236 // bits 15-18: data format (ignored, must be nonzero, 4=32bit)
237 // bit 19: In nested heap (0 here)
238 // bit 20: Behavior on unmap (0 means "return 0 / ignore")
239 // bits 21-22: Index stride for swizzles (N/A)
240 // bit 23: Add thread ID (0)
241 // bit 24: Reserved to 1 (RDNA) or 0 (CDNA)
242 // bits 25-26: Reserved (0)
243 // bit 27: Buffer is non-volatile (CDNA only)
244 // bits 28-29: Out of bounds select (0 = structured, 1 = check index, 2 =
245 // none, 3 = either swizzles or testing against offset field) RDNA only
246 // bits 30-31: Type (must be 0)
247 flags |= (7 << 12) | (4 << 15);
248 if (chipset.majorVersion >= 10) {
249 flags |= (1 << 24);
250 uint32_t oob = boundsCheck ? 3 : 2;
251 flags |= (oob << 28);
252 }
253 }
254 Value flagsConst = createI32Constant(rewriter, loc, flags);
255 Type rsrcType =
256 LLVM::LLVMPointerType::get(rewriter.getContext(), addressSpace);
257 Value resource = rewriter.createOrFold<ROCDL::MakeBufferRsrcOp>(
258 loc, rsrcType, basePointer, stride, numRecords, flagsConst);
259 return resource;
260}
261
262namespace {
263struct FatRawBufferCastLowering
264 : public ConvertOpToLLVMPattern<FatRawBufferCastOp> {
265 FatRawBufferCastLowering(const LLVMTypeConverter &converter, Chipset chipset)
266 : ConvertOpToLLVMPattern<FatRawBufferCastOp>(converter),
267 chipset(chipset) {}
268
269 Chipset chipset;
270
271 LogicalResult
272 matchAndRewrite(FatRawBufferCastOp op, FatRawBufferCastOpAdaptor adaptor,
273 ConversionPatternRewriter &rewriter) const override {
274 Location loc = op.getLoc();
275 Value memRef = adaptor.getSource();
276 Value unconvertedMemref = op.getSource();
277 MemRefType memrefType = cast<MemRefType>(unconvertedMemref.getType());
278 MemRefDescriptor descriptor(memRef);
279
280 DataLayout dataLayout = DataLayout::closest(op);
281 int64_t elementByteWidth =
282 dataLayout.getTypeSizeInBits(memrefType.getElementType()) / 8;
283
284 int64_t unusedOffset = 0;
285 SmallVector<int64_t, 5> strideVals;
286 if (failed(memrefType.getStridesAndOffset(strideVals, unusedOffset)))
287 return op.emitOpError("Can't lower non-stride-offset memrefs");
288
289 Value numRecords = adaptor.getValidBytes();
290 if (!numRecords)
291 numRecords =
292 getNumRecords(rewriter, loc, memrefType, descriptor, strideVals,
293 elementByteWidth, chipset, adaptor.getBoundsCheck());
294
295 Value basePointer =
296 adaptor.getResetOffset()
297 ? descriptor.bufferPtr(rewriter, loc, *getTypeConverter(),
298 memrefType)
299 : descriptor.alignedPtr(rewriter, loc);
300
301 Value offset = adaptor.getResetOffset()
302 ? LLVM::ConstantOp::create(rewriter, loc, getIndexType(),
303 rewriter.getIndexAttr(0))
304 : descriptor.offset(rewriter, loc);
305
306 bool hasSizes = memrefType.getRank() > 0;
307 // No need to unpack() and pack() all the individual sizes and strides,
308 // so we'll just extract the arrays.
309 Value sizes = hasSizes
310 ? LLVM::ExtractValueOp::create(rewriter, loc, descriptor,
312 : Value{};
313 Value strides =
314 hasSizes ? LLVM::ExtractValueOp::create(rewriter, loc, descriptor,
316 : Value{};
317
318 Value fatPtr = makeBufferRsrc(
319 rewriter, loc, basePointer, numRecords, adaptor.getBoundsCheck(),
320 chipset, adaptor.getCacheSwizzleStride(), /*addressSpace=*/7);
321
322 Value result = MemRefDescriptor::poison(
323 rewriter, loc,
324 getTypeConverter()->convertType(op.getResult().getType()));
325 SmallVector<int64_t> pos{kAllocatedPtrPosInMemRefDescriptor};
326 result = LLVM::InsertValueOp::create(rewriter, loc, result, fatPtr, pos);
327 result = LLVM::InsertValueOp::create(rewriter, loc, result, fatPtr,
329 result = LLVM::InsertValueOp::create(rewriter, loc, result, offset,
331 if (hasSizes) {
332 result = LLVM::InsertValueOp::create(rewriter, loc, result, sizes,
334 result = LLVM::InsertValueOp::create(rewriter, loc, result, strides,
336 }
337 rewriter.replaceOp(op, result);
338 return success();
339 }
340};
341
342/// Define lowering patterns for raw buffer ops
343template <typename GpuOp, typename Intrinsic>
344struct RawBufferOpLowering : public ConvertOpToLLVMPattern<GpuOp> {
345 RawBufferOpLowering(const LLVMTypeConverter &converter, Chipset chipset)
346 : ConvertOpToLLVMPattern<GpuOp>(converter), chipset(chipset) {}
347
348 Chipset chipset;
349 static constexpr uint32_t maxVectorOpWidth = 128;
350
351 LogicalResult
352 matchAndRewrite(GpuOp gpuOp, typename GpuOp::Adaptor adaptor,
353 ConversionPatternRewriter &rewriter) const override {
354 Location loc = gpuOp.getLoc();
355 Value memref = adaptor.getMemref();
356 Value unconvertedMemref = gpuOp.getMemref();
357 MemRefType memrefType = cast<MemRefType>(unconvertedMemref.getType());
358
359 if (chipset.majorVersion < 9)
360 return gpuOp.emitOpError("raw buffer ops require GCN or higher");
361
362 Value storeData = adaptor.getODSOperands(0)[0];
363 if (storeData == memref) // no write component to this op
364 storeData = Value();
365 Type wantedDataType;
366 if (storeData)
367 wantedDataType = storeData.getType();
368 else
369 wantedDataType = gpuOp.getODSResults(0)[0].getType();
370
371 Value atomicCmpData = Value();
372 // Operand index 1 of a load is the indices, trying to read them can crash.
373 if (storeData) {
374 Value maybeCmpData = adaptor.getODSOperands(1)[0];
375 if (maybeCmpData != memref)
376 atomicCmpData = maybeCmpData;
377 }
378
379 Type llvmWantedDataType = this->typeConverter->convertType(wantedDataType);
380
381 Type i32 = rewriter.getI32Type();
382
383 // Get the type size in bytes.
384 DataLayout dataLayout = DataLayout::closest(gpuOp);
385 int64_t elementByteWidth =
386 dataLayout.getTypeSizeInBits(memrefType.getElementType()) / 8;
387 Value byteWidthConst = createI32Constant(rewriter, loc, elementByteWidth);
388
389 // If we want to load a vector<NxT> with total size <= 32
390 // bits, use a scalar load and bitcast it. Similarly, if bitsize(T) < 32
391 // and the total load size is >= 32, use a vector load of N / (bitsize(T) /
392 // 32) x i32 and bitcast. Also, the CAS intrinsic requires integer operands,
393 // so bitcast any floats to integers.
394 Type llvmBufferValType = llvmWantedDataType;
395 if (atomicCmpData) {
396 if (auto floatType = dyn_cast<FloatType>(wantedDataType))
397 llvmBufferValType = this->getTypeConverter()->convertType(
398 rewriter.getIntegerType(floatType.getWidth()));
399 }
400 if (auto dataVector = dyn_cast<VectorType>(wantedDataType)) {
401 uint32_t vecLen = dataVector.getNumElements();
402 uint32_t elemBits =
403 dataLayout.getTypeSizeInBits(dataVector.getElementType());
404 uint32_t totalBits = elemBits * vecLen;
405 bool usePackedFp16 =
406 isa_and_present<RawBufferAtomicFaddOp>(*gpuOp) && vecLen == 2;
407 if (totalBits > maxVectorOpWidth)
408 return gpuOp.emitOpError(
409 "Total width of loads or stores must be no more than " +
410 Twine(maxVectorOpWidth) + " bits, but we call for " +
411 Twine(totalBits) +
412 " bits. This should've been caught in validation");
413 if (!usePackedFp16 && elemBits < 32) {
414 if (totalBits > 32) {
415 if (totalBits % 32 != 0)
416 return gpuOp.emitOpError("Load or store of more than 32-bits that "
417 "doesn't fit into words. Can't happen\n");
418 llvmBufferValType = this->typeConverter->convertType(
419 VectorType::get(totalBits / 32, i32));
420 } else {
421 llvmBufferValType = this->typeConverter->convertType(
422 rewriter.getIntegerType(totalBits));
423 }
424 }
425 }
426 if (auto vecType = dyn_cast<VectorType>(llvmBufferValType)) {
427 // Buffer intrinsics doesn't support 1-element vectors, cast them to
428 // scalars.
429 if (vecType.getNumElements() == 1)
430 llvmBufferValType = vecType.getElementType();
431 }
432
433 SmallVector<Value, 6> args;
434 if (storeData) {
435 if (llvmBufferValType != llvmWantedDataType) {
436 Value castForStore = LLVM::BitcastOp::create(
437 rewriter, loc, llvmBufferValType, storeData);
438 args.push_back(castForStore);
439 } else {
440 args.push_back(storeData);
441 }
442 }
443
444 if (atomicCmpData) {
445 if (llvmBufferValType != llvmWantedDataType) {
446 Value castForCmp = LLVM::BitcastOp::create(
447 rewriter, loc, llvmBufferValType, atomicCmpData);
448 args.push_back(castForCmp);
449 } else {
450 args.push_back(atomicCmpData);
451 }
452 }
453
454 // Construct buffer descriptor from memref, attributes
455 int64_t offset = 0;
456 SmallVector<int64_t, 5> strides;
457 if (failed(memrefType.getStridesAndOffset(strides, offset)))
458 return gpuOp.emitOpError("Can't lower non-stride-offset memrefs");
459
460 MemRefDescriptor memrefDescriptor(memref);
461
462 Value ptr = memrefDescriptor.bufferPtr(
463 rewriter, loc, *this->getTypeConverter(), memrefType);
464 Value numRecords =
465 getNumRecords(rewriter, loc, memrefType, memrefDescriptor, strides,
466 elementByteWidth, chipset, adaptor.getBoundsCheck());
467 Value resource = makeBufferRsrc(rewriter, loc, ptr, numRecords,
468 adaptor.getBoundsCheck(), chipset);
469 args.push_back(resource);
470
471 // Indexing (voffset)
472 Value voffset = getLinearIndexI32(rewriter, loc, memrefDescriptor,
473 adaptor.getIndices(), strides);
474 if (std::optional<int32_t> indexOffset = adaptor.getIndexOffset();
475 indexOffset && *indexOffset > 0) {
476 Value extraOffsetConst = createI32Constant(rewriter, loc, *indexOffset);
477 voffset = voffset ? LLVM::AddOp::create(rewriter, loc, voffset,
478 extraOffsetConst)
479 : extraOffsetConst;
480 }
481 voffset = LLVM::MulOp::create(rewriter, loc, voffset, byteWidthConst);
482 args.push_back(voffset);
483
484 // SGPR offset.
485 Value sgprOffset = adaptor.getSgprOffset();
486 if (!sgprOffset)
487 sgprOffset = createI32Constant(rewriter, loc, 0);
488 sgprOffset = LLVM::MulOp::create(rewriter, loc, sgprOffset, byteWidthConst);
489 args.push_back(sgprOffset);
490
491 llvm::SmallVector<Type, 1> resultTypes(gpuOp->getNumResults(),
492 llvmBufferValType);
493 typename Intrinsic::Properties properties;
494 properties.aux = rewriter.getI32IntegerAttr(0);
495 Operation *lowered =
496 Intrinsic::create(rewriter, loc, resultTypes, args, properties);
497 if (lowered->getNumResults() == 1) {
498 Value replacement = lowered->getResult(0);
499 if (llvmBufferValType != llvmWantedDataType) {
500 replacement = LLVM::BitcastOp::create(rewriter, loc, llvmWantedDataType,
502 }
503 rewriter.replaceOp(gpuOp, replacement);
504 } else {
505 rewriter.eraseOp(gpuOp);
506 }
507 return success();
508 }
509};
510
511// TODO: AMDGPU backend already have all this bitpacking logic, we should move
512// it to some common place.
513/// Vmcnt, Expcnt and Lgkmcnt are decoded as follows:
514/// Vmcnt = Waitcnt[3:0] (pre-gfx9)
515/// Vmcnt = Waitcnt[15:14,3:0] (gfx9,10)
516/// Vmcnt = Waitcnt[15:10] (gfx11)
517/// Expcnt = Waitcnt[6:4] (pre-gfx11)
518/// Expcnt = Waitcnt[2:0] (gfx11)
519/// Lgkmcnt = Waitcnt[11:8] (pre-gfx10)
520/// Lgkmcnt = Waitcnt[13:8] (gfx10)
521/// Lgkmcnt = Waitcnt[9:4] (gfx11)
522static FailureOr<unsigned> encodeWaitcnt(Chipset chipset, unsigned vmcnt,
523 unsigned expcnt, unsigned lgkmcnt) {
524 if (chipset.majorVersion < 9) {
525 vmcnt = std::min(15u, vmcnt);
526 expcnt = std::min(7u, expcnt);
527 lgkmcnt = std::min(15u, lgkmcnt);
528 return vmcnt | (expcnt << 4) | (lgkmcnt << 8);
529 }
530 if (chipset.majorVersion == 9) {
531 vmcnt = std::min(63u, vmcnt);
532 expcnt = std::min(7u, expcnt);
533 lgkmcnt = std::min(15u, lgkmcnt);
534 unsigned lowBits = vmcnt & 0xF;
535 unsigned highBits = (vmcnt >> 4) << 14;
536 unsigned otherCnts = (expcnt << 4) | (lgkmcnt << 8);
537 return lowBits | highBits | otherCnts;
538 }
539 if (chipset.majorVersion == 10) {
540 vmcnt = std::min(63u, vmcnt);
541 expcnt = std::min(7u, expcnt);
542 lgkmcnt = std::min(63u, lgkmcnt);
543 unsigned lowBits = vmcnt & 0xF;
544 unsigned highBits = (vmcnt >> 4) << 14;
545 unsigned otherCnts = (expcnt << 4) | (lgkmcnt << 8);
546 return lowBits | highBits | otherCnts;
547 }
548 if (chipset.majorVersion == 11) {
549 vmcnt = std::min(63u, vmcnt);
550 expcnt = std::min(7u, expcnt);
551 lgkmcnt = std::min(63u, lgkmcnt);
552 return (vmcnt << 10) | expcnt | (lgkmcnt << 4);
553 }
554 return failure();
555}
556
557struct MemoryCounterWaitOpLowering
558 : public ConvertOpToLLVMPattern<MemoryCounterWaitOp> {
559 MemoryCounterWaitOpLowering(const LLVMTypeConverter &converter,
560 Chipset chipset)
561 : ConvertOpToLLVMPattern<MemoryCounterWaitOp>(converter),
562 chipset(chipset) {}
563
564 Chipset chipset;
565
566 LogicalResult
567 matchAndRewrite(MemoryCounterWaitOp op, OpAdaptor adaptor,
568 ConversionPatternRewriter &rewriter) const override {
569 if (chipset.majorVersion >= 12) {
570 Location loc = op.getLoc();
571 if (std::optional<int> ds = adaptor.getDs())
572 ROCDL::WaitDscntOp::create(rewriter, loc, *ds);
573
574 if (std::optional<int> load = adaptor.getLoad())
575 ROCDL::WaitLoadcntOp::create(rewriter, loc, *load);
576
577 if (std::optional<int> store = adaptor.getStore())
578 ROCDL::WaitStorecntOp::create(rewriter, loc, *store);
579
580 if (std::optional<int> exp = adaptor.getExp())
581 ROCDL::WaitExpcntOp::create(rewriter, loc, *exp);
582
583 if (std::optional<int> tensor = adaptor.getTensor())
584 ROCDL::WaitTensorcntOp::create(rewriter, loc, *tensor);
585
586 rewriter.eraseOp(op);
587 return success();
588 }
589
590 if (adaptor.getTensor())
591 return op.emitOpError("unsupported chipset");
592
593 auto getVal = [](Attribute attr) -> unsigned {
594 if (attr)
595 return cast<IntegerAttr>(attr).getInt();
596
597 // This value will be clamped to the maximum value for the chipset.
598 return 1024;
599 };
600 unsigned ds = getVal(adaptor.getDsAttr());
601 unsigned exp = getVal(adaptor.getExpAttr());
602
603 unsigned vmcnt = 1024;
604 Attribute load = adaptor.getLoadAttr();
605 Attribute store = adaptor.getStoreAttr();
606 if (load && store) {
607 vmcnt = getVal(load) + getVal(store);
608 } else if (load) {
609 vmcnt = getVal(load);
610 } else if (store) {
611 vmcnt = getVal(store);
612 }
613
614 FailureOr<unsigned> waitcnt = encodeWaitcnt(chipset, vmcnt, exp, ds);
615 if (failed(waitcnt))
616 return op.emitOpError("unsupported chipset");
617
618 rewriter.replaceOpWithNewOp<ROCDL::SWaitcntOp>(op, *waitcnt);
619 return success();
620 }
621};
622
623struct LDSBarrierOpLowering : public ConvertOpToLLVMPattern<LDSBarrierOp> {
624 LDSBarrierOpLowering(const LLVMTypeConverter &converter, Chipset chipset)
625 : ConvertOpToLLVMPattern<LDSBarrierOp>(converter), chipset(chipset) {}
626
627 Chipset chipset;
628
629 LogicalResult
630 matchAndRewrite(LDSBarrierOp op, LDSBarrierOp::Adaptor adaptor,
631 ConversionPatternRewriter &rewriter) const override {
632 Location loc = op.getLoc();
633 // This ensures that waits on global memory aren't introduced on
634 // chips that don't have the BackOffBarrier feature enabled in LLVM.
635 bool requiresInlineAsm = chipset < kGfx90a;
636
637 Attribute mmra =
638 rewriter.getAttr<LLVM::MMRATagAttr>("amdgpu-synchronize-as", "local");
639 // Note: while there *is* a workgroup-one-as scope, this, when combined with
640 // the MMRA, will lead to the fence having no effect. This is because the
641 // codepaths for an atomic load or store will observe that a
642 // one-address-space atomic to LDS requires no synchronization because
643 // operations on LDS are totally ordered with respect to each other, and so
644 // will not emit the correct waitcnt operations that these fences are
645 // intended to produce. Therefore, we use a broader type of fence and rely
646 // on the MMRA to relax it to the semantics we want.
647 StringRef scope = "workgroup";
648
649 auto relFence = LLVM::FenceOp::create(rewriter, loc,
650 LLVM::AtomicOrdering::release, scope);
651 relFence->setDiscardableAttr(LLVM::LLVMDialect::getMmraAttrName(), mmra);
652 if (requiresInlineAsm) {
653 auto asmDialectAttr = LLVM::AsmDialectAttr::get(rewriter.getContext(),
654 LLVM::AsmDialect::AD_ATT);
655 const char *asmStr = ";;;WARNING: BREAKS DEBUG WATCHES\ns_barrier";
656 const char *constraints = "";
657 LLVM::InlineAsmOp::create(
658 rewriter, loc,
659 /*resultTypes=*/TypeRange(), /*operands=*/ValueRange(),
660 /*asm_string=*/asmStr, constraints, /*has_side_effects=*/true,
661 /*is_align_stack=*/false, LLVM::TailCallKind::None,
662 /*asm_dialect=*/asmDialectAttr,
663 /*operand_attrs=*/ArrayAttr());
664 } else if (chipset.majorVersion < 12) {
665 ROCDL::SBarrierOp::create(rewriter, loc);
666 } else {
667 ROCDL::BarrierSignalOp::create(rewriter, loc, -1);
668 ROCDL::BarrierWaitOp::create(rewriter, loc, -1);
669 }
670
671 auto acqFence = LLVM::FenceOp::create(rewriter, loc,
672 LLVM::AtomicOrdering::acquire, scope);
673 acqFence->setDiscardableAttr(LLVM::LLVMDialect::getMmraAttrName(), mmra);
674 rewriter.replaceOp(op, acqFence);
675 return success();
676 }
677};
678
679struct SchedBarrierOpLowering : public ConvertOpToLLVMPattern<SchedBarrierOp> {
680 SchedBarrierOpLowering(const LLVMTypeConverter &converter, Chipset chipset)
681 : ConvertOpToLLVMPattern<SchedBarrierOp>(converter), chipset(chipset) {}
682
683 Chipset chipset;
684
685 LogicalResult
686 matchAndRewrite(SchedBarrierOp op, SchedBarrierOp::Adaptor adaptor,
687 ConversionPatternRewriter &rewriter) const override {
688 rewriter.replaceOpWithNewOp<ROCDL::SchedBarrier>(op, op.getOptsAttr());
689 return success();
690 }
691};
692
693} // namespace
694
695/// Pack small float vector operands (fp4/fp6/fp8/bf16) into the format
696/// expected by scaled matrix multiply intrinsics (MFMA/WMMA).
697///
698/// Specifically:
699/// 1. If the element type is bfloat16, bitcast it to i16 unless rocdl intrinsic
700/// allows bf16. Newer MFMAs support bf16 types on operand, check
701/// IntrinsicsAMDGPU.td file for reference.
702/// 2. If instead we have a more than 64-bit quantity, use a <N / 4 x i32>
703/// instead, which is what the f8f6f4 intrinsics use.
704/// 3. If `input` is a vector of N <= 8 bytes, bitcast it to a (N * 8)-bit
705/// integer.
706///
707/// Note that the type of `input` has already been LLVM type converted:
708/// therefore 8-bit and smaller floats are represented as their corresponding
709/// `iN` integers.
710static Value packSmallFloatVectorOperand(ConversionPatternRewriter &rewriter,
711 Location loc, Value input,
712 bool allowBf16 = true) {
713 Type inputType = input.getType();
714 if (auto vectorType = dyn_cast<VectorType>(inputType)) {
715 if (vectorType.getElementType().isBF16() && !allowBf16)
716 return LLVM::BitcastOp::create(
717 rewriter, loc, vectorType.clone(rewriter.getI16Type()), input);
718 if (vectorType.getElementType().isInteger(8) &&
719 vectorType.getNumElements() <= 8)
720 return LLVM::BitcastOp::create(
721 rewriter, loc,
722 rewriter.getIntegerType(vectorType.getNumElements() * 8), input);
723 if (isa<IntegerType>(vectorType.getElementType()) &&
724 vectorType.getElementTypeBitWidth() <= 8) {
725 int64_t numWords = llvm::divideCeil(
726 vectorType.getNumElements() * vectorType.getElementTypeBitWidth(),
727 32);
728 return LLVM::BitcastOp::create(
729 rewriter, loc, VectorType::get(numWords, rewriter.getI32Type()),
730 input);
731 }
732 }
733 return input;
734}
735
736/// Converts packed vector operands to the expected ROCDL types.
737static Value convertPackedVectorOperand(ConversionPatternRewriter &rewriter,
738 Location loc, Value input,
739 bool allowBf16 = true) {
740 Type inputType = input.getType();
741 auto vectorType = cast<VectorType>(inputType);
742 // bf16 -> i16 when not allowed (pre-gfx950).
743 if (vectorType.getElementType().isBF16() && !allowBf16)
744 return LLVM::BitcastOp::create(
745 rewriter, loc, vectorType.clone(rewriter.getI16Type()), input);
746 // i8/fp8 vectors -> vector<Nxi32>.
747 if (isa<IntegerType>(vectorType.getElementType()) &&
748 vectorType.getElementTypeBitWidth() <= 8) {
749 int64_t numWords = llvm::divideCeil(
750 vectorType.getNumElements() * vectorType.getElementTypeBitWidth(), 32);
751 Type castType = (numWords > 1)
752 ? Type{VectorType::get(numWords, rewriter.getI32Type())}
753 : rewriter.getI32Type();
754 return LLVM::BitcastOp::create(rewriter, loc, castType, input);
755 }
756 return input;
757}
758
759/// Converts the scaled MFMA/WMMA operands, `scalesA` and `scalesB`, from MLIR
760/// AMDGPU dialect convention to ROCDL and LLVM AMDGPU intrinsics convention.
761///
762/// Specifically:
763/// 1. If `input` is a i8 value, zero extend it to i32
764/// 2. If `input` is a vector of length 4 or 8 and type i8, cast it to i32
765///
766/// Note that the type of `input` has already been LLVM type converted:
767/// therefore 8-bit and smaller floats are represented as their corresponding
768/// `iN` integers.
769static Value castScaleOperand(ConversionPatternRewriter &rewriter, Location loc,
770 Value input) {
771 return TypeSwitch<Type, Value>(input.getType())
772 .Case([&](IntegerType) {
773 // Handle scalar i8: zero extend to i32.
774 return LLVM::ZExtOp::create(rewriter, loc, rewriter.getI32Type(),
775 input);
776 })
777 .Case([&](VectorType vectorType) {
778 // Handle vector<4xi8> -> i32 or vector<8xi8> -> i64.
779 int64_t numElements = vectorType.getNumElements();
780 assert((numElements == 4 || numElements == 8) &&
781 "scale operand must be a vector of length 4 or 8");
782 IntegerType outputType =
783 (numElements == 4) ? rewriter.getI32Type() : rewriter.getI64Type();
784 return LLVM::BitcastOp::create(rewriter, loc, outputType, input);
785 })
786 .DefaultUnreachable("unexpected input type for scale operand");
787}
788
789/// Maps f8 scale element types to WMMA scale format codes.
790static std::optional<ROCDL::WMMAMatrixScaleFormat>
793 .Case([](Float8E8M0FNUType) { return ROCDL::WMMAMatrixScaleFormat::e8; })
794 .Case([](Float8E4M3FNType) { return ROCDL::WMMAMatrixScaleFormat::e4m3; })
795 .Default(std::nullopt);
796}
797
798/// Determines the ROCDL intrinsic name for scaled WMMA based on dimensions
799/// and scale block size (16 or 32).
800static std::optional<StringRef>
802 if (m == 16 && n == 16 && k == 128)
803 return isScale16
804 ? ROCDL::wmma_scale16_f32_16x16x128_f8f6f4::getOperationName()
805 : ROCDL::wmma_scale_f32_16x16x128_f8f6f4::getOperationName();
806
807 if (m == 32 && n == 16 && k == 128)
808 return isScale16 ? ROCDL::wmma_scale16_f32_32x16x128_f4::getOperationName()
809 : ROCDL::wmma_scale_f32_32x16x128_f4::getOperationName();
810
811 return std::nullopt;
812}
813
814/// Push an input operand. If it is a float type, nothing to do. If it is
815/// an integer type, then we need to also push its signdness (1 for signed, 0
816/// for unsigned) and we need to pack the input 16xi8 vector into a 4xi32
817/// vector (or the 8xi8 vector into a 2xi32 one for gfx12+).
818/// We also need to convert bfloat inputs to i16 to account for the bfloat
819/// intrinsics having been defined before the AMD backend supported bfloat. We
820/// similarly need to pack 8-bit float types into integers as if they were i8
821/// (which they are for the backend's purposes).
823 ConversionPatternRewriter &rewriter, Location loc,
824 const TypeConverter *typeConverter, bool isUnsigned, Value llvmInput,
825 Value mlirInput, SmallVectorImpl<Value> &operands,
826 SmallVectorImpl<NamedAttribute> &attrs, StringRef attrName) {
827 Type inputType = llvmInput.getType();
828 auto vectorType = dyn_cast<VectorType>(inputType);
829 if (!vectorType) {
830 operands.push_back(llvmInput);
831 return;
832 }
833 Type elemType = vectorType.getElementType();
834 if (elemType.getIntOrFloatBitWidth() > 8) {
835 operands.push_back(llvmInput);
836 return;
837 }
838
839 // We need to check the type of the input before conversion to properly test
840 // for int8. This is because, in LLVM, fp8 type is converted to int8, so the
841 // fp8/int8 information is lost during the conversion process.
842 auto mlirInputType = cast<VectorType>(mlirInput.getType());
843 bool isInputInteger = mlirInputType.getElementType().isInteger();
844 if (isInputInteger) {
845 // if element type is 8-bit signed or unsigned, ignore the isUnsigned flag
846 bool localIsUnsigned = isUnsigned;
847 if (elemType.isUnsignedInteger()) {
848 localIsUnsigned = true;
849 } else if (elemType.isSignedInteger()) {
850 localIsUnsigned = false;
851 }
852 attrs.push_back(
853 NamedAttribute(attrName, rewriter.getBoolAttr(!localIsUnsigned)));
854 }
855
856 int64_t numBits =
857 vectorType.getNumElements() * elemType.getIntOrFloatBitWidth();
858 Type i32 = rewriter.getI32Type();
859 Type intrinsicInType = numBits <= 32
860 ? (Type)rewriter.getIntegerType(numBits)
861 : (Type)VectorType::get(numBits / 32, i32);
862 auto llvmIntrinsicInType = typeConverter->convertType(intrinsicInType);
863 Value castInput = rewriter.createOrFold<LLVM::BitcastOp>(
864 loc, llvmIntrinsicInType, llvmInput);
865 // The wave64-mode 16x16x16 intrinsics that take 4-bit integers only need
866 // (256 / 64) * 4 = 16 bits of input (on gfx12+) but take i32 arguments.
867 // Add in the zeros here.
868 if (numBits < 32)
869 castInput = LLVM::ZExtOp::create(rewriter, loc, i32, castInput);
870 operands.push_back(castInput);
871}
872
873/// Push the output operand. For many cases this is only pushing the output in
874/// the operand list. But when we have f16 -> f16 or bf16 -> bf16 intrinsics,
875/// since the same numbers of VGPRs is used, we need to decide if to store the
876/// result in the upper 16 bits of the VGPRs or in the lower part. To store the
877/// result in the lower 16 bits, set subwordOffset to 1, otherwise result will
878/// be stored it in the upper part. The subwordOffset must not be set for gfx12,
879/// as the instructions have been changed to return fewer registers instead.
880static void wmmaPushOutputOperand(ConversionPatternRewriter &rewriter,
881 Location loc,
882 const TypeConverter *typeConverter,
883 Value output, int32_t subwordOffset,
884 bool clamp, SmallVectorImpl<Value> &operands,
886 Type inputType = output.getType();
887 auto vectorType = dyn_cast<VectorType>(inputType);
888 Type elemType = vectorType.getElementType();
889 operands.push_back(output);
890 if (elemType.isF16() || elemType.isBF16() || elemType.isInteger(16)) {
891 attrs.push_back(
892 NamedAttribute("opsel", rewriter.getBoolAttr(subwordOffset)));
893 } else if (elemType.isInteger(32)) {
894 attrs.push_back(NamedAttribute("clamp", rewriter.getBoolAttr(clamp)));
895 }
896}
897
898/// Return true if `type` is the E5M2 variant of an 8-bit float that is
899/// supported by the `_bf8` instructions on the given `chipset`.
900static bool typeIsExpectedBf8ForChipset(Chipset chipset, Type type) {
901 return (chipset == kGfx942 && isa<Float8E5M2FNUZType>(type)) ||
902 (hasOcpFp8(chipset) && isa<Float8E5M2Type>(type));
903}
904
905/// Return true if `type` is the E4M3FN variant of an 8-bit float that is
906/// supported by the `_fp8` instructions on the given `chipset`.
907static bool typeIsExpectedFp8ForChipset(Chipset chipset, Type type) {
908 return (chipset == kGfx942 && isa<Float8E4M3FNUZType>(type)) ||
909 (hasOcpFp8(chipset) && isa<Float8E4M3FNType>(type));
910}
911
912/// Return the `rocdl` intrinsic corresponding to a MFMA operation `mfma`
913/// if one exists. This includes checking to ensure the intrinsic is supported
914/// on the architecture you are compiling for.
915static std::optional<StringRef> mfmaOpToIntrinsic(MFMAOp mfma,
916 Chipset chipset) {
917 uint32_t m = mfma.getM(), n = mfma.getN(), k = mfma.getK(),
918 b = mfma.getBlocks();
919 Type sourceElem = getElementTypeOrSelf(mfma.getSourceA().getType());
920 Type destElem = getElementTypeOrSelf(mfma.getDestC().getType());
921
922 if (sourceElem.isF32() && destElem.isF32()) {
923 if (mfma.getReducePrecision() && chipset >= kGfx942) {
924 if (m == 32 && n == 32 && k == 4 && b == 1)
925 return ROCDL::mfma_f32_32x32x4_xf32::getOperationName();
926 if (m == 16 && n == 16 && k == 8 && b == 1)
927 return ROCDL::mfma_f32_16x16x8_xf32::getOperationName();
928 }
929 if (m == 32 && n == 32 && k == 1 && b == 2)
930 return ROCDL::mfma_f32_32x32x1f32::getOperationName();
931 if (m == 16 && n == 16 && k == 1 && b == 4)
932 return ROCDL::mfma_f32_16x16x1f32::getOperationName();
933 if (m == 4 && n == 4 && k == 1 && b == 16)
934 return ROCDL::mfma_f32_4x4x1f32::getOperationName();
935 if (m == 32 && n == 32 && k == 2 && b == 1)
936 return ROCDL::mfma_f32_32x32x2f32::getOperationName();
937 if (m == 16 && n == 16 && k == 4 && b == 1)
938 return ROCDL::mfma_f32_16x16x4f32::getOperationName();
939 }
940
941 if (sourceElem.isF16() && destElem.isF32()) {
942 if (chipset >= kGfx950) {
943 if (m == 32 && n == 32 && k == 16 && b == 1)
944 return ROCDL::mfma_f32_32x32x16_f16::getOperationName();
945 if (m == 16 && n == 16 && k == 32 && b == 1)
946 return ROCDL::mfma_f32_16x16x32_f16::getOperationName();
947 }
948 if (m == 32 && n == 32 && k == 4 && b == 2)
949 return ROCDL::mfma_f32_32x32x4f16::getOperationName();
950 if (m == 16 && n == 16 && k == 4 && b == 4)
951 return ROCDL::mfma_f32_16x16x4f16::getOperationName();
952 if (m == 4 && n == 4 && k == 4 && b == 16)
953 return ROCDL::mfma_f32_4x4x4f16::getOperationName();
954 if (m == 32 && n == 32 && k == 8 && b == 1)
955 return ROCDL::mfma_f32_32x32x8f16::getOperationName();
956 if (m == 16 && n == 16 && k == 16 && b == 1)
957 return ROCDL::mfma_f32_16x16x16f16::getOperationName();
958 }
959
960 if (sourceElem.isBF16() && destElem.isF32()) {
961 if (chipset >= kGfx950) {
962 if (m == 32 && n == 32 && k == 16 && b == 1)
963 return ROCDL::mfma_f32_32x32x16_bf16::getOperationName();
964 if (m == 16 && n == 16 && k == 32 && b == 1)
965 return ROCDL::mfma_f32_16x16x32_bf16::getOperationName();
966 }
967 if (chipset >= kGfx90a) {
968 if (m == 32 && n == 32 && k == 4 && b == 2)
969 return ROCDL::mfma_f32_32x32x4bf16_1k::getOperationName();
970 if (m == 16 && n == 16 && k == 4 && b == 4)
971 return ROCDL::mfma_f32_16x16x4bf16_1k::getOperationName();
972 if (m == 4 && n == 4 && k == 4 && b == 16)
973 return ROCDL::mfma_f32_4x4x4bf16_1k::getOperationName();
974 if (m == 32 && n == 32 && k == 8 && b == 1)
975 return ROCDL::mfma_f32_32x32x8bf16_1k::getOperationName();
976 if (m == 16 && n == 16 && k == 16 && b == 1)
977 return ROCDL::mfma_f32_16x16x16bf16_1k::getOperationName();
978 }
979 if (m == 32 && n == 32 && k == 2 && b == 2)
980 return ROCDL::mfma_f32_32x32x2bf16::getOperationName();
981 if (m == 16 && n == 16 && k == 2 && b == 4)
982 return ROCDL::mfma_f32_16x16x2bf16::getOperationName();
983 if (m == 4 && n == 4 && k == 2 && b == 16)
984 return ROCDL::mfma_f32_4x4x2bf16::getOperationName();
985 if (m == 32 && n == 32 && k == 4 && b == 1)
986 return ROCDL::mfma_f32_32x32x4bf16::getOperationName();
987 if (m == 16 && n == 16 && k == 8 && b == 1)
988 return ROCDL::mfma_f32_16x16x8bf16::getOperationName();
989 }
990
991 if (sourceElem.isInteger(8) && destElem.isInteger(32)) {
992 if (chipset >= kGfx950) {
993 if (m == 32 && n == 32 && k == 32 && b == 1)
994 return ROCDL::mfma_i32_32x32x32_i8::getOperationName();
995 if (m == 16 && n == 16 && k == 64 && b == 1)
996 return ROCDL::mfma_i32_16x16x64_i8::getOperationName();
997 }
998 if (m == 32 && n == 32 && k == 4 && b == 2)
999 return ROCDL::mfma_i32_32x32x4i8::getOperationName();
1000 if (m == 16 && n == 16 && k == 4 && b == 4)
1001 return ROCDL::mfma_i32_16x16x4i8::getOperationName();
1002 if (m == 4 && n == 4 && k == 4 && b == 16)
1003 return ROCDL::mfma_i32_4x4x4i8::getOperationName();
1004 if (m == 32 && n == 32 && k == 8 && b == 1)
1005 return ROCDL::mfma_i32_32x32x8i8::getOperationName();
1006 if (m == 16 && n == 16 && k == 16 && b == 1)
1007 return ROCDL::mfma_i32_16x16x16i8::getOperationName();
1008 if (m == 32 && n == 32 && k == 16 && b == 1 && chipset >= kGfx942)
1009 return ROCDL::mfma_i32_32x32x16_i8::getOperationName();
1010 if (m == 16 && n == 16 && k == 32 && b == 1 && chipset >= kGfx942)
1011 return ROCDL::mfma_i32_16x16x32_i8::getOperationName();
1012 }
1013
1014 if (sourceElem.isF64() && destElem.isF64() && chipset >= kGfx90a) {
1015 if (m == 16 && n == 16 && k == 4 && b == 1)
1016 return ROCDL::mfma_f64_16x16x4f64::getOperationName();
1017 if (m == 4 && n == 4 && k == 4 && b == 4)
1018 return ROCDL::mfma_f64_4x4x4f64::getOperationName();
1019 }
1020
1021 if (destElem.isF32() && typeIsExpectedBf8ForChipset(chipset, sourceElem)) {
1022 // Known to be correct because there are no scalar f8 instructions and
1023 // because a length mismatch will have been caught by the verifier.
1024 Type sourceBElem =
1025 cast<VectorType>(mfma.getSourceB().getType()).getElementType();
1026 if (m == 16 && n == 16 && k == 32 && b == 1) {
1027 if (typeIsExpectedBf8ForChipset(chipset, sourceBElem))
1028 return ROCDL::mfma_f32_16x16x32_bf8_bf8::getOperationName();
1029 if (typeIsExpectedFp8ForChipset(chipset, sourceBElem))
1030 return ROCDL::mfma_f32_16x16x32_bf8_fp8::getOperationName();
1031 }
1032 if (m == 32 && n == 32 && k == 16 && b == 1) {
1033 if (typeIsExpectedBf8ForChipset(chipset, sourceBElem))
1034 return ROCDL::mfma_f32_32x32x16_bf8_bf8::getOperationName();
1035 if (typeIsExpectedFp8ForChipset(chipset, sourceBElem))
1036 return ROCDL::mfma_f32_32x32x16_bf8_fp8::getOperationName();
1037 }
1038 }
1039
1040 if (destElem.isF32() && typeIsExpectedFp8ForChipset(chipset, sourceElem)) {
1041 Type sourceBElem =
1042 cast<VectorType>(mfma.getSourceB().getType()).getElementType();
1043 if (m == 16 && n == 16 && k == 32 && b == 1) {
1044 if (typeIsExpectedBf8ForChipset(chipset, sourceBElem))
1045 return ROCDL::mfma_f32_16x16x32_fp8_bf8::getOperationName();
1046 if (typeIsExpectedFp8ForChipset(chipset, sourceBElem))
1047 return ROCDL::mfma_f32_16x16x32_fp8_fp8::getOperationName();
1048 }
1049 if (m == 32 && n == 32 && k == 16 && b == 1) {
1050 if (typeIsExpectedBf8ForChipset(chipset, sourceBElem))
1051 return ROCDL::mfma_f32_32x32x16_fp8_bf8::getOperationName();
1052 if (typeIsExpectedFp8ForChipset(chipset, sourceBElem))
1053 return ROCDL::mfma_f32_32x32x16_fp8_fp8::getOperationName();
1054 }
1055 }
1056
1057 return std::nullopt;
1058}
1059
1060static std::optional<ROCDL::MatrixFormat>
1063 mlirElemType)
1064 .Case([](Float8E4M3FNType) { return ROCDL::MatrixFormat::fp8_e4m3; })
1065 .Case([](Float8E5M2Type) { return ROCDL::MatrixFormat::fp8_e5m2; })
1066 .Case([](Float6E2M3FNType) { return ROCDL::MatrixFormat::fp6_e2m3; })
1067 .Case([](Float6E3M2FNType) { return ROCDL::MatrixFormat::fp6_e3m2; })
1068 .Case([](Float4E2M1FNType) { return ROCDL::MatrixFormat::fp4_e2m1; })
1069 .Default(std::nullopt);
1070}
1071
1072/// If there is a scaled MFMA instruction for the input element types `aType`
1073/// and `bType`, output type `destType`, problem size M, N, K, and B (number of
1074/// blocks) on the given `chipset`, return a tuple consisting of the
1075/// OperationName of the intrinsic and the type codes that need to be passed to
1076/// that intrinsic. Note that this is also used to implement some un-scaled
1077/// MFMAs, since the compiler represents the ordinary instruction as a "scaled"
1078/// MFMA with a scale of 0.
1080 std::tuple<StringRef, ROCDL::MatrixFormat, ROCDL::MatrixFormat>;
1081
1082static std::optional<ScaledMFMAIntrinsic>
1083mfmaOpToScaledIntrinsic(Type aType, Type bType, Type destType, uint32_t m,
1084 uint32_t n, uint32_t k, uint32_t b, Chipset chipset) {
1085 aType = getElementTypeOrSelf(aType);
1086 bType = getElementTypeOrSelf(bType);
1087 destType = getElementTypeOrSelf(destType);
1088
1089 if (chipset < kGfx950)
1090 return std::nullopt;
1091 if (!isa<Float32Type>(destType))
1092 return std::nullopt;
1093
1094 std::optional<ROCDL::MatrixFormat> aTypeCode =
1096 std::optional<ROCDL::MatrixFormat> bTypeCode =
1098 if (!aTypeCode || !bTypeCode)
1099 return std::nullopt;
1100
1101 if (m == 32 && n == 32 && k == 64 && b == 1)
1102 return std::tuple{ROCDL::mfma_scale_f32_32x32x64_f8f6f4::getOperationName(),
1103 *aTypeCode, *bTypeCode};
1104 if (m == 16 && n == 16 && k == 128 && b == 1)
1105 return std::tuple{
1106 ROCDL::mfma_scale_f32_16x16x128_f8f6f4::getOperationName(), *aTypeCode,
1107 *bTypeCode};
1108
1109 return std::nullopt;
1110}
1111
1112static std::optional<ScaledMFMAIntrinsic>
1113mfmaOpToScaledIntrinsic(MFMAOp mfma, Chipset chipset) {
1115 mfma.getSourceA().getType(), mfma.getSourceB().getType(),
1116 mfma.getDestC().getType(), mfma.getM(), mfma.getN(), mfma.getK(),
1117 mfma.getBlocks(), chipset);
1118}
1119
1120static std::optional<ScaledMFMAIntrinsic>
1121mfmaOpToScaledIntrinsic(ScaledMFMAOp smfma, Chipset chipset) {
1122 return mfmaOpToScaledIntrinsic(smfma.getSourceA().getType(),
1123 smfma.getSourceB().getType(),
1124 smfma.getDestC().getType(), smfma.getM(),
1125 smfma.getN(), smfma.getK(), 1u, chipset);
1126}
1127
1128/// Returns the `rocdl` intrinsic corresponding to a WMMA operation `wmma`
1129/// for RDNA3/4 architectures.
1130static std::optional<StringRef>
1131wmmaOpToIntrinsicRDNA(Type elemSourceType, Type elemBSourceType,
1132 Type elemDestType, uint32_t k, bool isRDNA3) {
1133 using fp8 = Float8E4M3FNType;
1134 using bf8 = Float8E5M2Type;
1135
1136 // Handle k == 16 for RDNA3/4.
1137 if (k == 16) {
1138 // Common patterns for RDNA3 and RDNA4.
1139 if (elemSourceType.isF16() && elemDestType.isF32())
1140 return ROCDL::wmma_f32_16x16x16_f16::getOperationName();
1141 if (elemSourceType.isBF16() && elemDestType.isF32())
1142 return ROCDL::wmma_f32_16x16x16_bf16::getOperationName();
1143 if (elemSourceType.isF16() && elemDestType.isF16())
1144 return ROCDL::wmma_f16_16x16x16_f16::getOperationName();
1145 if (elemSourceType.isBF16() && elemDestType.isBF16())
1146 return ROCDL::wmma_bf16_16x16x16_bf16::getOperationName();
1147 if (elemSourceType.isInteger(8) && elemDestType.isInteger(32))
1148 return ROCDL::wmma_i32_16x16x16_iu8::getOperationName();
1149
1150 // RDNA3 specific patterns.
1151 if (isRDNA3) {
1152 if (elemSourceType.isInteger(4) && elemDestType.isInteger(32))
1153 return ROCDL::wmma_i32_16x16x16_iu4::getOperationName();
1154 return std::nullopt;
1155 }
1156
1157 // RDNA4 specific patterns (fp8/bf8).
1158 if (isa<fp8>(elemSourceType) && isa<fp8>(elemBSourceType) &&
1159 elemDestType.isF32())
1160 return ROCDL::wmma_f32_16x16x16_fp8_fp8::getOperationName();
1161 if (isa<fp8>(elemSourceType) && isa<bf8>(elemBSourceType) &&
1162 elemDestType.isF32())
1163 return ROCDL::wmma_f32_16x16x16_fp8_bf8::getOperationName();
1164 if (isa<bf8>(elemSourceType) && isa<bf8>(elemBSourceType) &&
1165 elemDestType.isF32())
1166 return ROCDL::wmma_f32_16x16x16_bf8_bf8::getOperationName();
1167 if (isa<bf8>(elemSourceType) && isa<fp8>(elemBSourceType) &&
1168 elemDestType.isF32())
1169 return ROCDL::wmma_f32_16x16x16_bf8_fp8::getOperationName();
1170 if (elemSourceType.isInteger(4) && elemDestType.isInteger(32))
1171 return ROCDL::wmma_i32_16x16x16_iu4::getOperationName();
1172
1173 return std::nullopt;
1174 }
1175
1176 // Handle k == 32 for RDNA4.
1177 if (k == 32 && !isRDNA3) {
1178 if (elemSourceType.isInteger(4) && elemDestType.isInteger(32))
1179 return ROCDL::wmma_i32_16x16x32_iu4::getOperationName();
1180 }
1181
1182 return std::nullopt;
1183}
1184
1185/// Return the `rocdl` intrinsic corresponding to a WMMA operation `wmma`
1186/// for the gfx1250 architecture.
1187static std::optional<StringRef> wmmaOpToIntrinsicGfx1250(Type elemSourceType,
1188 Type elemBSourceType,
1189 Type elemDestType,
1190 uint32_t k) {
1191 using fp8 = Float8E4M3FNType;
1192 using bf8 = Float8E5M2Type;
1193
1194 if (k == 4) {
1195 if (elemSourceType.isF32() && elemDestType.isF32())
1196 return ROCDL::wmma_f32_16x16x4_f32::getOperationName();
1197
1198 return std::nullopt;
1199 }
1200
1201 if (k == 32) {
1202 if (elemSourceType.isF16() && elemDestType.isF32())
1203 return ROCDL::wmma_f32_16x16x32_f16::getOperationName();
1204 if (elemSourceType.isBF16() && elemDestType.isF32())
1205 return ROCDL::wmma_f32_16x16x32_bf16::getOperationName();
1206 if (elemSourceType.isF16() && elemDestType.isF16())
1207 return ROCDL::wmma_f16_16x16x32_f16::getOperationName();
1208 if (elemSourceType.isBF16() && elemDestType.isBF16())
1209 return ROCDL::wmma_bf16_16x16x32_bf16::getOperationName();
1210
1211 return std::nullopt;
1212 }
1213
1214 if (k == 64) {
1215 if (isa<fp8>(elemSourceType) && isa<fp8>(elemBSourceType)) {
1216 if (elemDestType.isF32())
1217 return ROCDL::wmma_f32_16x16x64_fp8_fp8::getOperationName();
1218 if (elemDestType.isF16())
1219 return ROCDL::wmma_f16_16x16x64_fp8_fp8::getOperationName();
1220 }
1221 if (isa<fp8>(elemSourceType) && isa<bf8>(elemBSourceType)) {
1222 if (elemDestType.isF32())
1223 return ROCDL::wmma_f32_16x16x64_fp8_bf8::getOperationName();
1224 if (elemDestType.isF16())
1225 return ROCDL::wmma_f16_16x16x64_fp8_bf8::getOperationName();
1226 }
1227 if (isa<bf8>(elemSourceType) && isa<bf8>(elemBSourceType)) {
1228 if (elemDestType.isF32())
1229 return ROCDL::wmma_f32_16x16x64_bf8_bf8::getOperationName();
1230 if (elemDestType.isF16())
1231 return ROCDL::wmma_f16_16x16x64_bf8_bf8::getOperationName();
1232 }
1233 if (isa<bf8>(elemSourceType) && isa<fp8>(elemBSourceType)) {
1234 if (elemDestType.isF32())
1235 return ROCDL::wmma_f32_16x16x64_bf8_fp8::getOperationName();
1236 if (elemDestType.isF16())
1237 return ROCDL::wmma_f16_16x16x64_bf8_fp8::getOperationName();
1238 }
1239 if (elemSourceType.isInteger(8) && elemDestType.isInteger(32))
1240 return ROCDL::wmma_i32_16x16x64_iu8::getOperationName();
1241
1242 return std::nullopt;
1243 }
1244
1245 if (k == 128) {
1246 if (isa<fp8>(elemSourceType) && isa<fp8>(elemBSourceType)) {
1247 if (elemDestType.isF32())
1248 return ROCDL::wmma_f32_16x16x128_fp8_fp8::getOperationName();
1249 if (elemDestType.isF16())
1250 return ROCDL::wmma_f16_16x16x128_fp8_fp8::getOperationName();
1251 }
1252 if (isa<fp8>(elemSourceType) && isa<bf8>(elemBSourceType)) {
1253 if (elemDestType.isF32())
1254 return ROCDL::wmma_f32_16x16x128_fp8_bf8::getOperationName();
1255 if (elemDestType.isF16())
1256 return ROCDL::wmma_f16_16x16x128_fp8_bf8::getOperationName();
1257 }
1258 if (isa<bf8>(elemSourceType) && isa<bf8>(elemBSourceType)) {
1259 if (elemDestType.isF32())
1260 return ROCDL::wmma_f32_16x16x128_bf8_bf8::getOperationName();
1261 if (elemDestType.isF16())
1262 return ROCDL::wmma_f16_16x16x128_bf8_bf8::getOperationName();
1263 }
1264 if (isa<bf8>(elemSourceType) && isa<fp8>(elemBSourceType)) {
1265 if (elemDestType.isF32())
1266 return ROCDL::wmma_f32_16x16x128_bf8_fp8::getOperationName();
1267 if (elemDestType.isF16())
1268 return ROCDL::wmma_f16_16x16x128_bf8_fp8::getOperationName();
1269 }
1270
1271 return std::nullopt;
1272 }
1273
1274 return std::nullopt;
1275}
1276
1277/// Returns the `rocdl` intrinsic corresponding to a SparseMFMA (smfmac)
1278/// operation if one exists. This includes checking to ensure the intrinsic is
1279/// supported on the architecture you are compiling for.
1280static std::optional<StringRef> smfmacOpToIntrinsic(SparseMFMAOp op,
1281 Chipset chipset) {
1282 bool isGfx950 = chipset >= kGfx950;
1283 auto isFp8 = [&](Type t) { return typeIsExpectedFp8ForChipset(chipset, t); };
1284 auto isBf8 = [&](Type t) { return typeIsExpectedBf8ForChipset(chipset, t); };
1285
1286 uint32_t m = op.getM(), n = op.getN(), k = op.getK();
1287 Type sourceAElem = getElementTypeOrSelf(op.getSourceA().getType());
1288 Type sourceBElem = getElementTypeOrSelf(op.getSourceB().getType());
1289 Type destElem = getElementTypeOrSelf(op.getDestC().getType());
1290
1291 if (m == 16 && n == 16 && k == 32) {
1292 if (sourceAElem.isF16() && sourceBElem.isF16() && destElem.isF32())
1293 return ROCDL::smfmac_f32_16x16x32_f16::getOperationName();
1294 if (sourceAElem.isBF16() && sourceBElem.isBF16() && destElem.isF32())
1295 return ROCDL::smfmac_f32_16x16x32_bf16::getOperationName();
1296 }
1297
1298 if (m == 16 && n == 16 && k == 64) {
1299 if (isGfx950) {
1300 if (sourceAElem.isF16() && sourceBElem.isF16() && destElem.isF32())
1301 return ROCDL::smfmac_f32_16x16x64_f16::getOperationName();
1302 if (sourceAElem.isBF16() && sourceBElem.isBF16() && destElem.isF32())
1303 return ROCDL::smfmac_f32_16x16x64_bf16::getOperationName();
1304 }
1305 if (sourceAElem.isInteger(8) && sourceBElem.isInteger(8) &&
1306 destElem.isInteger(32))
1307 return ROCDL::smfmac_i32_16x16x64_i8::getOperationName();
1308 if (isFp8(sourceAElem) && isFp8(sourceBElem) && destElem.isF32())
1309 return ROCDL::smfmac_f32_16x16x64_fp8_fp8::getOperationName();
1310 if (isFp8(sourceAElem) && isBf8(sourceBElem) && destElem.isF32())
1311 return ROCDL::smfmac_f32_16x16x64_fp8_bf8::getOperationName();
1312 if (isBf8(sourceAElem) && isFp8(sourceBElem) && destElem.isF32())
1313 return ROCDL::smfmac_f32_16x16x64_bf8_fp8::getOperationName();
1314 if (isBf8(sourceAElem) && isBf8(sourceBElem) && destElem.isF32())
1315 return ROCDL::smfmac_f32_16x16x64_bf8_bf8::getOperationName();
1316 }
1317
1318 if (m == 16 && n == 16 && k == 128 && isGfx950) {
1319 if (sourceAElem.isInteger(8) && sourceBElem.isInteger(8) &&
1320 destElem.isInteger(32))
1321 return ROCDL::smfmac_i32_16x16x128_i8::getOperationName();
1322 if (isFp8(sourceAElem) && isFp8(sourceBElem) && destElem.isF32())
1323 return ROCDL::smfmac_f32_16x16x128_fp8_fp8::getOperationName();
1324 if (isFp8(sourceAElem) && isBf8(sourceBElem) && destElem.isF32())
1325 return ROCDL::smfmac_f32_16x16x128_fp8_bf8::getOperationName();
1326 if (isBf8(sourceAElem) && isFp8(sourceBElem) && destElem.isF32())
1327 return ROCDL::smfmac_f32_16x16x128_bf8_fp8::getOperationName();
1328 if (isBf8(sourceAElem) && isBf8(sourceBElem) && destElem.isF32())
1329 return ROCDL::smfmac_f32_16x16x128_bf8_bf8::getOperationName();
1330 }
1331
1332 if (m == 32 && n == 32 && k == 16) {
1333 if (sourceAElem.isF16() && sourceBElem.isF16() && destElem.isF32())
1334 return ROCDL::smfmac_f32_32x32x16_f16::getOperationName();
1335 if (sourceAElem.isBF16() && sourceBElem.isBF16() && destElem.isF32())
1336 return ROCDL::smfmac_f32_32x32x16_bf16::getOperationName();
1337 }
1338
1339 if (m == 32 && n == 32 && k == 32) {
1340 if (isGfx950) {
1341 if (sourceAElem.isF16() && sourceBElem.isF16() && destElem.isF32())
1342 return ROCDL::smfmac_f32_32x32x32_f16::getOperationName();
1343 if (sourceAElem.isBF16() && sourceBElem.isBF16() && destElem.isF32())
1344 return ROCDL::smfmac_f32_32x32x32_bf16::getOperationName();
1345 }
1346 if (sourceAElem.isInteger(8) && sourceBElem.isInteger(8) &&
1347 destElem.isInteger(32))
1348 return ROCDL::smfmac_i32_32x32x32_i8::getOperationName();
1349 if (isFp8(sourceAElem) && isFp8(sourceBElem) && destElem.isF32())
1350 return ROCDL::smfmac_f32_32x32x32_fp8_fp8::getOperationName();
1351 if (isFp8(sourceAElem) && isBf8(sourceBElem) && destElem.isF32())
1352 return ROCDL::smfmac_f32_32x32x32_fp8_bf8::getOperationName();
1353 if (isBf8(sourceAElem) && isFp8(sourceBElem) && destElem.isF32())
1354 return ROCDL::smfmac_f32_32x32x32_bf8_fp8::getOperationName();
1355 if (isBf8(sourceAElem) && isBf8(sourceBElem) && destElem.isF32())
1356 return ROCDL::smfmac_f32_32x32x32_bf8_bf8::getOperationName();
1357 }
1358
1359 if (m == 32 && n == 32 && k == 64 && isGfx950) {
1360 if (sourceAElem.isInteger(8) && sourceBElem.isInteger(8) &&
1361 destElem.isInteger(32))
1362 return ROCDL::smfmac_i32_32x32x64_i8::getOperationName();
1363 if (isFp8(sourceAElem) && isFp8(sourceBElem) && destElem.isF32())
1364 return ROCDL::smfmac_f32_32x32x64_fp8_fp8::getOperationName();
1365 if (isFp8(sourceAElem) && isBf8(sourceBElem) && destElem.isF32())
1366 return ROCDL::smfmac_f32_32x32x64_fp8_bf8::getOperationName();
1367 if (isBf8(sourceAElem) && isFp8(sourceBElem) && destElem.isF32())
1368 return ROCDL::smfmac_f32_32x32x64_bf8_fp8::getOperationName();
1369 if (isBf8(sourceAElem) && isBf8(sourceBElem) && destElem.isF32())
1370 return ROCDL::smfmac_f32_32x32x64_bf8_bf8::getOperationName();
1371 }
1372
1373 return std::nullopt;
1374}
1375
1376/// Returns the `rocdl` intrinsic corresponding to a WMMA operation `wmma`
1377/// if one exists. This includes checking to ensure the intrinsic is supported
1378/// on the architecture you are compiling for.
1379static std::optional<StringRef> wmmaOpToIntrinsic(WMMAOp wmma,
1380 Chipset chipset) {
1381 auto sourceVectorType = cast<VectorType>(wmma.getSourceA().getType());
1382 auto sourceBVectorType = cast<VectorType>(wmma.getSourceB().getType());
1383 auto destVectorType = cast<VectorType>(wmma.getDestC().getType());
1384 Type elemSourceType = sourceVectorType.getElementType();
1385 Type elemBSourceType = sourceBVectorType.getElementType();
1386 Type elemDestType = destVectorType.getElementType();
1387
1388 const uint32_t k = wmma.getK();
1389 const bool isRDNA3 = chipset.majorVersion == 11;
1390 const bool isRDNA4 = chipset.majorVersion == 12 && chipset.minorVersion == 0;
1391
1392 // Handle RDNA3 and RDNA4.
1393 if (isRDNA3 || isRDNA4)
1394 return wmmaOpToIntrinsicRDNA(elemSourceType, elemBSourceType, elemDestType,
1395 k, isRDNA3);
1396
1397 // Handle gfx1250.
1398 if (chipset == kGfx1250)
1399 return wmmaOpToIntrinsicGfx1250(elemSourceType, elemBSourceType,
1400 elemDestType, k);
1401
1402 return std::nullopt;
1403}
1404
1405/// Returns the `rocdl` intrinsic corresponding to a SparseWMMA operation
1406/// `swmmac` if one exists. This includes checking to ensure the intrinsic is
1407/// supported on the architecture you are compiling for.
1409 StringRef name;
1413};
1414
1415static std::optional<SparseWMMAOpInfo>
1416sparseWMMAOpToIntrinsic(SparseWMMAOp swmmac, Chipset chipset) {
1417 Type sourceAElem = getElementTypeOrSelf(swmmac.getSourceA().getType());
1418 Type sourceBElem = getElementTypeOrSelf(swmmac.getSourceB().getType());
1419 Type destElem = getElementTypeOrSelf(swmmac.getDestC().getType());
1420
1421 uint32_t m = swmmac.getM(), n = swmmac.getN(), k = swmmac.getK();
1422
1423 if ((m != 16) || (n != 16))
1424 return std::nullopt;
1425
1426 const bool isRDNA4 = chipset.majorVersion == 12 && chipset.minorVersion == 0;
1427 if (isRDNA4) {
1428 if (k == 32) {
1429 if (destElem.isF32() && sourceAElem.isF16() && sourceBElem.isF16())
1430 return SparseWMMAOpInfo{
1431 ROCDL::swmmac_f32_16x16x32_f16::getOperationName(), false, false,
1432 false};
1433 if (destElem.isF32() && sourceAElem.isBF16() && sourceBElem.isBF16())
1434 return SparseWMMAOpInfo{
1435 ROCDL::swmmac_f32_16x16x32_bf16::getOperationName(), false, false,
1436 false};
1437 if (destElem.isF16() && sourceAElem.isF16() && sourceBElem.isF16())
1438 return SparseWMMAOpInfo{
1439 ROCDL::swmmac_f16_16x16x32_f16::getOperationName(), false, false,
1440 false};
1441 if (destElem.isBF16() && sourceAElem.isBF16() && sourceBElem.isBF16())
1442 return SparseWMMAOpInfo{
1443 ROCDL::swmmac_bf16_16x16x32_bf16::getOperationName(), false, false,
1444 false};
1445 if (destElem.isInteger(32) && sourceAElem.isInteger(8) &&
1446 sourceBElem.isInteger(8))
1447 return SparseWMMAOpInfo{
1448 ROCDL::swmmac_i32_16x16x32_iu8::getOperationName(), true, false,
1449 true};
1450 if (destElem.isInteger(32) && sourceAElem.isInteger(4) &&
1451 sourceBElem.isInteger(4))
1452 return SparseWMMAOpInfo{
1453 ROCDL::swmmac_i32_16x16x32_iu4::getOperationName(), true, false,
1454 true};
1455 if (destElem.isF32() && sourceAElem.isF8E4M3FN() &&
1456 sourceBElem.isF8E4M3FN())
1457 return SparseWMMAOpInfo{
1458 ROCDL::swmmac_f32_16x16x32_fp8_fp8::getOperationName(), false,
1459 false, false};
1460 if (destElem.isF32() && sourceAElem.isF8E4M3FN() &&
1461 sourceBElem.isF8E5M2())
1462 return SparseWMMAOpInfo{
1463 ROCDL::swmmac_f32_16x16x32_fp8_bf8::getOperationName(), false,
1464 false, false};
1465 if (destElem.isF32() && sourceAElem.isF8E5M2() &&
1466 sourceBElem.isF8E4M3FN())
1467 return SparseWMMAOpInfo{
1468 ROCDL::swmmac_f32_16x16x32_bf8_fp8::getOperationName(), false,
1469 false, false};
1470 if (destElem.isF32() && sourceAElem.isF8E5M2() && sourceBElem.isF8E5M2())
1471 return SparseWMMAOpInfo{
1472 ROCDL::swmmac_f32_16x16x32_bf8_bf8::getOperationName(), false,
1473 false, false};
1474 }
1475 if (k == 64) {
1476 if (destElem.isInteger(32) && sourceAElem.isInteger(4) &&
1477 sourceBElem.isInteger(4))
1478 return SparseWMMAOpInfo{
1479 ROCDL::swmmac_i32_16x16x64_iu4::getOperationName(), true, false,
1480 true};
1481 }
1482 }
1483
1484 const bool isGFX1250 = chipset == kGfx1250;
1485 const bool isWavesize64 = swmmac.getWave64();
1486 if (isGFX1250 && !isWavesize64) {
1487 if (k == 64) {
1488 if (destElem.isF32() && sourceAElem.isF16() && sourceBElem.isF16())
1489 return SparseWMMAOpInfo{
1490 ROCDL::swmmac_f32_16x16x64_f16::getOperationName(), true, true,
1491 false};
1492 if (destElem.isF32() && sourceAElem.isBF16() && sourceBElem.isBF16())
1493 return SparseWMMAOpInfo{
1494 ROCDL::swmmac_f32_16x16x64_bf16::getOperationName(), true, true,
1495 false};
1496 if (destElem.isF16() && sourceAElem.isF16() && sourceBElem.isF16())
1497 return SparseWMMAOpInfo{
1498 ROCDL::swmmac_f16_16x16x64_f16::getOperationName(), true, true,
1499 false};
1500 if (destElem.isBF16() && sourceAElem.isBF16() && sourceBElem.isBF16())
1501 return SparseWMMAOpInfo{
1502 ROCDL::swmmac_bf16_16x16x64_bf16::getOperationName(), true, true,
1503 false};
1504 }
1505 if (k == 128) {
1506 if (destElem.isF32() && sourceAElem.isF8E4M3FN() &&
1507 sourceBElem.isF8E4M3FN())
1508 return SparseWMMAOpInfo{
1509 ROCDL::swmmac_f32_16x16x128_fp8_fp8::getOperationName(), false,
1510 true, false};
1511 if (destElem.isF32() && sourceAElem.isF8E4M3FN() &&
1512 sourceBElem.isF8E5M2())
1513 return SparseWMMAOpInfo{
1514 ROCDL::swmmac_f32_16x16x128_fp8_bf8::getOperationName(), false,
1515 true, false};
1516 if (destElem.isF32() && sourceAElem.isF8E5M2() &&
1517 sourceBElem.isF8E4M3FN())
1518 return SparseWMMAOpInfo{
1519 ROCDL::swmmac_f32_16x16x128_bf8_fp8::getOperationName(), false,
1520 true, false};
1521 if (destElem.isF32() && sourceAElem.isF8E5M2() && sourceBElem.isF8E5M2())
1522 return SparseWMMAOpInfo{
1523 ROCDL::swmmac_f32_16x16x128_bf8_bf8::getOperationName(), false,
1524 true, false};
1525 if (destElem.isF16() && sourceAElem.isF8E4M3FN() &&
1526 sourceBElem.isF8E4M3FN())
1527 return SparseWMMAOpInfo{
1528 ROCDL::swmmac_f16_16x16x128_fp8_fp8::getOperationName(), false,
1529 true, false};
1530 if (destElem.isF16() && sourceAElem.isF8E4M3FN() &&
1531 sourceBElem.isF8E5M2())
1532 return SparseWMMAOpInfo{
1533 ROCDL::swmmac_f16_16x16x128_fp8_bf8::getOperationName(), false,
1534 true, false};
1535 if (destElem.isF16() && sourceAElem.isF8E5M2() &&
1536 sourceBElem.isF8E4M3FN())
1537 return SparseWMMAOpInfo{
1538 ROCDL::swmmac_f16_16x16x128_bf8_fp8::getOperationName(), false,
1539 true, false};
1540 if (destElem.isF16() && sourceAElem.isF8E5M2() && sourceBElem.isF8E5M2())
1541 return SparseWMMAOpInfo{
1542 ROCDL::swmmac_f16_16x16x128_bf8_bf8::getOperationName(), false,
1543 true, false};
1544 if (destElem.isF16() && sourceAElem.isInteger(8) &&
1545 sourceBElem.isInteger(8))
1546 return SparseWMMAOpInfo{
1547 ROCDL::swmmac_f16_16x16x128_bf8_bf8::getOperationName(), false,
1548 true, false};
1549 if (destElem.isInteger(32) && sourceAElem.isInteger(8) &&
1550 sourceBElem.isInteger(8))
1551 return SparseWMMAOpInfo{
1552 ROCDL::swmmac_i32_16x16x128_iu8::getOperationName(), true, true,
1553 true};
1554 }
1555 }
1556
1557 return std::nullopt;
1558}
1559
1560namespace {
1561struct MFMAOpLowering : public ConvertOpToLLVMPattern<MFMAOp> {
1562 MFMAOpLowering(const LLVMTypeConverter &converter, Chipset chipset)
1563 : ConvertOpToLLVMPattern<MFMAOp>(converter), chipset(chipset) {}
1564
1565 Chipset chipset;
1566
1567 LogicalResult
1568 matchAndRewrite(MFMAOp op, MFMAOpAdaptor adaptor,
1569 ConversionPatternRewriter &rewriter) const override {
1570 Location loc = op.getLoc();
1571 Type destElem = getElementTypeOrSelf(op.getDestD().getType());
1572 Type outType = typeConverter->convertType(op.getDestD().getType());
1573 Type intrinsicOutType = outType;
1574 if (auto outVecType = dyn_cast<VectorType>(outType))
1575 if (outVecType.getElementType().isBF16())
1576 intrinsicOutType = outVecType.clone(rewriter.getI16Type());
1577
1578 if (chipset.majorVersion != 9 || chipset < kGfx908)
1579 return op->emitOpError("MFMA only supported on gfx908+");
1580 uint32_t getBlgpField = static_cast<uint32_t>(op.getBlgp());
1581 if (op.getNegateA() || op.getNegateB() || op.getNegateC()) {
1582 if (chipset < kGfx942)
1583 return op.emitOpError("negation unsupported on older than gfx942");
1584 getBlgpField |=
1585 op.getNegateA() | (op.getNegateB() << 1) | (op.getNegateC() << 2);
1586 }
1587 std::optional<StringRef> maybeIntrinsic = mfmaOpToIntrinsic(op, chipset);
1588 std::optional<ScaledMFMAIntrinsic> maybeScaledIntrinsic =
1589 mfmaOpToScaledIntrinsic(op, chipset);
1590 if (!maybeIntrinsic.has_value() && !maybeScaledIntrinsic.has_value())
1591 return op.emitOpError("no intrinsic matching MFMA size on given chipset");
1592
1593 bool isScaled =
1594 !maybeIntrinsic.has_value() && maybeScaledIntrinsic.has_value();
1595 if (isScaled &&
1596 (adaptor.getAbid() > 0 || getBlgpField > 0 || op.getCbsz() > 0)) {
1597 return op.emitOpError(
1598 "non-default abid, blgp, and cbsz aren't supported on MFMAs that can "
1599 "be scaled as those fields are used for type information");
1600 }
1601
1602 StringRef intrinsicName =
1603 isScaled ? std::get<0>(*maybeScaledIntrinsic) : *maybeIntrinsic;
1604 // Determine if we can use bf16 in the intrinsic. Newer MFMAs in gfx950+
1605 // allows bf16 as the input. For reference check IntrinsicsAMDGPU.td file.
1606 bool allowBf16 = [&]() {
1607 if (chipset < kGfx950)
1608 return false;
1609 if (isScaled)
1610 return true;
1611 return intrinsicName.contains("16x16x32.bf16") ||
1612 intrinsicName.contains("32x32x16.bf16");
1613 }();
1614 OperationState loweredOp(loc, intrinsicName);
1615 loweredOp.addTypes(intrinsicOutType);
1616 loweredOp.addOperands({packSmallFloatVectorOperand(
1617 rewriter, loc, adaptor.getSourceA(), allowBf16),
1619 rewriter, loc, adaptor.getSourceB(), allowBf16),
1620 adaptor.getDestC()});
1621 if (isScaled) {
1622 Value zero = createI32Constant(rewriter, loc, 0);
1623 auto [_scaledName, aTypeCode, bTypeCode] = *maybeScaledIntrinsic;
1624 loweredOp.addOperands({/*scale A=*/zero, /*scale B=*/zero});
1625 loweredOp.addAttributes(
1626 {{"cbsz",
1627 ROCDL::MatrixFormatAttr::get(rewriter.getContext(), aTypeCode)},
1628 {"blgp",
1629 ROCDL::MatrixFormatAttr::get(rewriter.getContext(), bTypeCode)},
1630 {"opselA", rewriter.getI32IntegerAttr(0)},
1631 {"opselB", rewriter.getI32IntegerAttr(0)}});
1632 } else {
1633 Attribute blgpAttr =
1634 destElem.isF64()
1635 ? Attribute(ROCDL::MFMANegModifierAttr::get(
1636 rewriter.getContext(),
1637 static_cast<ROCDL::MFMANegModifier>(getBlgpField)))
1638 : Attribute(ROCDL::MFMAPermBAttr::get(
1639 rewriter.getContext(),
1640 static_cast<ROCDL::MFMAPermB>(getBlgpField)));
1641 loweredOp.addAttributes(
1642 {{"cbsz", rewriter.getI32IntegerAttr(op.getCbsz())},
1643 {"abid", rewriter.getI32IntegerAttr(op.getAbid())},
1644 {"blgp", blgpAttr}});
1645 };
1646 Value lowered = rewriter.create(loweredOp)->getResult(0);
1647 if (outType != intrinsicOutType)
1648 lowered = LLVM::BitcastOp::create(rewriter, loc, outType, lowered);
1649 rewriter.replaceOp(op, lowered);
1650 return success();
1651 }
1652};
1653
1654struct ScaledMFMAOpLowering : public ConvertOpToLLVMPattern<ScaledMFMAOp> {
1655 ScaledMFMAOpLowering(const LLVMTypeConverter &converter, Chipset chipset)
1656 : ConvertOpToLLVMPattern(converter), chipset(chipset) {}
1657
1658 Chipset chipset;
1659
1660 LogicalResult
1661 matchAndRewrite(ScaledMFMAOp op, ScaledMFMAOpAdaptor adaptor,
1662 ConversionPatternRewriter &rewriter) const override {
1663 Location loc = op.getLoc();
1664 Type intrinsicOutType = typeConverter->convertType(op.getDestD().getType());
1665
1666 if (chipset.majorVersion != 9 || chipset < kGfx950)
1667 return op->emitOpError("scaled MFMA only supported on gfx908+");
1668 std::optional<ScaledMFMAIntrinsic> maybeScaledIntrinsic =
1669 mfmaOpToScaledIntrinsic(op, chipset);
1670 if (!maybeScaledIntrinsic.has_value())
1671 return op.emitOpError(
1672 "no intrinsic matching scaled MFMA size on given chipset");
1673
1674 auto [intrinsicName, aTypeCode, bTypeCode] = *maybeScaledIntrinsic;
1675 OperationState loweredOp(loc, intrinsicName);
1676 loweredOp.addTypes(intrinsicOutType);
1677 loweredOp.addOperands(
1678 {packSmallFloatVectorOperand(rewriter, loc, adaptor.getSourceA()),
1679 packSmallFloatVectorOperand(rewriter, loc, adaptor.getSourceB()),
1680 adaptor.getDestC()});
1681 loweredOp.addOperands(
1682 {/*scales A*/
1683 castScaleOperand(rewriter, loc, adaptor.getScalesA()),
1684 /*scales B*/
1685 castScaleOperand(rewriter, loc, adaptor.getScalesB())});
1686 loweredOp.addAttributes(
1687 {{"cbsz",
1688 ROCDL::MatrixFormatAttr::get(rewriter.getContext(), aTypeCode)},
1689 {"blgp",
1690 ROCDL::MatrixFormatAttr::get(rewriter.getContext(), bTypeCode)},
1691 {"opselA", rewriter.getI32IntegerAttr(adaptor.getScalesIdxA())},
1692 {"opselB", rewriter.getI32IntegerAttr(adaptor.getScalesIdxB())}});
1693
1694 Value lowered = rewriter.create(loweredOp)->getResult(0);
1695 rewriter.replaceOp(op, lowered);
1696 return success();
1697 }
1698};
1699
1700struct SparseMFMAOpLowering : public ConvertOpToLLVMPattern<SparseMFMAOp> {
1701 SparseMFMAOpLowering(const LLVMTypeConverter &converter, Chipset chipset)
1702 : ConvertOpToLLVMPattern<SparseMFMAOp>(converter), chipset(chipset) {}
1703
1704 Chipset chipset;
1705
1706 LogicalResult
1707 matchAndRewrite(SparseMFMAOp op, SparseMFMAOpAdaptor adaptor,
1708 ConversionPatternRewriter &rewriter) const override {
1709 Location loc = op.getLoc();
1710 auto outType =
1711 typeConverter->convertType<VectorType>(op.getDestC().getType());
1712 if (!outType)
1713 return rewriter.notifyMatchFailure(op, "type conversion failed");
1714
1715 // smfmac is supported on gfx942 and gfx950.
1716 if (chipset.majorVersion != 9 || chipset < kGfx942)
1717 return op->emitOpError("sparse MFMA (smfmac) only supported on gfx942+");
1718
1719 std::optional<StringRef> maybeIntrinsic = smfmacOpToIntrinsic(op, chipset);
1720 if (!maybeIntrinsic.has_value())
1721 return op.emitOpError(
1722 "no intrinsic matching sparse MFMA on the given chipset");
1723 bool isGfx942BF16 =
1724 (*maybeIntrinsic ==
1725 ROCDL::smfmac_f32_16x16x32_bf16::getOperationName() ||
1726 *maybeIntrinsic ==
1727 ROCDL::smfmac_f32_32x32x16_bf16::getOperationName());
1728 bool isGfx950 = (chipset >= kGfx950) && !isGfx942BF16;
1729
1730 Value a = convertPackedVectorOperand(rewriter, loc, adaptor.getSourceA(),
1731 isGfx950);
1732 Value b = convertPackedVectorOperand(rewriter, loc, adaptor.getSourceB(),
1733 isGfx950);
1734 Value c = adaptor.getDestC();
1735
1736 // Bitcast sparse indices from vector<4xi8> or vector<2xi16> to i32.
1737 // gfx950 8-bit variants already carry the index as i32; skip the bitcast.
1738 Value sparseIdx = adaptor.getSparseIdx();
1739 Type i32Type = rewriter.getI32Type();
1740 if (sparseIdx.getType() != i32Type)
1741 sparseIdx = LLVM::BitcastOp::create(rewriter, loc, i32Type, sparseIdx);
1742
1743 OperationState loweredOp(loc, maybeIntrinsic.value());
1744 loweredOp.addTypes(outType);
1745 loweredOp.addOperands({a, b, c, sparseIdx});
1746 loweredOp.addAttributes(
1747 {{"cbsz", rewriter.getI32IntegerAttr(op.getCbsz())},
1748 {"abid", rewriter.getI32IntegerAttr(op.getAbid())}});
1749 Value lowered = rewriter.create(loweredOp)->getResult(0);
1750 rewriter.replaceOp(op, lowered);
1751 return success();
1752 }
1753};
1754
1755struct WMMAOpLowering : public ConvertOpToLLVMPattern<WMMAOp> {
1756 WMMAOpLowering(const LLVMTypeConverter &converter, Chipset chipset)
1757 : ConvertOpToLLVMPattern<WMMAOp>(converter), chipset(chipset) {}
1758
1759 Chipset chipset;
1760
1761 LogicalResult
1762 matchAndRewrite(WMMAOp op, WMMAOpAdaptor adaptor,
1763 ConversionPatternRewriter &rewriter) const override {
1764 Location loc = op.getLoc();
1765 auto outType =
1766 typeConverter->convertType<VectorType>(op.getDestD().getType());
1767 if (!outType)
1768 return rewriter.notifyMatchFailure(op, "type conversion failed");
1769
1770 if (chipset.majorVersion != 11 && chipset.majorVersion != 12)
1771 return op->emitOpError("WMMA only supported on gfx11 and gfx12");
1772
1773 bool isGFX1250 = chipset >= kGfx1250;
1774
1775 // The WMMA operations represent vectors of bf16s as vectors of i16s
1776 // (except on gfx1250), so we need to bitcast bfloats to i16 and then
1777 // bitcast them back.
1778 auto aType = cast<VectorType>(adaptor.getSourceA().getType());
1779 auto bType = cast<VectorType>(adaptor.getSourceB().getType());
1780 auto destCType = cast<VectorType>(adaptor.getDestC().getType());
1781 bool castAToI16 = aType.getElementType().isBF16() && !isGFX1250;
1782 bool castBToI16 = bType.getElementType().isBF16() && !isGFX1250;
1783 bool castDestCToI16 = destCType.getElementType().isBF16() && !isGFX1250;
1784 bool castOutToI16 = outType.getElementType().isBF16() && !isGFX1250;
1785 VectorType rawOutType = outType;
1786 if (castOutToI16)
1787 rawOutType = outType.clone(rewriter.getI16Type());
1788 Value a = adaptor.getSourceA();
1789 if (castAToI16)
1790 a = LLVM::BitcastOp::create(rewriter, loc,
1791 aType.clone(rewriter.getI16Type()), a);
1792 Value b = adaptor.getSourceB();
1793 if (castBToI16)
1794 b = LLVM::BitcastOp::create(rewriter, loc,
1795 bType.clone(rewriter.getI16Type()), b);
1796 Value destC = adaptor.getDestC();
1797 if (castDestCToI16)
1798 destC = LLVM::BitcastOp::create(
1799 rewriter, loc, destCType.clone(rewriter.getI16Type()), destC);
1800
1801 std::optional<StringRef> maybeIntrinsic = wmmaOpToIntrinsic(op, chipset);
1802
1803 if (!maybeIntrinsic.has_value())
1804 return op.emitOpError("no intrinsic matching WMMA on the given chipset");
1805
1806 if (chipset.majorVersion >= 12 && op.getSubwordOffset() != 0)
1807 return op.emitOpError("subwordOffset not supported on gfx12+");
1808
1809 SmallVector<Value, 4> operands;
1810 SmallVector<NamedAttribute, 4> attrs;
1811 wmmaPushInputOperand(rewriter, loc, typeConverter, op.getUnsignedA(), a,
1812 op.getSourceA(), operands, attrs, "signA");
1813 wmmaPushInputOperand(rewriter, loc, typeConverter, op.getUnsignedB(), b,
1814 op.getSourceB(), operands, attrs, "signB");
1815 wmmaPushOutputOperand(rewriter, loc, typeConverter, destC,
1816 op.getSubwordOffset(), op.getClamp(), operands,
1817 attrs);
1818
1819 OperationState loweredOp(loc, *maybeIntrinsic);
1820 loweredOp.addTypes(rawOutType);
1821 loweredOp.addOperands(operands);
1822 loweredOp.addAttributes(attrs);
1823 Operation *lowered = rewriter.create(loweredOp);
1824
1825 Operation *maybeCastBack = lowered;
1826 if (rawOutType != outType)
1827 maybeCastBack = LLVM::BitcastOp::create(rewriter, loc, outType,
1828 lowered->getResult(0));
1829 rewriter.replaceOp(op, maybeCastBack->getResults());
1830
1831 return success();
1832 }
1833};
1834
1835enum class DotFamily {
1836 /// ROCDL_Dot_IntrOp: single `clamp` attribute.
1837 Clamp,
1838 /// ROCDL_Dot_NoClamp_IntrOp: no attributes.
1839 NoClamp,
1840 /// ROCDL_Sudot_IntrOp: `signA`, `signB`, and `clamp` attributes.
1841 Sudot,
1842};
1843
1844static std::optional<std::pair<StringRef, DotFamily>>
1845dotOpToIntrinsic(DotOp op, Chipset chipset) {
1846 Type aElem = cast<VectorType>(op.getSourceA().getType()).getElementType();
1847 Type bElem = cast<VectorType>(op.getSourceB().getType()).getElementType();
1848 Type dest = op.getDestC().getType();
1849 bool uA = op.getUnsignedA();
1850 bool uB = op.getUnsignedB();
1851
1852 // f16 x f16 -> f32 / f16.
1853 if (aElem.isF16() && bElem.isF16()) {
1854 if (dest.isF32() && hasDot10Insts(chipset))
1855 return {{ROCDL::fdot2::getOperationName(), DotFamily::Clamp}};
1856 if (dest.isF16() && hasDot9Insts(chipset))
1857 return {{ROCDL::fdot2_f16_f16::getOperationName(), DotFamily::NoClamp}};
1858 return std::nullopt;
1859 }
1860
1861 // bf16 x bf16 -> f32 / bf16.
1862 if (aElem.isBF16() && bElem.isBF16()) {
1863 if (dest.isF32() && hasDot12Insts(chipset))
1864 return {{ROCDL::fdot2_f32_bf16::getOperationName(), DotFamily::Clamp}};
1865 if (dest.isBF16() && hasDot9Insts(chipset))
1866 return {{ROCDL::fdot2_bf16_bf16::getOperationName(), DotFamily::NoClamp}};
1867 return std::nullopt;
1868 }
1869
1870 // Integer sources -> i32.
1871 if (isa<IntegerType>(aElem) && isa<IntegerType>(bElem) &&
1872 dest.isInteger(32)) {
1873 bool mixedSign = (uA != uB);
1874 unsigned elemWidth = aElem.getIntOrFloatBitWidth();
1875
1876 if (mixedSign) {
1877 if (!hasDot8Insts(chipset))
1878 return std::nullopt;
1879 StringRef name;
1880 switch (elemWidth) {
1881 case 8:
1882 name = ROCDL::sudot4::getOperationName();
1883 break;
1884 case 4:
1885 name = ROCDL::sudot8::getOperationName();
1886 break;
1887 default:
1888 return std::nullopt;
1889 }
1890 return {{name, DotFamily::Sudot}};
1891 }
1892
1893 StringRef name;
1894 bool supported = false;
1895 switch (elemWidth) {
1896 case 16:
1897 supported = hasDot2Insts(chipset);
1898 name = uA ? ROCDL::udot2::getOperationName()
1899 : ROCDL::sdot2::getOperationName();
1900 break;
1901 case 8:
1902 supported = uA ? hasDot7Insts(chipset)
1903 : hasDot1Insts(chipset) || hasDot8Insts(chipset);
1904 name = uA ? ROCDL::udot4::getOperationName()
1905 : ROCDL::sdot4::getOperationName();
1906 break;
1907 case 4:
1908 supported = uA ? hasDot7Insts(chipset)
1909 : hasDot1Insts(chipset) || hasDot8Insts(chipset);
1910 name = uA ? ROCDL::udot8::getOperationName()
1911 : ROCDL::sdot8::getOperationName();
1912 break;
1913 default:
1914 return std::nullopt;
1915 }
1916 if (!supported)
1917 return std::nullopt;
1918 return {{name, DotFamily::Clamp}};
1919 }
1920
1921 // fp8/bf8 x fp8/bf8 -> f32.
1922 bool aIsFp8 = isa<Float8E4M3FNType>(aElem);
1923 bool aIsBf8 = isa<Float8E5M2Type>(aElem);
1924 bool bIsFp8 = isa<Float8E4M3FNType>(bElem);
1925 bool bIsBf8 = isa<Float8E5M2Type>(bElem);
1926 if ((aIsFp8 || aIsBf8) && (bIsFp8 || bIsBf8) && dest.isF32()) {
1927 if (!hasDot11Insts(chipset))
1928 return std::nullopt;
1929 StringRef name;
1930 if (aIsFp8 && bIsFp8)
1931 name = ROCDL::dot4_f32_fp8_fp8::getOperationName();
1932 else if (aIsFp8 && bIsBf8)
1933 name = ROCDL::dot4_f32_fp8_bf8::getOperationName();
1934 else if (aIsBf8 && bIsFp8)
1935 name = ROCDL::dot4_f32_bf8_fp8::getOperationName();
1936 else
1937 name = ROCDL::dot4_f32_bf8_bf8::getOperationName();
1938 return {{name, DotFamily::NoClamp}};
1939 }
1940
1941 return std::nullopt;
1942}
1943
1944struct DotOpLowering : public ConvertOpToLLVMPattern<DotOp> {
1945 DotOpLowering(const LLVMTypeConverter &converter, Chipset chipset)
1946 : ConvertOpToLLVMPattern<DotOp>(converter), chipset(chipset) {}
1947
1948 Chipset chipset;
1949
1950 LogicalResult
1951 matchAndRewrite(DotOp op, DotOpAdaptor adaptor,
1952 ConversionPatternRewriter &rewriter) const override {
1953 Location loc = op.getLoc();
1954
1955 std::optional<std::pair<StringRef, DotFamily>> maybeIntrinsic =
1956 dotOpToIntrinsic(op, chipset);
1957 if (!maybeIntrinsic)
1958 return op.emitOpError("no intrinsic matching dot on the given chipset: ")
1959 << op.getSourceA().getType() << " * " << op.getSourceB().getType()
1960 << " + " << op.getDestC().getType();
1961
1962 auto [intrinsicName, family] = maybeIntrinsic.value();
1963
1964 Value a = convertPackedVectorOperand(rewriter, loc, adaptor.getSourceA());
1965 Value b = convertPackedVectorOperand(rewriter, loc, adaptor.getSourceB());
1966 Value c = adaptor.getDestC();
1967
1968 SmallVector<NamedAttribute, 3> attrs;
1969 if (family == DotFamily::Sudot) {
1970 attrs.push_back(rewriter.getNamedAttr(
1971 "signA", rewriter.getBoolAttr(!op.getUnsignedA())));
1972 attrs.push_back(rewriter.getNamedAttr(
1973 "signB", rewriter.getBoolAttr(!op.getUnsignedB())));
1974 }
1975
1976 if (family != DotFamily::NoClamp && op.getClamp())
1977 attrs.push_back(
1978 rewriter.getNamedAttr("clamp", rewriter.getBoolAttr(true)));
1979
1980 Type resultType = typeConverter->convertType(op.getDestD().getType());
1981
1982 OperationState loweredOp(loc, intrinsicName);
1983 loweredOp.addTypes(resultType);
1984 loweredOp.addOperands({a, b, c});
1985 loweredOp.addAttributes(attrs);
1986 Operation *lowered = rewriter.create(loweredOp);
1987 rewriter.replaceOp(op, lowered->getResults());
1988 return success();
1989 }
1990};
1991
1992struct SparseWMMAOpLowering : public ConvertOpToLLVMPattern<SparseWMMAOp> {
1993 SparseWMMAOpLowering(const LLVMTypeConverter &converter, Chipset chipset)
1994 : ConvertOpToLLVMPattern<SparseWMMAOp>(converter), chipset(chipset) {}
1995
1996 Chipset chipset;
1997
1998 LogicalResult
1999 matchAndRewrite(SparseWMMAOp op, SparseWMMAOpAdaptor adaptor,
2000 ConversionPatternRewriter &rewriter) const override {
2001 Location loc = op.getLoc();
2002 auto outType =
2003 typeConverter->convertType<VectorType>(op.getDestD().getType());
2004 if (!outType)
2005 return rewriter.notifyMatchFailure(op, "type conversion failed");
2006
2007 std::optional<SparseWMMAOpInfo> maybeIntrinsic =
2008 sparseWMMAOpToIntrinsic(op, chipset);
2009
2010 if (!maybeIntrinsic.has_value())
2011 return op.emitOpError(
2012 "no intrinsic matching Sparse WMMA on the given chipset");
2013 SparseWMMAOpInfo intrinsic = maybeIntrinsic.value();
2014
2015 SmallVector<NamedAttribute> attrs;
2016
2017 if ((op.getUnsignedA() || op.getUnsignedB()) && !intrinsic.useSign)
2018 return op->emitOpError("intrinsic doesn't support unsign");
2019 if (intrinsic.useSign) {
2020 if (auto attr = op.getUnsignedAAttr())
2021 attrs.push_back({"signA", attr});
2022 if (auto attr = op.getUnsignedBAttr())
2023 attrs.push_back({"signB", attr});
2024 }
2025
2026 if ((op.getReuseA() || op.getReuseB()) && !intrinsic.useReuse)
2027 return op->emitOpError("intrinsic doesn't support reuse");
2028 if (intrinsic.useReuse) {
2029 if (auto attr = op.getReuseAAttr())
2030 attrs.push_back({"reuseA", attr});
2031 if (auto attr = op.getReuseBAttr())
2032 attrs.push_back({"reuseB", attr});
2033 }
2034
2035 if (op.getClamp() && !intrinsic.useClamp)
2036 return op->emitOpError("intrinsic doesn't support clamp");
2037 if (intrinsic.useClamp && op.getClampAttr())
2038 attrs.push_back({"clamp", op.getClampAttr()});
2039
2040 const bool isGFX1250orHigher =
2041 chipset.majorVersion == 12 && chipset.minorVersion >= 5;
2042 Value a = convertPackedVectorOperand(rewriter, loc, adaptor.getSourceA(),
2043 isGFX1250orHigher);
2044 Value b = convertPackedVectorOperand(rewriter, loc, adaptor.getSourceB(),
2045 isGFX1250orHigher);
2046 Value c = adaptor.getDestC();
2047 VectorType rawOutType = outType;
2048 if (!isGFX1250orHigher) {
2049 c = convertPackedVectorOperand(rewriter, loc, adaptor.getDestC(), false);
2050 rawOutType = cast<VectorType>(c.getType());
2051 }
2052
2053 // Bitcast sparse indices from vector<4xi8> to i32.
2054 Value sparseIdx = LLVM::BitcastOp::create(
2055 rewriter, loc, rewriter.getI32Type(), adaptor.getSparseIdx());
2056
2057 OperationState loweredOp(loc, intrinsic.name);
2058 loweredOp.addTypes(rawOutType);
2059 loweredOp.addOperands({a, b, c, sparseIdx});
2060 loweredOp.addAttributes(attrs);
2061 Operation *lowered = rewriter.create(loweredOp);
2062
2063 Operation *maybeCastBack = lowered;
2064 if (rawOutType != outType)
2065 maybeCastBack = LLVM::BitcastOp::create(rewriter, loc, outType,
2066 lowered->getResult(0));
2067 rewriter.replaceOp(op, maybeCastBack->getResults());
2068
2069 return success();
2070 }
2071};
2072
2073struct ScaledWMMAOpLowering : public ConvertOpToLLVMPattern<ScaledWMMAOp> {
2074 ScaledWMMAOpLowering(const LLVMTypeConverter &converter, Chipset chipset)
2075 : ConvertOpToLLVMPattern<ScaledWMMAOp>(converter), chipset(chipset) {}
2076
2077 Chipset chipset;
2078
2079 LogicalResult
2080 matchAndRewrite(ScaledWMMAOp op, ScaledWMMAOpAdaptor adaptor,
2081 ConversionPatternRewriter &rewriter) const override {
2082 Location loc = op.getLoc();
2083 auto outType =
2084 typeConverter->convertType<VectorType>(op.getDestD().getType());
2085 if (!outType)
2086 return rewriter.notifyMatchFailure(op, "type conversion failed");
2087
2088 if (chipset < kGfx1250)
2089 return op->emitOpError("WMMA scale only supported on gfx1250+");
2090
2091 int64_t m = op.getM();
2092 int64_t n = op.getN();
2093 int64_t k = op.getK();
2094
2095 Type aElemType = getElementTypeOrSelf(op.getSourceA().getType());
2096 Type bElemType = getElementTypeOrSelf(op.getSourceB().getType());
2097
2098 std::optional<ROCDL::MatrixFormat> aFmtCode =
2100 std::optional<ROCDL::MatrixFormat> bFmtCode =
2102
2103 if (!aFmtCode || !bFmtCode)
2104 return op.emitOpError("unsupported element types for scaled_wmma");
2105
2106 // Get scale vector types and determine variant (scale vs scale16).
2107 auto scaleAVecType = cast<VectorType>(op.getScaleA().getType());
2108 auto scaleBVecType = cast<VectorType>(op.getScaleB().getType());
2109
2110 if (scaleAVecType.getNumElements() != scaleBVecType.getNumElements())
2111 return op.emitOpError("scaleA and scaleB must have equal vector length");
2112
2113 // Extract scale format from element types.
2114 Type scaleAElemType = scaleAVecType.getElementType();
2115 Type scaleBElemType = scaleBVecType.getElementType();
2116
2117 std::optional<ROCDL::WMMAMatrixScaleFormat> scaleAFmt =
2118 getWmmaScaleFormat(scaleAElemType);
2119 std::optional<ROCDL::WMMAMatrixScaleFormat> scaleBFmt =
2120 getWmmaScaleFormat(scaleBElemType);
2121
2122 if (!scaleAFmt || !scaleBFmt)
2123 return op.emitOpError("unsupported scale element types");
2124
2125 // Determine which intrinsic to use based on dimensions.
2126 bool isScale16 = (scaleAVecType.getNumElements() == 8);
2127 std::optional<StringRef> intrinsicName =
2128 getScaledWmmaIntrinsicName(m, n, k, isScale16);
2129 if (!intrinsicName)
2130 return op.emitOpError("unsupported scaled_wmma dimensions: ")
2131 << m << "x" << n << "x" << k;
2132
2133 SmallVector<NamedAttribute, 8> attrs;
2134
2135 // The f4 variant does not have fmtA and fmtB attributes.
2136 bool is32x16 = (m == 32 && n == 16 && k == 128);
2137 if (!is32x16) {
2138 attrs.emplace_back("fmtA", ROCDL::MatrixFormatAttr::get(
2139 rewriter.getContext(), *aFmtCode));
2140 attrs.emplace_back("fmtB", ROCDL::MatrixFormatAttr::get(
2141 rewriter.getContext(), *bFmtCode));
2142 }
2143
2144 // modC uses default value of 0.
2145 attrs.emplace_back(
2146 "modC", ROCDL::WMMACModifierAttr::get(rewriter.getContext(),
2147 ROCDL::WMMACModifier::none));
2148
2149 // Scale attributes. Convert user-facing firstScaleLane (0 or 16) to the
2150 // half of the wave that is being selected (0 or 1).
2151 attrs.emplace_back("scaleAType", ROCDL::WMMAMatrixScaleAttr::get(
2152 rewriter.getContext(),
2153 static_cast<ROCDL::WMMAMatrixScale>(
2154 op.getAFirstScaleLane() / 16)));
2155 attrs.emplace_back("fmtScaleA", ROCDL::WMMAMatrixScaleFormatAttr::get(
2156 rewriter.getContext(), *scaleAFmt));
2157 attrs.emplace_back("scaleBType", ROCDL::WMMAMatrixScaleAttr::get(
2158 rewriter.getContext(),
2159 static_cast<ROCDL::WMMAMatrixScale>(
2160 op.getBFirstScaleLane() / 16)));
2161 attrs.emplace_back("fmtScaleB", ROCDL::WMMAMatrixScaleFormatAttr::get(
2162 rewriter.getContext(), *scaleBFmt));
2163
2164 // Reuse flags use default value of false.
2165 attrs.emplace_back("reuseA", rewriter.getBoolAttr(false));
2166 attrs.emplace_back("reuseB", rewriter.getBoolAttr(false));
2167
2168 // Convert typed float vectors to packed format.
2169 Value sourceA =
2170 packSmallFloatVectorOperand(rewriter, loc, adaptor.getSourceA());
2171 Value sourceB =
2172 packSmallFloatVectorOperand(rewriter, loc, adaptor.getSourceB());
2173
2174 // Pack scale vectors into i32/i64.
2175 Value packedScaleA = castScaleOperand(rewriter, loc, adaptor.getScaleA());
2176 Value packedScaleB = castScaleOperand(rewriter, loc, adaptor.getScaleB());
2177
2178 // Create the intrinsic call.
2179 OperationState loweredOp(loc, *intrinsicName);
2180 loweredOp.addTypes(outType);
2181 loweredOp.addOperands(
2182 {sourceA, sourceB, adaptor.getDestC(), packedScaleA, packedScaleB});
2183 loweredOp.addAttributes(attrs);
2184
2185 Operation *lowered = rewriter.create(loweredOp);
2186 rewriter.replaceOp(op, lowered->getResults());
2187
2188 return success();
2189 }
2190};
2191
2192struct TransposeLoadOpLowering
2193 : public ConvertOpToLLVMPattern<TransposeLoadOp> {
2194 TransposeLoadOpLowering(const LLVMTypeConverter &converter, Chipset chipset)
2195 : ConvertOpToLLVMPattern<TransposeLoadOp>(converter), chipset(chipset) {}
2196
2197 Chipset chipset;
2198
2199 LogicalResult
2200 matchAndRewrite(TransposeLoadOp op, TransposeLoadOpAdaptor adaptor,
2201 ConversionPatternRewriter &rewriter) const override {
2202 if (chipset != kGfx950 && chipset < kGfx1250)
2203 return op.emitOpError(
2204 "transpose_load is only supported on gfx950 and gfx1250+");
2205
2206 Location loc = op.getLoc();
2207 auto srcMemRefType = cast<MemRefType>(op.getSrc().getType());
2208
2209 // Elements in subbyte memrefs are stored non-contiguously,
2210 // reject if source is sub-byte memref. Use emulated memrefs instead.
2211 size_t srcElementSize =
2212 srcMemRefType.getElementType().getIntOrFloatBitWidth();
2213 if (srcElementSize < 8)
2214 return op.emitOpError("Expect source memref to have at least 8 bits "
2215 "element size, got ")
2216 << srcElementSize;
2217
2218 auto resultType = cast<VectorType>(op.getResult().getType());
2219 Value srcPtr =
2220 getStridedElementPtr(rewriter, loc, srcMemRefType, adaptor.getSrc(),
2221 (adaptor.getSrcIndices()));
2222
2223 size_t numElements = resultType.getNumElements();
2224 size_t elementTypeSize =
2225 resultType.getElementType().getIntOrFloatBitWidth();
2226
2227 Type llvmResultType = typeConverter->convertType(resultType);
2228 // ROCDL transpose load intrinsics return vectors of 32-bit integers for
2229 // sub-16-bit element types, and otherwise return the converted result type.
2230 Type rocdlResultType =
2231 elementTypeSize < 16
2232 ? VectorType::get((numElements * elementTypeSize) / 32,
2233 rewriter.getIntegerType(32))
2234 : llvmResultType;
2235
2236 auto emitNumElementsError = [&](size_t expected, StringRef chipsetName) {
2237 return op.emitOpError()
2238 << elementTypeSize << "-bit transpose_load requires " << expected
2239 << " elements on " << chipsetName;
2240 };
2241
2242 Value intrinsic;
2243 if (chipset >= kGfx1250) {
2244 switch (elementTypeSize) {
2245 case 4: {
2246 if (numElements != 16)
2247 return emitNumElementsError(16, "gfx1250+");
2248 intrinsic =
2249 ROCDL::DsLoadTr4_B64::create(rewriter, loc, rocdlResultType, srcPtr)
2250 .getResult();
2251 break;
2252 }
2253 case 6: {
2254 if (numElements != 16)
2255 return emitNumElementsError(16, "gfx1250+");
2256 intrinsic =
2257 ROCDL::DsLoadTr6_B96::create(rewriter, loc, rocdlResultType, srcPtr)
2258 .getResult();
2259 break;
2260 }
2261 case 8: {
2262 if (numElements != 8)
2263 return emitNumElementsError(8, "gfx1250+");
2264 intrinsic =
2265 ROCDL::DsLoadTr8_B64::create(rewriter, loc, rocdlResultType, srcPtr)
2266 .getResult();
2267 break;
2268 }
2269 case 16: {
2270 if (numElements != 8)
2271 return emitNumElementsError(8, "gfx1250+");
2272 intrinsic = ROCDL::DsLoadTr16_B128::create(rewriter, loc,
2273 rocdlResultType, srcPtr)
2274 .getResult();
2275 break;
2276 }
2277 default:
2278 return op.emitOpError("Unsupported element size for transpose load");
2279 }
2280 } else {
2281 switch (elementTypeSize) {
2282 case 4: {
2283 if (numElements != 16)
2284 return emitNumElementsError(16, "gfx950");
2285 intrinsic = ROCDL::ds_read_tr4_b64::create(rewriter, loc,
2286 rocdlResultType, srcPtr)
2287 .getResult();
2288 break;
2289 }
2290 case 6: {
2291 if (numElements != 16)
2292 return emitNumElementsError(16, "gfx950");
2293 intrinsic = ROCDL::ds_read_tr6_b96::create(rewriter, loc,
2294 rocdlResultType, srcPtr)
2295 .getResult();
2296 break;
2297 }
2298 case 8: {
2299 if (numElements != 8)
2300 return emitNumElementsError(8, "gfx950");
2301 intrinsic = ROCDL::ds_read_tr8_b64::create(rewriter, loc,
2302 rocdlResultType, srcPtr)
2303 .getResult();
2304 break;
2305 }
2306 case 16: {
2307 if (numElements != 4)
2308 return emitNumElementsError(4, "gfx950");
2309 intrinsic = ROCDL::ds_read_tr16_b64::create(rewriter, loc,
2310 rocdlResultType, srcPtr)
2311 .getResult();
2312 break;
2313 }
2314 default:
2315 return op.emitOpError("Unsupported element size for transpose load");
2316 }
2317 }
2318
2319 assert(intrinsic && "expected ROCDL transpose load intrinsic");
2320 if (intrinsic.getType() == llvmResultType) {
2321 rewriter.replaceOp(op, intrinsic);
2322 return success();
2323 }
2324 rewriter.replaceOpWithNewOp<LLVM::BitcastOp>(op, llvmResultType, intrinsic);
2325 return success();
2326 }
2327};
2328
2329struct GlobalTransposeLoadOpLowering
2330 : public ConvertOpToLLVMPattern<GlobalTransposeLoadOp> {
2331 GlobalTransposeLoadOpLowering(const LLVMTypeConverter &converter,
2332 Chipset chipset)
2333 : ConvertOpToLLVMPattern<GlobalTransposeLoadOp>(converter),
2334 chipset(chipset) {}
2335
2336 Chipset chipset;
2337
2338 LogicalResult
2339 matchAndRewrite(GlobalTransposeLoadOp op,
2340 GlobalTransposeLoadOpAdaptor adaptor,
2341 ConversionPatternRewriter &rewriter) const override {
2342 if (chipset < kGfx1200)
2343 return op.emitOpError(
2344 "global_transpose_load is only supported on gfx1200+");
2345
2346 Location loc = op.getLoc();
2347 auto srcMemRefType = cast<MemRefType>(op.getSrc().getType());
2348 auto resultType = cast<VectorType>(op.getResult().getType());
2349
2350 Value srcPtr = getStridedElementPtr(
2351 rewriter, loc, srcMemRefType, adaptor.getSrc(), adaptor.getSrcIndices(),
2352 LLVM::GEPNoWrapFlags::inbounds | LLVM::GEPNoWrapFlags::nuw);
2353
2354 size_t numElements = resultType.getNumElements();
2355 size_t elementTypeSize =
2356 resultType.getElementType().getIntOrFloatBitWidth();
2357
2358 // ROCDL global transpose load intrinsics return vectors of i32 for
2359 // sub-16-bit elements, matching the LDS lowering convention.
2360 Type rocdlResultType =
2361 elementTypeSize < 16
2362 ? VectorType::get((numElements * elementTypeSize) / 32,
2363 rewriter.getIntegerType(32))
2364 : typeConverter->convertType(resultType);
2365 Type llvmResultType = typeConverter->convertType(resultType);
2366
2367 switch (elementTypeSize) {
2368 case 4: {
2369 assert(numElements == 16);
2370 if (chipset < kGfx1250)
2371 return op.emitOpError("4-bit global_transpose_load requires gfx1250+");
2372 auto rocdlOp = ROCDL::GlobalLoadTr4_B64::create(rewriter, loc,
2373 rocdlResultType, srcPtr);
2374 rewriter.replaceOpWithNewOp<LLVM::BitcastOp>(op, llvmResultType, rocdlOp);
2375 break;
2376 }
2377 case 6: {
2378 assert(numElements == 16);
2379 if (chipset < kGfx1250)
2380 return op.emitOpError("6-bit global_transpose_load requires gfx1250+");
2381 auto rocdlOp = ROCDL::GlobalLoadTr6_B96::create(rewriter, loc,
2382 rocdlResultType, srcPtr);
2383 rewriter.replaceOpWithNewOp<LLVM::BitcastOp>(op, llvmResultType, rocdlOp);
2384 break;
2385 }
2386 case 8: {
2387 assert(numElements == 8);
2388 auto rocdlOp = ROCDL::GlobalLoadTr8_B64::create(rewriter, loc,
2389 rocdlResultType, srcPtr);
2390 rewriter.replaceOpWithNewOp<LLVM::BitcastOp>(op, llvmResultType, rocdlOp);
2391 break;
2392 }
2393 case 16: {
2394 assert(numElements == 8);
2395 rewriter.replaceOpWithNewOp<ROCDL::GlobalLoadTr8_B128>(op, llvmResultType,
2396 srcPtr);
2397 break;
2398 }
2399 default:
2400 return op.emitOpError(
2401 "unsupported element size for global transpose load");
2402 }
2403 return success();
2404 }
2405};
2406
2407struct GatherToLDSOpLowering : public ConvertOpToLLVMPattern<GatherToLDSOp> {
2408 GatherToLDSOpLowering(const LLVMTypeConverter &converter, Chipset chipset)
2409 : ConvertOpToLLVMPattern<GatherToLDSOp>(converter), chipset(chipset) {}
2410
2411 Chipset chipset;
2412
2413 LogicalResult
2414 matchAndRewrite(GatherToLDSOp op, GatherToLDSOpAdaptor adaptor,
2415 ConversionPatternRewriter &rewriter) const override {
2416 if (chipset.majorVersion < 9 || chipset.majorVersion > 10)
2417 return op.emitOpError("pre-gfx9 and post-gfx10 not supported");
2418
2419 Location loc = op.getLoc();
2420
2421 auto srcMemRefType = cast<MemRefType>(op.getSrc().getType());
2422 auto dstMemRefType = cast<MemRefType>(op.getDst().getType());
2423
2424 // TODO: instead of only transfering one element per thread, we could
2425 // augment it to transfer multiple elements per thread by issuing multiple
2426 // `global_load_lds` instructions.
2427 Type transferType = op.getTransferType();
2428 int loadWidth = [&]() -> int {
2429 if (auto transferVectorType = dyn_cast<VectorType>(transferType)) {
2430 return (transferVectorType.getNumElements() *
2431 transferVectorType.getElementTypeBitWidth()) /
2432 8;
2433 }
2434 return transferType.getIntOrFloatBitWidth() / 8;
2435 }();
2436
2437 // Currently only 1, 2, 4, 12 and 16 byte loads are supported.
2438 if (!llvm::is_contained({1, 2, 4, 12, 16}, loadWidth))
2439 return op.emitOpError("chipset unsupported element size");
2440
2441 if (chipset != kGfx950 && llvm::is_contained({12, 16}, loadWidth))
2442 return op.emitOpError("Gather to LDS instructions with 12-byte and "
2443 "16-byte load widths are only supported on gfx950");
2444
2445 Value srcPtr =
2446 getStridedElementPtr(rewriter, loc, srcMemRefType, adaptor.getSrc(),
2447 (adaptor.getSrcIndices()));
2448 Value dstPtr =
2449 getStridedElementPtr(rewriter, loc, dstMemRefType, adaptor.getDst(),
2450 (adaptor.getDstIndices()));
2451
2452 if (op.getAsync()) {
2453 rewriter.replaceOpWithNewOp<ROCDL::LoadAsyncToLDSOp>(
2454 op, srcPtr, dstPtr, rewriter.getI32IntegerAttr(loadWidth),
2455 /*offset=*/rewriter.getI32IntegerAttr(0),
2456 /*aux=*/rewriter.getI32IntegerAttr(0), ArrayAttr{}, ArrayAttr{},
2457 ArrayAttr{});
2458 } else {
2459 rewriter.replaceOpWithNewOp<ROCDL::LoadToLDSOp>(
2460 op, srcPtr, dstPtr, rewriter.getI32IntegerAttr(loadWidth),
2461 /*offset=*/rewriter.getI32IntegerAttr(0),
2462 /*aux=*/rewriter.getI32IntegerAttr(0), ArrayAttr{}, ArrayAttr{},
2463 ArrayAttr{});
2464 }
2465
2466 return success();
2467 }
2468};
2469
2470struct GlobalLoadAsyncToLDSOpLowering
2471 : public ConvertOpToLLVMPattern<GlobalLoadAsyncToLDSOp> {
2472 GlobalLoadAsyncToLDSOpLowering(const LLVMTypeConverter &converter,
2473 Chipset chipset)
2474 : ConvertOpToLLVMPattern<GlobalLoadAsyncToLDSOp>(converter),
2475 chipset(chipset) {}
2476
2477 Chipset chipset;
2478
2479 LogicalResult
2480 matchAndRewrite(GlobalLoadAsyncToLDSOp op,
2481 GlobalLoadAsyncToLDSOpAdaptor adaptor,
2482 ConversionPatternRewriter &rewriter) const override {
2483 if (chipset < kGfx1250)
2484 return op.emitOpError(
2485 "global_load_async_to_lds is only supported on gfx1250+");
2486
2487 Location loc = op.getLoc();
2488 auto srcMemRefType = cast<MemRefType>(op.getSrc().getType());
2489 auto dstMemRefType = cast<MemRefType>(op.getDst().getType());
2490
2491 Type transferType = op.getTransferType();
2492 int transferBits =
2493 isa<VectorType>(transferType)
2494 ? cast<VectorType>(transferType).getNumElements() *
2495 cast<VectorType>(transferType).getElementTypeBitWidth()
2496 : transferType.getIntOrFloatBitWidth();
2497
2498 Value srcPtr =
2499 getStridedElementPtr(rewriter, loc, srcMemRefType, adaptor.getSrc(),
2500 adaptor.getSrcIndices());
2501 Value dstPtr =
2502 getStridedElementPtr(rewriter, loc, dstMemRefType, adaptor.getDst(),
2503 adaptor.getDstIndices());
2504
2505 if (op.getMask()) {
2506 Value mask = adaptor.getMask();
2507 int64_t nullptrVal =
2508 llvm::AMDGPU::getNullPointerValue(llvm::AMDGPUAS::LOCAL_ADDRESS);
2509 Value nullInt =
2510 createI32Constant(rewriter, loc, static_cast<int32_t>(nullptrVal));
2511 Value nullPtr =
2512 LLVM::IntToPtrOp::create(rewriter, loc, dstPtr.getType(), nullInt);
2513 dstPtr = LLVM::SelectOp::create(rewriter, loc, mask, dstPtr, nullPtr);
2514 }
2515
2516 auto offset = rewriter.getI32IntegerAttr(0);
2517 Attribute aux = rewriter.getI32IntegerAttr(0);
2518
2519 switch (transferBits) {
2520 case 8:
2521 rewriter.replaceOpWithNewOp<ROCDL::GlobalLoadAsyncToLDSB8Op>(
2522 op, srcPtr, dstPtr, offset, aux, ArrayAttr{}, ArrayAttr{},
2523 ArrayAttr{});
2524 break;
2525 case 32:
2526 rewriter.replaceOpWithNewOp<ROCDL::GlobalLoadAsyncToLDSB32Op>(
2527 op, srcPtr, dstPtr, offset, aux, ArrayAttr{}, ArrayAttr{},
2528 ArrayAttr{});
2529 break;
2530 case 64:
2531 rewriter.replaceOpWithNewOp<ROCDL::GlobalLoadAsyncToLDSB64Op>(
2532 op, srcPtr, dstPtr, offset, aux, ArrayAttr{}, ArrayAttr{},
2533 ArrayAttr{});
2534 break;
2535 case 128:
2536 rewriter.replaceOpWithNewOp<ROCDL::GlobalLoadAsyncToLDSB128Op>(
2537 op, srcPtr, dstPtr, offset, aux, ArrayAttr{}, ArrayAttr{},
2538 ArrayAttr{});
2539 break;
2540 default:
2541 return op.emitOpError("unsupported transfer width");
2542 }
2543 return success();
2544 }
2545};
2546
2547namespace {
2548struct ExtPackedFp8OpLowering final
2549 : public ConvertOpToLLVMPattern<ExtPackedFp8Op> {
2550 ExtPackedFp8OpLowering(const LLVMTypeConverter &converter, Chipset chipset)
2551 : ConvertOpToLLVMPattern<amdgpu::ExtPackedFp8Op>(converter),
2552 chipset(chipset) {}
2553 Chipset chipset;
2554
2555 LogicalResult
2556 matchAndRewrite(ExtPackedFp8Op op, ExtPackedFp8OpAdaptor adaptor,
2557 ConversionPatternRewriter &rewriter) const override;
2558};
2559
2560struct ScaledExtPackedMatrixOpLowering final
2561 : public ConvertOpToLLVMPattern<ScaledExtPackedMatrixOp> {
2562 ScaledExtPackedMatrixOpLowering(const LLVMTypeConverter &converter,
2563 Chipset chipset)
2564 : ConvertOpToLLVMPattern<amdgpu::ScaledExtPackedMatrixOp>(converter),
2565 chipset(chipset) {}
2566 Chipset chipset;
2567
2568 LogicalResult
2569 matchAndRewrite(ScaledExtPackedMatrixOp op,
2570 ScaledExtPackedMatrixOpAdaptor adaptor,
2571 ConversionPatternRewriter &rewriter) const override;
2572};
2573
2574struct PackedTrunc2xFp8OpLowering final
2575 : public ConvertOpToLLVMPattern<PackedTrunc2xFp8Op> {
2576 PackedTrunc2xFp8OpLowering(const LLVMTypeConverter &converter,
2577 Chipset chipset)
2578 : ConvertOpToLLVMPattern<amdgpu::PackedTrunc2xFp8Op>(converter),
2579 chipset(chipset) {}
2580 Chipset chipset;
2581
2582 LogicalResult
2583 matchAndRewrite(PackedTrunc2xFp8Op op, PackedTrunc2xFp8OpAdaptor adaptor,
2584 ConversionPatternRewriter &rewriter) const override;
2585};
2586
2587struct PackedStochRoundFp8OpLowering final
2588 : public ConvertOpToLLVMPattern<PackedStochRoundFp8Op> {
2589 PackedStochRoundFp8OpLowering(const LLVMTypeConverter &converter,
2590 Chipset chipset)
2591 : ConvertOpToLLVMPattern<amdgpu::PackedStochRoundFp8Op>(converter),
2592 chipset(chipset) {}
2593 Chipset chipset;
2594
2595 LogicalResult
2596 matchAndRewrite(PackedStochRoundFp8Op op,
2597 PackedStochRoundFp8OpAdaptor adaptor,
2598 ConversionPatternRewriter &rewriter) const override;
2599};
2600
2601struct ScaledExtPackedOpLowering final
2602 : public ConvertOpToLLVMPattern<ScaledExtPackedOp> {
2603 ScaledExtPackedOpLowering(const LLVMTypeConverter &converter, Chipset chipset)
2604 : ConvertOpToLLVMPattern<amdgpu::ScaledExtPackedOp>(converter),
2605 chipset(chipset) {}
2606 Chipset chipset;
2607
2608 LogicalResult
2609 matchAndRewrite(ScaledExtPackedOp op, ScaledExtPackedOpAdaptor adaptor,
2610 ConversionPatternRewriter &rewriter) const override;
2611};
2612
2613struct PackedScaledTruncOpLowering final
2614 : public ConvertOpToLLVMPattern<PackedScaledTruncOp> {
2615 PackedScaledTruncOpLowering(const LLVMTypeConverter &converter,
2616 Chipset chipset)
2617 : ConvertOpToLLVMPattern<amdgpu::PackedScaledTruncOp>(converter),
2618 chipset(chipset) {}
2619 Chipset chipset;
2620
2621 LogicalResult
2622 matchAndRewrite(PackedScaledTruncOp op, PackedScaledTruncOpAdaptor adaptor,
2623 ConversionPatternRewriter &rewriter) const override;
2624};
2625
2626} // end namespace
2627
2628LogicalResult ExtPackedFp8OpLowering::matchAndRewrite(
2629 ExtPackedFp8Op op, ExtPackedFp8OpAdaptor adaptor,
2630 ConversionPatternRewriter &rewriter) const {
2631 Location loc = op.getLoc();
2632 if (!(chipset == kGfx942 || hasOcpFp8(chipset)))
2633 return rewriter.notifyMatchFailure(
2634 loc, "Fp8 conversion instructions are not available on target "
2635 "architecture and their emulation is not implemented");
2636 Type v4i8 =
2637 getTypeConverter()->convertType(VectorType::get(4, rewriter.getI8Type()));
2638 Type i32 = getTypeConverter()->convertType(rewriter.getI32Type());
2639 Type f32 = getTypeConverter()->convertType(op.getResult().getType());
2640
2641 Value source = adaptor.getSource();
2642 auto sourceVecType = dyn_cast<VectorType>(op.getSource().getType());
2643 auto resultVecType = dyn_cast<VectorType>(op.getResult().getType());
2644 Type sourceElemType = getElementTypeOrSelf(op.getSource());
2645 // Extend to a v4i8
2646 if (!sourceVecType || sourceVecType.getNumElements() < 4) {
2647 Value longVec = LLVM::UndefOp::create(rewriter, loc, v4i8);
2648 if (!sourceVecType) {
2649 longVec = LLVM::InsertElementOp::create(
2650 rewriter, loc, longVec, source, createI32Constant(rewriter, loc, 0));
2651 } else {
2652 for (int32_t i = 0, e = sourceVecType.getNumElements(); i < e; ++i) {
2653 Value idx = createI32Constant(rewriter, loc, i);
2654 Value elem = LLVM::ExtractElementOp::create(rewriter, loc, source, idx);
2655 longVec =
2656 LLVM::InsertElementOp::create(rewriter, loc, longVec, elem, idx);
2657 }
2658 }
2659 source = longVec;
2660 }
2661 Value i32Source = LLVM::BitcastOp::create(rewriter, loc, i32, source);
2662 if (resultVecType) {
2663 if (typeIsExpectedBf8ForChipset(chipset, sourceElemType)) {
2664 rewriter.replaceOpWithNewOp<ROCDL::CvtPkF32Bf8Op>(op, f32, i32Source,
2665 op.getIndex());
2666 } else if (typeIsExpectedFp8ForChipset(chipset, sourceElemType)) {
2667 rewriter.replaceOpWithNewOp<ROCDL::CvtPkF32Fp8Op>(op, f32, i32Source,
2668 op.getIndex());
2669 }
2670 } else {
2671 if (typeIsExpectedBf8ForChipset(chipset, sourceElemType)) {
2672 rewriter.replaceOpWithNewOp<ROCDL::CvtF32Bf8Op>(op, f32, i32Source,
2673 op.getIndex());
2674 } else if (typeIsExpectedFp8ForChipset(chipset, sourceElemType)) {
2675 rewriter.replaceOpWithNewOp<ROCDL::CvtF32Fp8Op>(op, f32, i32Source,
2676 op.getIndex());
2677 }
2678 }
2679 return success();
2680}
2681
2682int32_t getScaleSel(int32_t blockSize, unsigned bitWidth, int32_t scaleWaveHalf,
2683 int32_t firstScaleByte) {
2684 // When lowering amdgpu.scaled_ext_packed_matrix to rocdl.cvt.scale.pk*.f*.f*
2685 // operations, the attributes blockSize, sourceType, scaleWaveHalf, and
2686 // firstScaleByte are merged into a single attribute scaleSel. This is how
2687 // those values are merged together. (Note: scaleWaveHalf isn't a high-level
2688 // attribute but is derifed from firstScaleLane).
2689 assert(llvm::is_contained({16, 32}, blockSize));
2690 assert(llvm::is_contained({4u, 6u, 8u}, bitWidth));
2691
2692 const bool isFp8 = bitWidth == 8;
2693 const bool isBlock16 = blockSize == 16;
2694
2695 if (!isFp8) {
2696 int32_t bit0 = isBlock16;
2697 assert(llvm::is_contained({0, 1, 2}, firstScaleByte));
2698 int32_t bit1 = (firstScaleByte == 2) << 1;
2699 assert(llvm::is_contained({0, 1}, scaleWaveHalf));
2700 int32_t bit2 = scaleWaveHalf << 2;
2701 return bit2 | bit1 | bit0;
2702 }
2703
2704 int32_t bit0 = isBlock16;
2705 // firstScaleByte is guaranteed to be defined by two bits.
2706 assert(llvm::is_contained({0, 1, 2, 3}, firstScaleByte));
2707 int32_t bits2and1 = firstScaleByte << 1;
2708 assert(llvm::is_contained({0, 1}, scaleWaveHalf));
2709 int32_t bit3 = scaleWaveHalf << 3;
2710 int32_t bits = bit3 | bits2and1 | bit0;
2711 // These are invalid cases.
2712 assert(!llvm::is_contained(
2713 {0b0011, 0b0101, 0b0111, 0b1000, 0b1001, 0b1011, 0b1111}, bits));
2714 return bits;
2715}
2716
2717static std::optional<StringRef>
2718scaledExtPacked816ToIntrinsic(Type srcElemType, Type destElemType) {
2719 using fp4 = Float4E2M1FNType;
2720 using fp8 = Float8E4M3FNType;
2721 using bf8 = Float8E5M2Type;
2722 using fp6 = Float6E2M3FNType;
2723 using bf6 = Float6E3M2FNType;
2724 if (isa<fp4>(srcElemType)) {
2725 if (destElemType.isF16())
2726 return ROCDL::CvtPkScalePk8F16Fp4Op::getOperationName();
2727 if (destElemType.isBF16())
2728 return ROCDL::CvtPkScalePk8Bf16Fp4Op::getOperationName();
2729 if (destElemType.isF32())
2730 return ROCDL::CvtPkScalePk8F32Fp4Op::getOperationName();
2731 return std::nullopt;
2732 }
2733 if (isa<fp8>(srcElemType)) {
2734 if (destElemType.isF16())
2735 return ROCDL::CvtPkScalePk8F16Fp8Op::getOperationName();
2736 if (destElemType.isBF16())
2737 return ROCDL::CvtPkScalePk8Bf16Fp8Op::getOperationName();
2738 if (destElemType.isF32())
2739 return ROCDL::CvtPkScalePk8F32Fp8Op::getOperationName();
2740 return std::nullopt;
2741 }
2742 if (isa<bf8>(srcElemType)) {
2743 if (destElemType.isF16())
2744 return ROCDL::CvtPkScalePk8F16Bf8Op::getOperationName();
2745 if (destElemType.isBF16())
2746 return ROCDL::CvtPkScalePk8Bf16Bf8Op::getOperationName();
2747 if (destElemType.isF32())
2748 return ROCDL::CvtPkScalePk8F32Bf8Op::getOperationName();
2749 return std::nullopt;
2750 }
2751 if (isa<fp6>(srcElemType)) {
2752 if (destElemType.isF16())
2753 return ROCDL::CvtPkScalePk16F16Fp6Op::getOperationName();
2754 if (destElemType.isBF16())
2755 return ROCDL::CvtPkScalePk16Bf16Fp6Op::getOperationName();
2756 if (destElemType.isF32())
2757 return ROCDL::CvtPkScalePk16F32Fp6Op::getOperationName();
2758 return std::nullopt;
2759 }
2760 if (isa<bf6>(srcElemType)) {
2761 if (destElemType.isF16())
2762 return ROCDL::CvtPkScalePk16F16Bf6Op::getOperationName();
2763 if (destElemType.isBF16())
2764 return ROCDL::CvtPkScalePk16Bf16Bf6Op::getOperationName();
2765 if (destElemType.isF32())
2766 return ROCDL::CvtPkScalePk16F32Bf6Op::getOperationName();
2767 return std::nullopt;
2768 }
2769 llvm_unreachable("invalid combination of element types for packed conversion "
2770 "instructions");
2771}
2772
2773LogicalResult ScaledExtPackedMatrixOpLowering::matchAndRewrite(
2774 ScaledExtPackedMatrixOp op, ScaledExtPackedMatrixOpAdaptor adaptor,
2775 ConversionPatternRewriter &rewriter) const {
2776 using fp4 = Float4E2M1FNType;
2777 using fp8 = Float8E4M3FNType;
2778 using bf8 = Float8E5M2Type;
2779 using fp6 = Float6E2M3FNType;
2780 using bf6 = Float6E3M2FNType;
2781 Location loc = op.getLoc();
2782 if (chipset != kGfx1250) {
2783 return rewriter.notifyMatchFailure(
2784 loc,
2785 "Scaled fp packed conversion instructions are not available on target "
2786 "architecture and their emulation is not implemented");
2787 }
2788 // Convert user-facing firstScaleLane (0 or 16) to the half of the wave that
2789 // is being selected.
2790 int32_t scaleWaveHalf = op.getFirstScaleLane() / 16;
2791 int32_t firstScaleByte = op.getFirstScaleByte();
2792 int32_t blockSize = op.getBlockSize();
2793 auto sourceType = cast<VectorType>(op.getSource().getType());
2794 auto srcElemType = cast<FloatType>(sourceType.getElementType());
2795 unsigned bitWidth = srcElemType.getWidth();
2796
2797 auto targetType = cast<VectorType>(op.getResult().getType());
2798 auto destElemType = cast<FloatType>(targetType.getElementType());
2799
2800 IntegerType i32 = rewriter.getI32Type();
2801 Value source = adaptor.getSource();
2802 Type llvmResultType = typeConverter->convertType(op.getResult().getType());
2803 Type packedType = nullptr;
2804 if (isa<fp4>(srcElemType)) {
2805 packedType = i32;
2806 packedType = getTypeConverter()->convertType(packedType);
2807 } else if (isa<fp8, bf8>(srcElemType)) {
2808 packedType = VectorType::get(2, i32);
2809 packedType = getTypeConverter()->convertType(packedType);
2810 } else if (isa<fp6, bf6>(srcElemType)) {
2811 packedType = VectorType::get(3, i32);
2812 packedType = getTypeConverter()->convertType(packedType);
2813 } else {
2814 llvm_unreachable("invalid element type for packed scaled ext");
2815 }
2816
2817 if (!packedType || !llvmResultType) {
2818 return rewriter.notifyMatchFailure(op, "type conversion failed");
2819 }
2820
2821 std::optional<StringRef> maybeIntrinsic =
2822 scaledExtPacked816ToIntrinsic(srcElemType, destElemType);
2823 if (!maybeIntrinsic.has_value())
2824 return op.emitOpError(
2825 "no intrinsic matching packed scaled conversion on the given chipset");
2826
2827 int32_t scaleSel =
2828 getScaleSel(blockSize, bitWidth, scaleWaveHalf, firstScaleByte);
2829 Value castedScale =
2830 LLVM::BitcastOp::create(rewriter, loc, i32, adaptor.getScale());
2831 Value castedSource =
2832 LLVM::BitcastOp::create(rewriter, loc, packedType, source);
2833
2834 OperationState loweredOp(loc, *maybeIntrinsic);
2835 loweredOp.addTypes({llvmResultType});
2836 loweredOp.addOperands({castedSource, castedScale});
2837
2838 SmallVector<NamedAttribute, 1> attrs;
2839 attrs.push_back(
2840 NamedAttribute("scaleSel", rewriter.getI32IntegerAttr(scaleSel)));
2841
2842 loweredOp.addAttributes(attrs);
2843 Operation *lowered = rewriter.create(loweredOp);
2844 rewriter.replaceOp(op, lowered);
2845
2846 return success();
2847}
2848
2849LogicalResult ScaledExtPackedOpLowering::matchAndRewrite(
2850 ScaledExtPackedOp op, ScaledExtPackedOpAdaptor adaptor,
2851 ConversionPatternRewriter &rewriter) const {
2852 Location loc = op.getLoc();
2853 if (chipset != kGfx950)
2854 return rewriter.notifyMatchFailure(
2855 loc, "Scaled fp conversion instructions are not available on target "
2856 "architecture and their emulation is not implemented");
2857 Type i32 = getTypeConverter()->convertType(rewriter.getI32Type());
2858
2859 Value source = adaptor.getSource();
2860 Value scale = adaptor.getScale();
2861
2862 VectorType sourceVecType = cast<VectorType>(op.getSource().getType());
2863 Type sourceElemType = sourceVecType.getElementType();
2864 VectorType destVecType = cast<VectorType>(op.getResult().getType());
2865 Type destElemType = destVecType.getElementType();
2866
2867 VectorType packedVecType;
2868 if (isa<Float8E5M2Type, Float8E4M3FNType>(sourceElemType)) {
2869 VectorType v4i8 = VectorType::get(4, rewriter.getI8Type());
2870 packedVecType = cast<VectorType>(getTypeConverter()->convertType(v4i8));
2871 } else if (isa<Float4E2M1FNType>(sourceElemType)) {
2872 VectorType v8i4 = VectorType::get(8, rewriter.getI4Type());
2873 packedVecType = cast<VectorType>(getTypeConverter()->convertType(v8i4));
2874 } else {
2875 llvm_unreachable("invalid element type for scaled ext");
2876 }
2877
2878 // Extend to a packedVectorType
2879 if (sourceVecType.getNumElements() < packedVecType.getNumElements()) {
2880 Value longVec = LLVM::ZeroOp::create(rewriter, loc, packedVecType);
2881 if (!sourceVecType) {
2882 longVec = LLVM::InsertElementOp::create(
2883 rewriter, loc, longVec, source, createI32Constant(rewriter, loc, 0));
2884 } else {
2885 for (int32_t i = 0, e = sourceVecType.getNumElements(); i < e; ++i) {
2886 Value idx = createI32Constant(rewriter, loc, i);
2887 Value elem = LLVM::ExtractElementOp::create(rewriter, loc, source, idx);
2888 longVec =
2889 LLVM::InsertElementOp::create(rewriter, loc, longVec, elem, idx);
2890 }
2891 }
2892 source = longVec;
2893 }
2894 Value i32Source = LLVM::BitcastOp::create(rewriter, loc, i32, source);
2895
2896 if (isa<Float8E5M2Type>(sourceElemType) && destElemType.isF32())
2897 rewriter.replaceOpWithNewOp<ROCDL::CvtScaleF32PkF32Bf8Op>(
2898 op, destVecType, i32Source, scale, op.getIndex());
2899 else if (isa<Float8E5M2Type>(sourceElemType) && destElemType.isF16())
2900 rewriter.replaceOpWithNewOp<ROCDL::CvtScaleF32PkF16Bf8Op>(
2901 op, destVecType, i32Source, scale, op.getIndex());
2902 else if (isa<Float8E5M2Type>(sourceElemType) && destElemType.isBF16())
2903 rewriter.replaceOpWithNewOp<ROCDL::CvtScaleF32PkBf16Bf8Op>(
2904 op, destVecType, i32Source, scale, op.getIndex());
2905 else if (isa<Float8E4M3FNType>(sourceElemType) && destElemType.isF32())
2906 rewriter.replaceOpWithNewOp<ROCDL::CvtScaleF32PkF32Fp8Op>(
2907 op, destVecType, i32Source, scale, op.getIndex());
2908 else if (isa<Float8E4M3FNType>(sourceElemType) && destElemType.isF16())
2909 rewriter.replaceOpWithNewOp<ROCDL::CvtScaleF32PkF16Fp8Op>(
2910 op, destVecType, i32Source, scale, op.getIndex());
2911 else if (isa<Float8E4M3FNType>(sourceElemType) && destElemType.isBF16())
2912 rewriter.replaceOpWithNewOp<ROCDL::CvtScaleF32PkBf16Fp8Op>(
2913 op, destVecType, i32Source, scale, op.getIndex());
2914 else if (isa<Float4E2M1FNType>(sourceElemType) && destElemType.isF32())
2915 rewriter.replaceOpWithNewOp<ROCDL::CvtScaleF32PkF32Fp4Op>(
2916 op, destVecType, i32Source, scale, op.getIndex());
2917 else if (isa<Float4E2M1FNType>(sourceElemType) && destElemType.isF16())
2918 rewriter.replaceOpWithNewOp<ROCDL::CvtScaleF32PkF16Fp4Op>(
2919 op, destVecType, i32Source, scale, op.getIndex());
2920 else if (isa<Float4E2M1FNType>(sourceElemType) && destElemType.isBF16())
2921 rewriter.replaceOpWithNewOp<ROCDL::CvtScaleF32PkBf16Fp4Op>(
2922 op, destVecType, i32Source, scale, op.getIndex());
2923 else
2924 return failure();
2925
2926 return success();
2927}
2928
2929LogicalResult PackedScaledTruncOpLowering::matchAndRewrite(
2930 PackedScaledTruncOp op, PackedScaledTruncOpAdaptor adaptor,
2931 ConversionPatternRewriter &rewriter) const {
2932 Location loc = op.getLoc();
2933 if (chipset != kGfx950)
2934 return rewriter.notifyMatchFailure(
2935 loc, "Scaled fp conversion instructions are not available on target "
2936 "architecture and their emulation is not implemented");
2937 Type v2i16 = getTypeConverter()->convertType(
2938 VectorType::get(2, rewriter.getI16Type()));
2939 Type i32 = getTypeConverter()->convertType(rewriter.getI32Type());
2940
2941 Type resultType = op.getResult().getType();
2942 Type resultElemType = getElementTypeOrSelf(resultType);
2943 VectorType sourceVecType = cast<VectorType>(op.getSource().getType());
2944 Type sourceElemType = sourceVecType.getElementType();
2945
2946 Type intResultType = isa<Float4E2M1FNType>(resultElemType) ? i32 : v2i16;
2947
2948 Value source = adaptor.getSource();
2949 Value scale = adaptor.getScale();
2950 Value existing = adaptor.getExisting();
2951 if (existing)
2952 existing = LLVM::BitcastOp::create(rewriter, loc, intResultType, existing);
2953 else
2954 existing = LLVM::ZeroOp::create(rewriter, loc, intResultType);
2955
2956 if (sourceVecType.getNumElements() < 2) {
2957 Value c0 = createI32Constant(rewriter, loc, 0);
2958 Value elem0 = LLVM::ExtractElementOp::create(rewriter, loc, source, c0);
2959 VectorType v2 = VectorType::get(2, sourceElemType);
2960 source = LLVM::ZeroOp::create(rewriter, loc, v2);
2961 source = LLVM::InsertElementOp::create(rewriter, loc, source, elem0, c0);
2962 }
2963
2964 Value sourceA, sourceB;
2965 if (sourceElemType.isF32()) {
2966 Value c0 = createI32Constant(rewriter, loc, 0);
2967 Value c1 = createI32Constant(rewriter, loc, 1);
2968 sourceA = LLVM::ExtractElementOp::create(rewriter, loc, source, c0);
2969 sourceB = LLVM::ExtractElementOp::create(rewriter, loc, source, c1);
2970 }
2971
2972 Value result;
2973 if (sourceElemType.isF32() && isa<Float8E5M2Type>(resultElemType))
2974 result = ROCDL::CvtScaleF32PkBf8F32Op::create(rewriter, loc, intResultType,
2975 existing, sourceA, sourceB,
2976 scale, op.getIndex());
2977 else if (sourceElemType.isF16() && isa<Float8E5M2Type>(resultElemType))
2978 result = ROCDL::CvtScaleF32PkBf8F16Op::create(
2979 rewriter, loc, intResultType, existing, source, scale, op.getIndex());
2980 else if (sourceElemType.isBF16() && isa<Float8E5M2Type>(resultElemType))
2981 result = ROCDL::CvtScaleF32PkBf8Bf16Op::create(
2982 rewriter, loc, intResultType, existing, source, scale, op.getIndex());
2983 else if (sourceElemType.isF32() && isa<Float8E4M3FNType>(resultElemType))
2984 result = ROCDL::CvtScaleF32PkFp8F32Op::create(rewriter, loc, intResultType,
2985 existing, sourceA, sourceB,
2986 scale, op.getIndex());
2987 else if (sourceElemType.isF16() && isa<Float8E4M3FNType>(resultElemType))
2988 result = ROCDL::CvtScaleF32PkFp8F16Op::create(
2989 rewriter, loc, intResultType, existing, source, scale, op.getIndex());
2990 else if (sourceElemType.isBF16() && isa<Float8E4M3FNType>(resultElemType))
2991 result = ROCDL::CvtScaleF32PkFp8Bf16Op::create(
2992 rewriter, loc, intResultType, existing, source, scale, op.getIndex());
2993 else if (sourceElemType.isF32() && isa<Float4E2M1FNType>(resultElemType))
2994 result = ROCDL::CvtScaleF32PkFp4F32Op::create(rewriter, loc, intResultType,
2995 existing, sourceA, sourceB,
2996 scale, op.getIndex());
2997 else if (sourceElemType.isF16() && isa<Float4E2M1FNType>(resultElemType))
2998 result = ROCDL::CvtScaleF32PkFp4F16Op::create(
2999 rewriter, loc, intResultType, existing, source, scale, op.getIndex());
3000 else if (sourceElemType.isBF16() && isa<Float4E2M1FNType>(resultElemType))
3001 result = ROCDL::CvtScaleF32PkFp4Bf16Op::create(
3002 rewriter, loc, intResultType, existing, source, scale, op.getIndex());
3003 else
3004 return failure();
3005
3006 result = rewriter.replaceOpWithNewOp<LLVM::BitcastOp>(
3007 op, getTypeConverter()->convertType(resultType), result);
3008 return success();
3009}
3010
3011LogicalResult PackedTrunc2xFp8OpLowering::matchAndRewrite(
3012 PackedTrunc2xFp8Op op, PackedTrunc2xFp8OpAdaptor adaptor,
3013 ConversionPatternRewriter &rewriter) const {
3014 Location loc = op.getLoc();
3015 if (!(chipset == kGfx942 || hasOcpFp8(chipset)))
3016 return rewriter.notifyMatchFailure(
3017 loc, "Fp8 conversion instructions are not available on target "
3018 "architecture and their emulation is not implemented");
3019 Type i32 = getTypeConverter()->convertType(rewriter.getI32Type());
3020
3021 Type resultType = op.getResult().getType();
3022 Type resultElemType = getElementTypeOrSelf(resultType);
3023
3024 Value sourceA = adaptor.getSourceA();
3025 Value sourceB = adaptor.getSourceB();
3026 if (!sourceB)
3027 sourceB = LLVM::UndefOp::create(rewriter, loc, sourceA.getType());
3028 Value existing = adaptor.getExisting();
3029 if (existing)
3030 existing = LLVM::BitcastOp::create(rewriter, loc, i32, existing);
3031 else
3032 existing = LLVM::UndefOp::create(rewriter, loc, i32);
3033
3034 Value result;
3035 if (typeIsExpectedBf8ForChipset(chipset, resultElemType))
3036 result = ROCDL::CvtPkBf8F32Op::create(rewriter, loc, i32, sourceA, sourceB,
3037 existing, op.getWordIndex());
3038 else if (typeIsExpectedFp8ForChipset(chipset, resultElemType))
3039 result = ROCDL::CvtPkFp8F32Op::create(rewriter, loc, i32, sourceA, sourceB,
3040 existing, op.getWordIndex());
3041
3042 result = rewriter.replaceOpWithNewOp<LLVM::BitcastOp>(
3043 op, getTypeConverter()->convertType(resultType), result);
3044 return success();
3045}
3046
3047LogicalResult PackedStochRoundFp8OpLowering::matchAndRewrite(
3048 PackedStochRoundFp8Op op, PackedStochRoundFp8OpAdaptor adaptor,
3049 ConversionPatternRewriter &rewriter) const {
3050 Location loc = op.getLoc();
3051 if (!(chipset == kGfx942 || hasOcpFp8(chipset)))
3052 return rewriter.notifyMatchFailure(
3053 loc, "Fp8 conversion instructions are not available on target "
3054 "architecture and their emulation is not implemented");
3055 Type i32 = getTypeConverter()->convertType(rewriter.getI32Type());
3056
3057 Type resultType = op.getResult().getType();
3058 Type resultElemType = getElementTypeOrSelf(resultType);
3059
3060 Value source = adaptor.getSource();
3061 Value stoch = adaptor.getStochiasticParam();
3062 Value existing = adaptor.getExisting();
3063 if (existing)
3064 existing = LLVM::BitcastOp::create(rewriter, loc, i32, existing);
3065 else
3066 existing = LLVM::UndefOp::create(rewriter, loc, i32);
3067
3068 Value result;
3069 if (typeIsExpectedBf8ForChipset(chipset, resultElemType))
3070 result = ROCDL::CvtSrBf8F32Op::create(rewriter, loc, i32, source, stoch,
3071 existing, op.getStoreIndex());
3072 else if (typeIsExpectedFp8ForChipset(chipset, resultElemType))
3073 result = ROCDL::CvtSrFp8F32Op::create(rewriter, loc, i32, source, stoch,
3074 existing, op.getStoreIndex());
3075
3076 result = rewriter.replaceOpWithNewOp<LLVM::BitcastOp>(
3077 op, getTypeConverter()->convertType(resultType), result);
3078 return success();
3079}
3080
3081// Implement the AMDGPU_DPPLowering class that will convert the amdgpu.dpp
3082// operation into the corresponding ROCDL instructions.
3083struct AMDGPUDPPLowering : public ConvertOpToLLVMPattern<DPPOp> {
3084 AMDGPUDPPLowering(const LLVMTypeConverter &converter, Chipset chipset)
3085 : ConvertOpToLLVMPattern<DPPOp>(converter), chipset(chipset) {}
3086 Chipset chipset;
3087
3088 LogicalResult
3089 matchAndRewrite(DPPOp DppOp, DPPOp::Adaptor adaptor,
3090 ConversionPatternRewriter &rewriter) const override {
3091
3092 // Convert the source operand to the corresponding LLVM type
3093 Location loc = DppOp.getLoc();
3094 Value src = adaptor.getSrc();
3095 Value old = adaptor.getOld();
3096 Type srcType = src.getType();
3097 Type oldType = old.getType();
3098 Type llvmType = nullptr;
3099 if (srcType.getIntOrFloatBitWidth() < 32) {
3100 llvmType = rewriter.getI32Type();
3101 } else if (isa<FloatType>(srcType)) {
3102 llvmType = (srcType.getIntOrFloatBitWidth() == 32)
3103 ? rewriter.getF32Type()
3104 : rewriter.getF64Type();
3105 } else if (isa<IntegerType>(srcType)) {
3106 llvmType = (srcType.getIntOrFloatBitWidth() == 32)
3107 ? rewriter.getI32Type()
3108 : rewriter.getI64Type();
3109 }
3110 auto llvmSrcIntType = typeConverter->convertType(
3111 rewriter.getIntegerType(srcType.getIntOrFloatBitWidth()));
3112
3113 // If the source type is less of 32, use bitcast to convert it to i32.
3114 auto convertOperand = [&](Value operand, Type operandType) {
3115 if (operandType.getIntOrFloatBitWidth() <= 16) {
3116 if (llvm::isa<FloatType>(operandType)) {
3117 operand =
3118 LLVM::BitcastOp::create(rewriter, loc, llvmSrcIntType, operand);
3119 }
3120 auto llvmVecType = typeConverter->convertType(mlir::VectorType::get(
3121 32 / operandType.getIntOrFloatBitWidth(), llvmSrcIntType));
3122 Value undefVec = LLVM::UndefOp::create(rewriter, loc, llvmVecType);
3123 operand =
3124 LLVM::InsertElementOp::create(rewriter, loc, undefVec, operand,
3125 createI32Constant(rewriter, loc, 0));
3126 operand = LLVM::BitcastOp::create(rewriter, loc, llvmType, operand);
3127 }
3128 return operand;
3129 };
3130
3131 src = convertOperand(src, srcType);
3132 old = convertOperand(old, oldType);
3133
3134 // This is taken from the following file llvm/lib/Target/AMDGPU/SIDefines.h
3135 enum DppCtrl : unsigned {
3136 ROW_SHL0 = 0x100,
3137 ROW_SHR0 = 0x110,
3138 ROW_ROR0 = 0x120,
3139 WAVE_SHL1 = 0x130,
3140 WAVE_ROL1 = 0x134,
3141 WAVE_SHR1 = 0x138,
3142 WAVE_ROR1 = 0x13C,
3143 ROW_MIRROR = 0x140,
3144 ROW_HALF_MIRROR = 0x141,
3145 BCAST15 = 0x142,
3146 BCAST31 = 0x143,
3147 };
3148
3149 auto kind = DppOp.getKind();
3150 auto permArgument = DppOp.getPermArgument();
3151 uint32_t DppCtrl = 0;
3152
3153 switch (kind) {
3154
3155 case DPPPerm::quad_perm: {
3156 auto quadPermAttr = cast<ArrayAttr>(*permArgument);
3157 int32_t i = 0;
3158 for (auto elem : quadPermAttr.getAsRange<IntegerAttr>()) {
3159 uint32_t num = elem.getInt();
3160 DppCtrl |= num << (i * 2);
3161 i++;
3162 }
3163 break;
3164 }
3165 case DPPPerm::row_shl: {
3166 auto intAttr = cast<IntegerAttr>(*permArgument);
3167 DppCtrl = intAttr.getInt() + DppCtrl::ROW_SHL0;
3168 break;
3169 }
3170 case DPPPerm::row_shr: {
3171 auto intAttr = cast<IntegerAttr>(*permArgument);
3172 DppCtrl = intAttr.getInt() + DppCtrl::ROW_SHR0;
3173 break;
3174 }
3175 case DPPPerm::row_ror: {
3176 auto intAttr = cast<IntegerAttr>(*permArgument);
3177 DppCtrl = intAttr.getInt() + DppCtrl::ROW_ROR0;
3178 break;
3179 }
3180 case DPPPerm::wave_shl:
3181 DppCtrl = DppCtrl::WAVE_SHL1;
3182 break;
3183 case DPPPerm::wave_shr:
3184 DppCtrl = DppCtrl::WAVE_SHR1;
3185 break;
3186 case DPPPerm::wave_rol:
3187 DppCtrl = DppCtrl::WAVE_ROL1;
3188 break;
3189 case DPPPerm::wave_ror:
3190 DppCtrl = DppCtrl::WAVE_ROR1;
3191 break;
3192 case DPPPerm::row_mirror:
3193 DppCtrl = DppCtrl::ROW_MIRROR;
3194 break;
3195 case DPPPerm::row_half_mirror:
3196 DppCtrl = DppCtrl::ROW_HALF_MIRROR;
3197 break;
3198 case DPPPerm::row_bcast_15:
3199 DppCtrl = DppCtrl::BCAST15;
3200 break;
3201 case DPPPerm::row_bcast_31:
3202 DppCtrl = DppCtrl::BCAST31;
3203 break;
3204 }
3205
3206 // Check for row_mask, bank_mask, bound_ctrl if they exist and create
3207 // constants
3208 auto rowMask = DppOp->getAttrOfType<IntegerAttr>("row_mask").getInt();
3209 auto bankMask = DppOp->getAttrOfType<IntegerAttr>("bank_mask").getInt();
3210 bool boundCtrl = DppOp->getAttrOfType<BoolAttr>("bound_ctrl").getValue();
3211
3212 // create a ROCDL_DPPMovOp instruction with the appropriate attributes
3213 auto dppMovOp =
3214 ROCDL::DPPUpdateOp::create(rewriter, loc, llvmType, old, src, DppCtrl,
3215 rowMask, bankMask, boundCtrl);
3216
3217 Value result = dppMovOp.getRes();
3218 if (srcType.getIntOrFloatBitWidth() < 32) {
3219 result = LLVM::TruncOp::create(rewriter, loc, llvmSrcIntType, result);
3220 if (!llvm::isa<IntegerType>(srcType)) {
3221 result = LLVM::BitcastOp::create(rewriter, loc, srcType, result);
3222 }
3223 }
3224
3225 // We are replacing the AMDGPU_DPPOp instruction with the new
3226 // ROCDL_DPPMovOp instruction
3227 rewriter.replaceOp(DppOp, ValueRange(result));
3228 return success();
3229 }
3230};
3231
3232struct AMDGPUSwizzleBitModeLowering
3233 : public ConvertOpToLLVMPattern<SwizzleBitModeOp> {
3235
3236 LogicalResult
3237 matchAndRewrite(SwizzleBitModeOp op, OpAdaptor adaptor,
3238 ConversionPatternRewriter &rewriter) const override {
3239 Location loc = op.getLoc();
3240 Type i32 = rewriter.getI32Type();
3241 Value src = adaptor.getSrc();
3242 SmallVector<Value> decomposed;
3243 if (failed(LLVM::decomposeValue(rewriter, loc, src, i32, decomposed)))
3244 return rewriter.notifyMatchFailure(op,
3245 "failed to decompose value to i32");
3246 unsigned andMask = op.getAndMask();
3247 unsigned orMask = op.getOrMask();
3248 unsigned xorMask = op.getXorMask();
3249
3250 // bit 15 is 0 for the BitMode swizzle.
3251 // https://gpuopen.com/learn/amd-gcn-assembly-cross-lane-operations/
3252 unsigned mask = andMask | (orMask << 5) | (xorMask << 10);
3253 Value maskValue = createI32Constant(rewriter, loc, mask);
3254 SmallVector<Value> swizzled;
3255 for (Value v : decomposed) {
3256 Value res =
3257 ROCDL::DsSwizzleOp::create(rewriter, loc, v.getType(), v, maskValue);
3258 swizzled.emplace_back(res);
3259 }
3260
3261 Value result = LLVM::composeValue(rewriter, loc, swizzled, src.getType());
3262 rewriter.replaceOp(op, result);
3263 return success();
3264 }
3265};
3266
3267struct AMDGPUPermlaneLowering : public ConvertOpToLLVMPattern<PermlaneSwapOp> {
3269
3270 AMDGPUPermlaneLowering(const LLVMTypeConverter &converter, Chipset chipset)
3271 : ConvertOpToLLVMPattern<PermlaneSwapOp>(converter), chipset(chipset) {}
3272 Chipset chipset;
3273
3274 LogicalResult
3275 matchAndRewrite(PermlaneSwapOp op, OpAdaptor adaptor,
3276 ConversionPatternRewriter &rewriter) const override {
3277 if (chipset < kGfx950)
3278 return op->emitOpError("permlane_swap is only supported on gfx950+");
3279
3280 Location loc = op.getLoc();
3281 Type i32 = rewriter.getI32Type();
3282 Value src = adaptor.getSrc();
3283 unsigned rowLength = op.getRowLength();
3284 bool fi = op.getFetchInactive();
3285 bool boundctrl = op.getBoundCtrl();
3286
3287 SmallVector<Value> decomposed;
3288 if (failed(LLVM::decomposeValue(rewriter, loc, src, i32, decomposed)))
3289 return rewriter.notifyMatchFailure(op,
3290 "failed to decompose value to i32");
3291
3292 SmallVector<Value> permuted;
3293 for (Value v : decomposed) {
3294 Value res;
3295 Type i32pair = LLVM::LLVMStructType::getLiteral(
3296 rewriter.getContext(), {v.getType(), v.getType()});
3297
3298 if (rowLength == 16)
3299 res = ROCDL::Permlane16SwapOp::create(rewriter, loc, i32pair, v, v, fi,
3300 boundctrl);
3301 else if (rowLength == 32)
3302 res = ROCDL::Permlane32SwapOp::create(rewriter, loc, i32pair, v, v, fi,
3303 boundctrl);
3304 else
3305 llvm_unreachable("unsupported row length");
3306
3307 Value vdst0 = LLVM::ExtractValueOp::create(rewriter, loc, res, {0});
3308 Value vdst1 = LLVM::ExtractValueOp::create(rewriter, loc, res, {1});
3309
3310 Value isEqual = LLVM::ICmpOp::create(rewriter, loc,
3311 LLVM::ICmpPredicate::eq, vdst0, v);
3312
3313 // Per `permlane(16|32)` semantics: if the first extracted element equals
3314 // 'v', the result is the second element; otherwise it is the first.
3315 Value vdstNew =
3316 LLVM::SelectOp::create(rewriter, loc, isEqual, vdst1, vdst0);
3317 permuted.emplace_back(vdstNew);
3318 }
3319
3320 Value result = LLVM::composeValue(rewriter, loc, permuted, src.getType());
3321 rewriter.replaceOp(op, result);
3322 return success();
3323 }
3324};
3325
3326struct AMDGPUPermlaneVarLowering
3327 : public ConvertOpToLLVMPattern<PermlaneVarOp> {
3329
3330 AMDGPUPermlaneVarLowering(const LLVMTypeConverter &converter, Chipset chipset)
3331 : ConvertOpToLLVMPattern<PermlaneVarOp>(converter), chipset(chipset) {}
3332 Chipset chipset;
3333
3334 LogicalResult
3335 matchAndRewrite(PermlaneVarOp op, OpAdaptor adaptor,
3336 ConversionPatternRewriter &rewriter) const override {
3337 if (chipset < kGfx1200)
3338 return op->emitOpError("permlane_var is only supported on GFX12+");
3339
3340 Location loc = op.getLoc();
3341 Type i32 = rewriter.getI32Type();
3342 Value src = adaptor.getSrc();
3343 Value selector = adaptor.getSelector();
3344 bool cross = op.getCross();
3345 bool fi = op.getFetchInactive();
3346 bool boundCtrl = op.getBoundCtrl();
3347
3348 SmallVector<Value> decomposed;
3349 if (failed(LLVM::decomposeValue(rewriter, loc, src, i32, decomposed)))
3350 return rewriter.notifyMatchFailure(op,
3351 "failed to decompose value to i32");
3352
3353 SmallVector<Value> permuted;
3354 for (Value v : decomposed) {
3355 Value res;
3356 if (cross)
3357 res = ROCDL::PermlaneX16VarOp::create(rewriter, loc, i32, v, v,
3358 selector, fi, boundCtrl);
3359 else
3360 res = ROCDL::Permlane16VarOp::create(rewriter, loc, i32, v, v, selector,
3361 fi, boundCtrl);
3362 permuted.emplace_back(res);
3363 }
3364
3365 Value result = LLVM::composeValue(rewriter, loc, permuted, src.getType());
3366 rewriter.replaceOp(op, result);
3367 return success();
3368 }
3369};
3370
3371//===----------------------------------------------------------------------===//
3372// In-LDS Barrier Operations
3373//===----------------------------------------------------------------------===//
3374
3375// Bit layout of ds_barrier_state (as i64):
3376// [63:32] init count (32 bits)
3377// [31:29] phase (3 bits)
3378// [28:0] pending count (29 bits)
3379constexpr int32_t kDsBarrierPendingCountBitWidth = 29;
3380constexpr int32_t kDsBarrierPhasePos = kDsBarrierPendingCountBitWidth;
3381constexpr int32_t kDsBarrierInitCountPos = 32;
3382constexpr int32_t kDsBarrierPendingCountMask =
3383 (1 << kDsBarrierPendingCountBitWidth) - 1;
3384
3385struct DsBarrierInitOpLowering
3386 : public ConvertOpToLLVMPattern<DsBarrierInitOp> {
3387 Chipset chipset;
3388
3389 DsBarrierInitOpLowering(const LLVMTypeConverter &converter, Chipset chipset)
3390 : ConvertOpToLLVMPattern<DsBarrierInitOp>(converter), chipset(chipset) {}
3391
3392 LogicalResult
3393 matchAndRewrite(DsBarrierInitOp op, OpAdaptor adaptor,
3394 ConversionPatternRewriter &rewriter) const override {
3395 if (chipset < kGfx1250)
3396 return op->emitOpError("only supported on gfx1250+");
3397
3398 Location loc = op.getLoc();
3399 Type i64 = rewriter.getI64Type();
3400
3401 MemRefType memrefType = cast<MemRefType>(op.getBase().getType());
3402 Value ptr = getStridedElementPtr(rewriter, loc, memrefType,
3403 adaptor.getBase(), adaptor.getIndices());
3404
3405 // Note: We give participants as the number of arrivals that have to occur
3406 // before the phase changes. Hardware changes the phase when updating the
3407 // pending count would underflow, so we subtract 1 to get the behavior we're
3408 // looking for.
3409 Value initCount =
3410 LLVM::SubOp::create(rewriter, loc, adaptor.getParticipants(),
3411 createI32Constant(rewriter, loc, 1));
3412
3413 // Just a bit of paranoia, but this also allows for configurable width if
3414 // that becomes a thing.
3415 Value countMask =
3416 createI32Constant(rewriter, loc, kDsBarrierPendingCountMask);
3417 Value maskedCount32 =
3418 LLVM::AndOp::create(rewriter, loc, initCount, countMask);
3419 Value maskedCount = LLVM::ZExtOp::create(rewriter, loc, i64, maskedCount32);
3420
3421 Value initCountShifted = LLVM::ShlOp::create(
3422 rewriter, loc, maskedCount,
3423 createI64Constant(rewriter, loc, kDsBarrierInitCountPos));
3424 Value barrierState =
3425 LLVM::OrOp::create(rewriter, loc, initCountShifted, maskedCount);
3426
3427 LLVM::StoreOp::create(
3428 rewriter, loc, barrierState, ptr, /*alignment=*/8, /*isVolatile=*/false,
3429 /*isNonTemporal=*/false,
3430 /*isInvariantGroup=*/false, LLVM::AtomicOrdering::release,
3431 /*syncscope=*/"workgroup");
3432
3433 rewriter.eraseOp(op);
3434 return success();
3435 }
3436};
3437
3438struct DsBarrierPollStateOpLowering
3439 : public ConvertOpToLLVMPattern<DsBarrierPollStateOp> {
3440 Chipset chipset;
3441
3442 DsBarrierPollStateOpLowering(const LLVMTypeConverter &converter,
3443 Chipset chipset)
3444 : ConvertOpToLLVMPattern<DsBarrierPollStateOp>(converter),
3445 chipset(chipset) {}
3446
3447 LogicalResult
3448 matchAndRewrite(DsBarrierPollStateOp op, OpAdaptor adaptor,
3449 ConversionPatternRewriter &rewriter) const override {
3450 if (chipset < kGfx1250)
3451 return op->emitOpError("only supported on gfx1250+");
3452
3453 Location loc = op.getLoc();
3454 Type i64 = rewriter.getI64Type();
3455
3456 MemRefType memrefType = cast<MemRefType>(op.getBase().getType());
3457 Value ptr = getStridedElementPtr(rewriter, loc, memrefType,
3458 adaptor.getBase(), adaptor.getIndices());
3459
3460 // Atomic load with workgroup scope and acquire ordering should be what
3461 // we're looking for.
3462 rewriter.replaceOpWithNewOp<LLVM::LoadOp>(
3463 op, i64, ptr, /*alignment=*/8, /*volatile_=*/false,
3464 /*nontemporal=*/false, /*invariant=*/false,
3465 /*invariantGroup=*/false, LLVM::AtomicOrdering::acquire,
3466 /*syncscope=*/"workgroup");
3467 return success();
3468 }
3469};
3470
3471struct DsAsyncBarrierArriveOpLowering
3472 : public ConvertOpToLLVMPattern<DsAsyncBarrierArriveOp> {
3473 Chipset chipset;
3474
3475 DsAsyncBarrierArriveOpLowering(const LLVMTypeConverter &converter,
3476 Chipset chipset)
3477 : ConvertOpToLLVMPattern<DsAsyncBarrierArriveOp>(converter),
3478 chipset(chipset) {}
3479
3480 LogicalResult
3481 matchAndRewrite(DsAsyncBarrierArriveOp op, OpAdaptor adaptor,
3482 ConversionPatternRewriter &rewriter) const override {
3483 if (chipset < kGfx1250)
3484 return op->emitOpError("only supported on gfx1250+");
3485
3486 Location loc = op.getLoc();
3487
3488 MemRefType memrefType = cast<MemRefType>(op.getBase().getType());
3489 Value ptr = getStridedElementPtr(rewriter, loc, memrefType,
3490 adaptor.getBase(), adaptor.getIndices());
3491
3492 rewriter.replaceOpWithNewOp<ROCDL::DsAtomicAsyncBarrierArriveOp>(
3493 op, ptr, /*alias_scopes=*/nullptr, /*noalias_scopes=*/nullptr,
3494 /*tbaa=*/nullptr);
3495 return success();
3496 }
3497};
3498
3499struct DsBarrierArriveOpLowering
3500 : public ConvertOpToLLVMPattern<DsBarrierArriveOp> {
3501 Chipset chipset;
3502
3503 DsBarrierArriveOpLowering(const LLVMTypeConverter &converter, Chipset chipset)
3504 : ConvertOpToLLVMPattern<DsBarrierArriveOp>(converter), chipset(chipset) {
3505 }
3506
3507 LogicalResult
3508 matchAndRewrite(DsBarrierArriveOp op, OpAdaptor adaptor,
3509 ConversionPatternRewriter &rewriter) const override {
3510 if (chipset < kGfx1250)
3511 return op->emitOpError("only supported on gfx1250+");
3512
3513 Location loc = op.getLoc();
3514 Type i64 = rewriter.getI64Type();
3515
3516 MemRefType memrefType = cast<MemRefType>(op.getBase().getType());
3517 Value ptr = getStridedElementPtr(rewriter, loc, memrefType,
3518 adaptor.getBase(), adaptor.getIndices());
3519
3520 rewriter.replaceOpWithNewOp<ROCDL::DsAtomicBarrierArriveRtnOp>(
3521 op, i64, ptr, adaptor.getCount(), /*alias_scopes=*/nullptr,
3522 /*noalias_scopes=*/nullptr, /*tbaa=*/nullptr);
3523 return success();
3524 }
3525};
3526
3527struct DsBarrierStatePhaseOpLowering
3528 : public ConvertOpToLLVMPattern<DsBarrierStatePhaseOp> {
3530
3531 LogicalResult
3532 matchAndRewrite(DsBarrierStatePhaseOp op, OpAdaptor adaptor,
3533 ConversionPatternRewriter &rewriter) const override {
3534 Location loc = op.getLoc();
3535 Type i32 = rewriter.getI32Type();
3536
3537 Value state = adaptor.getState();
3538
3539 Value noInitCount = LLVM::TruncOp::create(rewriter, loc, i32, state);
3540 Value phase = LLVM::LShrOp::create(
3541 rewriter, loc, noInitCount,
3542 createI32Constant(rewriter, loc, kDsBarrierPhasePos));
3543
3544 rewriter.replaceOp(op, phase);
3545 return success();
3546 }
3547};
3548
3549struct DsBarrierStatePendingCountOpLowering
3550 : public ConvertOpToLLVMPattern<DsBarrierStatePendingCountOp> {
3552
3553 LogicalResult
3554 matchAndRewrite(DsBarrierStatePendingCountOp op, OpAdaptor adaptor,
3555 ConversionPatternRewriter &rewriter) const override {
3556 Location loc = op.getLoc();
3557 Type i32 = rewriter.getI32Type();
3558
3559 Value state = adaptor.getState();
3560
3561 Value noInitCount = LLVM::TruncOp::create(rewriter, loc, i32, state);
3562 Value pendingCount = LLVM::AndOp::create(
3563 rewriter, loc, noInitCount,
3564 createI32Constant(rewriter, loc,
3565 static_cast<uint32_t>(kDsBarrierPendingCountMask)));
3566
3567 rewriter.replaceOp(op, pendingCount);
3568 return success();
3569 }
3570};
3571
3572struct DsBarrierStateInitCountOpLowering
3573 : public ConvertOpToLLVMPattern<DsBarrierStateInitCountOp> {
3575
3576 LogicalResult
3577 matchAndRewrite(DsBarrierStateInitCountOp op, OpAdaptor adaptor,
3578 ConversionPatternRewriter &rewriter) const override {
3579 Location loc = op.getLoc();
3580 Type i32 = rewriter.getI32Type();
3581
3582 Value state = adaptor.getState();
3583
3584 Value initCountI64 = LLVM::LShrOp::create(
3585 rewriter, loc, state,
3586 createI64Constant(rewriter, loc, kDsBarrierInitCountPos));
3587 Value initCount = LLVM::TruncOp::create(rewriter, loc, i32, initCountI64);
3588
3589 rewriter.replaceOp(op, initCount);
3590 return success();
3591 }
3592};
3593
3594struct DsBarrierStatePhaseParityLowering
3595 : public ConvertOpToLLVMPattern<DsBarrierStatePhaseParity> {
3597
3598 LogicalResult
3599 matchAndRewrite(DsBarrierStatePhaseParity op, OpAdaptor adaptor,
3600 ConversionPatternRewriter &rewriter) const override {
3601 Location loc = op.getLoc();
3602 Type i1 = rewriter.getI1Type();
3603
3604 Value state = adaptor.getState();
3605
3606 Value noInitCount =
3607 LLVM::TruncOp::create(rewriter, loc, rewriter.getI32Type(), state);
3608 Value phase = LLVM::LShrOp::create(
3609 rewriter, loc, noInitCount,
3610 createI32Constant(rewriter, loc, kDsBarrierPhasePos));
3611 Value parity = LLVM::TruncOp::create(rewriter, loc, i1, phase);
3612
3613 rewriter.replaceOp(op, parity);
3614 return success();
3615 }
3616};
3617
3618//===----------------------------------------------------------------------===//
3619// Tensor Data Mover (TDM)
3620//===----------------------------------------------------------------------===//
3621
3622static Value setValueAtOffset(ConversionPatternRewriter &rewriter, Location loc,
3623 Value accumulator, Value value, int64_t shift) {
3624 shift = shift % 32;
3625 Value shiftAmount;
3626 if (shift != 0) {
3627 shiftAmount = createI32Constant(rewriter, loc, shift % 32);
3628 value = LLVM::ShlOp::create(rewriter, loc, value, shiftAmount);
3629 }
3630
3631 if (matchPattern(accumulator, mlir::m_Zero()))
3632 return value;
3633
3634 constexpr bool isDisjoint = true;
3635 return LLVM::OrOp::create(rewriter, loc, accumulator, value, isDisjoint);
3636}
3637
3638template <typename BaseOp>
3639struct AMDGPUMakeDmaBaseLowering : public ConvertOpToLLVMPattern<BaseOp> {
3640 using ConvertOpToLLVMPattern<BaseOp>::ConvertOpToLLVMPattern;
3641 using Adaptor = typename ConvertOpToLLVMPattern<BaseOp>::OpAdaptor;
3642
3643 AMDGPUMakeDmaBaseLowering(const LLVMTypeConverter &converter, Chipset chipset)
3644 : ConvertOpToLLVMPattern<BaseOp>(converter), chipset(chipset) {}
3645 Chipset chipset;
3646
3647 LogicalResult
3648 matchAndRewrite(BaseOp op, Adaptor adaptor,
3649 ConversionPatternRewriter &rewriter) const override {
3650 if (chipset < kGfx1250)
3651 return op->emitOpError("make_dma_base is only supported on gfx1250");
3652
3653 Location loc = op.getLoc();
3654
3655 constexpr int32_t constlen = 4;
3656 Value consts[constlen];
3657 for (int64_t i = 0; i < constlen; ++i)
3658 consts[i] = createI32Constant(rewriter, loc, i);
3659
3660 constexpr int32_t sgprslen = constlen;
3661 Value sgprs[sgprslen];
3662 for (int64_t i = 0; i < sgprslen; ++i) {
3663 sgprs[i] = consts[0];
3664 }
3665
3666 sgprs[0] = consts[1];
3667
3668 if constexpr (BaseOp::isGather()) {
3669 sgprs[0] = setValueAtOffset(rewriter, loc, sgprs[0], consts[1], 30);
3670
3671 auto type = cast<TDMGatherBaseType>(op.getResult().getType());
3672 Type indexType = type.getIndexType();
3673 unsigned indexSize = indexType.getIntOrFloatBitWidth();
3674 assert(llvm::is_contained({16u, 32u}, indexSize) &&
3675 "expected index_size to be 16 or 32");
3676 unsigned idx = (indexSize / 16) - 1;
3677
3678 if (idx)
3679 sgprs[0] = setValueAtOffset(rewriter, loc, sgprs[0], consts[1], 31);
3680 }
3681
3682 ValueRange ldsIndices = adaptor.getLdsIndices();
3683 Value lds = adaptor.getLds();
3684 auto ldsMemRefType = cast<MemRefType>(op.getLds().getType());
3685
3687 rewriter, loc, ldsMemRefType, lds, ldsIndices);
3688
3689 ValueRange globalIndices = adaptor.getGlobalIndices();
3690 Value global = adaptor.getGlobal();
3691 auto globalMemRefType = cast<MemRefType>(op.getGlobal().getType());
3692
3694 rewriter, loc, globalMemRefType, global, globalIndices);
3695
3696 Type i32 = rewriter.getI32Type();
3697 Type i64 = rewriter.getI64Type();
3698
3699 sgprs[1] = LLVM::PtrToIntOp::create(rewriter, loc, i32, ldsPtr);
3700 Value castForGlobalAddr =
3701 LLVM::PtrToIntOp::create(rewriter, loc, i64, globalPtr);
3702
3703 sgprs[2] = LLVM::TruncOp::create(rewriter, loc, i32, castForGlobalAddr);
3704
3705 Value shift = LLVM::LShrOp::create(rewriter, loc, castForGlobalAddr,
3706 createI64Constant(rewriter, loc, 32));
3707
3708 Value highHalf = LLVM::TruncOp::create(rewriter, loc, i32, shift);
3709
3710 Value mask = createI32Constant(rewriter, loc, (1ull << 25) - 1);
3711 highHalf = LLVM::AndOp::create(rewriter, loc, highHalf, mask);
3712
3713 sgprs[3] = setValueAtOffset(rewriter, loc, highHalf, consts[2], 30);
3714
3715 Type v4i32 = this->typeConverter->convertType(VectorType::get(4, i32));
3716 assert(v4i32 && "expected type conversion to succeed");
3717 Value result = LLVM::PoisonOp::create(rewriter, loc, v4i32);
3718
3719 for (auto [sgpr, constant] : llvm::zip_equal(sgprs, consts))
3720 result =
3721 LLVM::InsertElementOp::create(rewriter, loc, result, sgpr, constant);
3722
3723 rewriter.replaceOp(op, result);
3724 return success();
3725 }
3726};
3727
3728template <typename DescriptorOp>
3729struct AMDGPULowerDescriptor : public ConvertOpToLLVMPattern<DescriptorOp> {
3730 using ConvertOpToLLVMPattern<DescriptorOp>::ConvertOpToLLVMPattern;
3731 using OpAdaptor = typename ConvertOpToLLVMPattern<DescriptorOp>::OpAdaptor;
3732
3733 AMDGPULowerDescriptor(const LLVMTypeConverter &converter, Chipset chipset)
3734 : ConvertOpToLLVMPattern<DescriptorOp>(converter), chipset(chipset) {}
3735 Chipset chipset;
3736
3737 Value getDGroup0(OpAdaptor adaptor) const { return adaptor.getBase(); }
3738
3739 Value setWorkgroupMask(DescriptorOp op, OpAdaptor adaptor,
3740 ConversionPatternRewriter &rewriter, Location loc,
3741 Value sgpr0) const {
3742 Value mask = op.getWorkgroupMask();
3743 if (!mask)
3744 return sgpr0;
3745
3746 Type i16 = rewriter.getI16Type();
3747 mask = LLVM::BitcastOp::create(rewriter, loc, i16, mask);
3748 Type i32 = rewriter.getI32Type();
3749 Value extendedMask = LLVM::ZExtOp::create(rewriter, loc, i32, mask);
3750 return setValueAtOffset(rewriter, loc, sgpr0, extendedMask, 0);
3751 }
3752
3753 Value setDataSize(DescriptorOp op, OpAdaptor adaptor,
3754 ConversionPatternRewriter &rewriter, Location loc,
3755 Value sgpr0, ArrayRef<Value> consts) const {
3756 unsigned elementTypeWidthInBits = op.getElementTypeWidth();
3757 assert(llvm::is_contained({8u, 16u, 32u, 64u}, elementTypeWidthInBits) &&
3758 "expected type width to be 8, 16, 32, or 64.");
3759 int64_t idx = llvm::Log2_32(elementTypeWidthInBits / 8);
3760 Value size = consts[idx];
3761 return setValueAtOffset(rewriter, loc, sgpr0, size, 16);
3762 }
3763
3764 Value setAtomicBarrier(DescriptorOp op, OpAdaptor adaptor,
3765 ConversionPatternRewriter &rewriter, Location loc,
3766 Value sgpr0, ArrayRef<Value> consts) const {
3767 if (!adaptor.getAtomicBarrierAddress())
3768 return sgpr0;
3769
3770 return setValueAtOffset(rewriter, loc, sgpr0, consts[1], 18);
3771 }
3772
3773 Value setIterateEnable(DescriptorOp op, OpAdaptor adaptor,
3774 ConversionPatternRewriter &rewriter, Location loc,
3775 Value sgpr0, ArrayRef<Value> consts) const {
3776 if (!adaptor.getGlobalIncrement())
3777 return sgpr0;
3778
3779 // Value is ignored when in gather mode.
3780 // TODO: emit error earlier?
3781 return setValueAtOffset(rewriter, loc, sgpr0, consts[1], 19);
3782 }
3783
3784 Value setPadEnable(DescriptorOp op, OpAdaptor adaptor,
3785 ConversionPatternRewriter &rewriter, Location loc,
3786 Value sgpr0, ArrayRef<Value> consts) const {
3787 if (!op.getPadAmount())
3788 return sgpr0;
3789
3790 return setValueAtOffset(rewriter, loc, sgpr0, consts[1], 20);
3791 }
3792
3793 Value setEarlyTimeout(DescriptorOp op, OpAdaptor adaptor,
3794 ConversionPatternRewriter &rewriter, Location loc,
3795 Value sgpr0, ArrayRef<Value> consts) const {
3796 if (!op.getWorkgroupMask())
3797 return sgpr0;
3798
3799 return setValueAtOffset(rewriter, loc, sgpr0, consts[1], 21);
3800 }
3801
3802 Value setPadInterval(DescriptorOp op, OpAdaptor adaptor,
3803 ConversionPatternRewriter &rewriter, Location loc,
3804 Value sgpr0, ArrayRef<Value> consts) const {
3805 if (!op.getPadAmount())
3806 return sgpr0;
3807
3808 // pre-condition: padInterval can be a power of two between 2 and 256.
3809 // TODO: Validation if the value breaks the pre-condition.
3810 // If the pre-condition fails, there is a possibility of
3811 // affecting the higher bits. In a following PR implement
3812 // RuntimeVerifiableOpInterface that instruments conditions that need to be
3813 // checked at runtime.
3814 IntegerType i32 = rewriter.getI32Type();
3815 Value padInterval = adaptor.getPadInterval();
3816 padInterval = LLVM::CountTrailingZerosOp::create(rewriter, loc, i32,
3817 padInterval, false);
3818 padInterval = LLVM::SubOp::create(rewriter, loc, padInterval, consts[1]);
3819 // post-condition: padInterval can be a value between 0 and 7.
3820 return setValueAtOffset(rewriter, loc, sgpr0, padInterval, 22);
3821 }
3822
3823 Value setPadAmount(DescriptorOp op, OpAdaptor adaptor,
3824 ConversionPatternRewriter &rewriter, Location loc,
3825 Value sgpr0, ArrayRef<Value> consts) const {
3826 if (!op.getPadAmount())
3827 return sgpr0;
3828
3829 // pre-condition: padAmount is a value between 1-128.
3830 // TODO: Validation if the value breaks the pre-condition.
3831 // If the pre-condition fails, there is a possibility of
3832 // affecting the higher bits. In a following PR implement
3833 // RuntimeVerifiableOpInterface that instruments conditions that need to be
3834 // checked at runtime.
3835 Value padAmount = adaptor.getPadAmount();
3836 padAmount = LLVM::SubOp::create(rewriter, loc, padAmount, consts[1]);
3837 // post-condition: padAmount is a value between 0-127.
3838 return setValueAtOffset(rewriter, loc, sgpr0, padAmount, 25);
3839 }
3840
3841 Value setAtomicBarrierAddress(DescriptorOp op, OpAdaptor adaptor,
3842 ConversionPatternRewriter &rewriter,
3843 Location loc, Value sgpr1,
3844 ArrayRef<Value> consts) const {
3845 if (!adaptor.getAtomicBarrierAddress())
3846 return sgpr1;
3847
3848 Value atomicBarrierAddress = adaptor.getAtomicBarrierAddress();
3849 auto barrierAddressTy =
3850 cast<MemRefType>(op.getAtomicBarrierAddress().getType());
3851 ValueRange atomicBarrierIndices = adaptor.getAtomicBarrierIndices();
3852 atomicBarrierAddress = ConvertToLLVMPattern::getStridedElementPtr(
3853 rewriter, loc, barrierAddressTy, atomicBarrierAddress,
3854 atomicBarrierIndices);
3855 IntegerType i32 = rewriter.getI32Type();
3856 // pre-condition: atomicBarrierAddress is aligned to 8 bytes which implies
3857 // that the 3 LSBs are zero.
3858 // TODO: Validation if the value breaks the pre-condition.
3859 // In a following PR implement RuntimeVerifiableOpInterface
3860 // that instruments conditions that need to be checked at runtime.
3861 atomicBarrierAddress =
3862 LLVM::PtrToIntOp::create(rewriter, loc, i32, atomicBarrierAddress);
3863 atomicBarrierAddress =
3864 LLVM::LShrOp::create(rewriter, loc, atomicBarrierAddress, consts[3]);
3865 Value mask = createI32Constant(rewriter, loc, 0xFFFF);
3866 atomicBarrierAddress =
3867 LLVM::AndOp::create(rewriter, loc, atomicBarrierAddress, mask);
3868 return setValueAtOffset(rewriter, loc, sgpr1, atomicBarrierAddress, 32);
3869 }
3870
3871 std::pair<Value, Value> setTensorDimX(DescriptorOp op, OpAdaptor adaptor,
3872 ConversionPatternRewriter &rewriter,
3873 Location loc, Value sgpr1, Value sgpr2,
3874 ArrayRef<Value> consts, uint64_t dimX,
3875 uint32_t offset) const {
3876 ArrayRef<int64_t> globalStaticSizes = adaptor.getGlobalStaticSizes();
3877 ValueRange globalDynamicSizes = adaptor.getGlobalDynamicSizes();
3878 SmallVector<OpFoldResult> mixedGlobalSizes =
3879 getMixedValues(globalStaticSizes, globalDynamicSizes, rewriter);
3880 if (mixedGlobalSizes.size() <= dimX)
3881 return {sgpr1, sgpr2};
3882
3883 OpFoldResult tensorDimXOpFoldResult = *(mixedGlobalSizes.rbegin() + dimX);
3884 // pre-condition: tensorDimX is less than 2^32-1
3885 // TODO: Validation if the value breaks the pre-condition.
3886 // In a following PR implement RuntimeVerifiableOpInterface that instruments
3887 // conditions that need to be checked at runtime. This could also be fixed
3888 // by saying that mixedGlobalSizes is a DynamicI32List.
3889 Value tensorDimX;
3890 if (auto attr = dyn_cast<Attribute>(tensorDimXOpFoldResult)) {
3891 tensorDimX =
3892 createI32Constant(rewriter, loc, cast<IntegerAttr>(attr).getInt());
3893 } else {
3894 IntegerType i32 = rewriter.getI32Type();
3895 tensorDimX = cast<Value>(tensorDimXOpFoldResult);
3896 tensorDimX = LLVM::TruncOp::create(rewriter, loc, i32, tensorDimX);
3897 }
3898
3899 sgpr1 = setValueAtOffset(rewriter, loc, sgpr1, tensorDimX, offset);
3900
3901 Value c16 = createI32Constant(rewriter, loc, 16);
3902 Value tensorDimXHigh = LLVM::LShrOp::create(rewriter, loc, tensorDimX, c16);
3903 sgpr2 = setValueAtOffset(rewriter, loc, sgpr2, tensorDimXHigh, offset + 16);
3904 return {sgpr1, sgpr2};
3905 }
3906
3907 std::pair<Value, Value> setTensorDim0(DescriptorOp op, OpAdaptor adaptor,
3908 ConversionPatternRewriter &rewriter,
3909 Location loc, Value sgpr1, Value sgpr2,
3910 ArrayRef<Value> consts) const {
3911 return setTensorDimX(op, adaptor, rewriter, loc, sgpr1, sgpr2, consts, 0,
3912 48);
3913 }
3914
3915 std::pair<Value, Value> setTensorDim1(DescriptorOp op, OpAdaptor adaptor,
3916 ConversionPatternRewriter &rewriter,
3917 Location loc, Value sgpr2, Value sgpr3,
3918 ArrayRef<Value> consts) const {
3919 return setTensorDimX(op, adaptor, rewriter, loc, sgpr2, sgpr3, consts, 1,
3920 80);
3921 }
3922
3923 Value setTileDimX(DescriptorOp op, OpAdaptor adaptor,
3924 ConversionPatternRewriter &rewriter, Location loc,
3925 Value sgpr, ArrayRef<Value> consts, size_t dimX,
3926 int64_t offset) const {
3927 ArrayRef<int64_t> sharedStaticSizes = adaptor.getSharedStaticSizes();
3928 ValueRange sharedDynamicSizes = adaptor.getSharedDynamicSizes();
3929 SmallVector<OpFoldResult> mixedSharedSizes =
3930 getMixedValues(sharedStaticSizes, sharedDynamicSizes, rewriter);
3931 if (mixedSharedSizes.size() <= dimX)
3932 return sgpr;
3933
3934 OpFoldResult tileDimXOpFoldResult = *(mixedSharedSizes.rbegin() + dimX);
3935 // pre-condition: tileDimX is less than 2^16-1
3936 // TODO: Validation if the value breaks the pre-condition.
3937 // If the pre-condition fails, there is a possibility of
3938 // affecting the higher bits. In a following PR implement
3939 // RuntimeVerifiableOpInterface that instruments conditions that need to be
3940 // checked at runtime. This could also be fixed by saying that
3941 // mixedSharedSizes is a DynamicI16List.
3942 Value tileDimX;
3943 if (auto attr = dyn_cast<Attribute>(tileDimXOpFoldResult)) {
3944 tileDimX =
3945 createI32Constant(rewriter, loc, cast<IntegerAttr>(attr).getInt());
3946 } else {
3947 IntegerType i32 = rewriter.getI32Type();
3948 tileDimX = cast<Value>(tileDimXOpFoldResult);
3949 tileDimX = LLVM::TruncOp::create(rewriter, loc, i32, tileDimX);
3950 }
3951
3952 return setValueAtOffset(rewriter, loc, sgpr, tileDimX, offset);
3953 }
3954
3955 Value setTileDim0(DescriptorOp op, OpAdaptor adaptor,
3956 ConversionPatternRewriter &rewriter, Location loc,
3957 Value sgpr3, ArrayRef<Value> consts) const {
3958 return setTileDimX(op, adaptor, rewriter, loc, sgpr3, consts, 0, 112);
3959 }
3960
3961 Value setTileDim1(DescriptorOp op, OpAdaptor adaptor,
3962 ConversionPatternRewriter &rewriter, Location loc,
3963 Value sgpr4, ArrayRef<Value> consts) const {
3964 return setTileDimX(op, adaptor, rewriter, loc, sgpr4, consts, 1, 128);
3965 }
3966
3967 Value setValidIndices(DescriptorOp op, OpAdaptor adaptor,
3968 ConversionPatternRewriter &rewriter, Location loc,
3969 Value sgpr4, ArrayRef<Value> consts) const {
3970 auto type = cast<VectorType>(op.getIndices().getType());
3971 ArrayRef<int64_t> shape = type.getShape();
3972 assert(shape.size() == 1 && "expected shape to be of rank 1.");
3973 unsigned length = shape.back();
3974 assert(0 < length && length <= 16 && "expected length to be at most 16.");
3975 Value value = createI32Constant(rewriter, loc, length);
3976 return setValueAtOffset(rewriter, loc, sgpr4, value, 128);
3977 }
3978
3979 Value setTileDim1OrValidIndices(DescriptorOp op, OpAdaptor adaptor,
3980 ConversionPatternRewriter &rewriter,
3981 Location loc, Value sgpr4,
3982 ArrayRef<Value> consts) const {
3983 if constexpr (DescriptorOp::isGather())
3984 return setValidIndices(op, adaptor, rewriter, loc, sgpr4, consts);
3985 return setTileDim1(op, adaptor, rewriter, loc, sgpr4, consts);
3986 }
3987
3988 Value setTileDim2(DescriptorOp op, OpAdaptor adaptor,
3989 ConversionPatternRewriter &rewriter, Location loc,
3990 Value sgpr4, ArrayRef<Value> consts) const {
3991 // Value is ignored when in gather mode.
3992 if constexpr (DescriptorOp::isGather())
3993 return sgpr4;
3994 return setTileDimX(op, adaptor, rewriter, loc, sgpr4, consts, 2, 144);
3995 }
3996
3997 std::pair<Value, Value>
3998 setTensorDimXStride(DescriptorOp op, OpAdaptor adaptor,
3999 ConversionPatternRewriter &rewriter, Location loc,
4000 Value sgprY, Value sgprZ, ArrayRef<Value> consts,
4001 size_t dimX, int64_t offset) const {
4002 ArrayRef<int64_t> globalStaticStrides = adaptor.getGlobalStaticStrides();
4003 ValueRange globalDynamicStrides = adaptor.getGlobalDynamicStrides();
4004 SmallVector<OpFoldResult> mixedGlobalStrides =
4005 getMixedValues(globalStaticStrides, globalDynamicStrides, rewriter);
4006
4007 if (mixedGlobalStrides.size() <= (dimX + 1))
4008 return {sgprY, sgprZ};
4009
4010 OpFoldResult tensorDimXStrideOpFoldResult =
4011 *(mixedGlobalStrides.rbegin() + dimX + 1);
4012 // pre-condition: tensorDimXStride is less than 2^48-1
4013 // TODO: Validation if the value breaks the pre-condition.
4014 // In a following PR implement RuntimeVerifiableOpInterface that instruments
4015 // conditions that need to be checked at runtime.
4016 Value tensorDimXStride;
4017 if (auto attr = dyn_cast<Attribute>(tensorDimXStrideOpFoldResult))
4018 tensorDimXStride =
4019 createI64Constant(rewriter, loc, cast<IntegerAttr>(attr).getInt());
4020 else
4021 tensorDimXStride = cast<Value>(tensorDimXStrideOpFoldResult);
4022
4023 constexpr int64_t first48bits = (1ll << 48) - 1;
4024 Value mask = createI64Constant(rewriter, loc, first48bits);
4025 tensorDimXStride =
4026 LLVM::AndOp::create(rewriter, loc, mask, tensorDimXStride);
4027 IntegerType i32 = rewriter.getI32Type();
4028 Value tensorDimXStrideLow =
4029 LLVM::TruncOp::create(rewriter, loc, i32, tensorDimXStride);
4030 sgprY = setValueAtOffset(rewriter, loc, sgprY, tensorDimXStrideLow, offset);
4031
4032 int64_t shift = (offset % 32) == 0 ? 32 : offset % 32;
4033 Value shiftVal = createI64Constant(rewriter, loc, shift);
4034 Value tensorDimXStrideHigh =
4035 LLVM::LShrOp::create(rewriter, loc, tensorDimXStride, shiftVal);
4036 tensorDimXStrideHigh =
4037 LLVM::TruncOp::create(rewriter, loc, i32, tensorDimXStrideHigh);
4038 sgprZ = setValueAtOffset(rewriter, loc, sgprZ, tensorDimXStrideHigh,
4039 offset + shift);
4040 return {sgprY, sgprZ};
4041 }
4042
4043 std::pair<Value, Value>
4044 setTensorDim0Stride(DescriptorOp op, OpAdaptor adaptor,
4045 ConversionPatternRewriter &rewriter, Location loc,
4046 Value sgpr5, Value sgpr6, ArrayRef<Value> consts) const {
4047 return setTensorDimXStride(op, adaptor, rewriter, loc, sgpr5, sgpr6, consts,
4048 0, 160);
4049 }
4050
4051 std::pair<Value, Value>
4052 setTensorDim1Stride(DescriptorOp op, OpAdaptor adaptor,
4053 ConversionPatternRewriter &rewriter, Location loc,
4054 Value sgpr5, Value sgpr6, ArrayRef<Value> consts) const {
4055 // Value is ignored when in gather mode.
4056 if constexpr (DescriptorOp::isGather())
4057 return {sgpr5, sgpr6};
4058 return setTensorDimXStride(op, adaptor, rewriter, loc, sgpr5, sgpr6, consts,
4059 1, 208);
4060 }
4061
4062 Value getDGroup1(DescriptorOp op, OpAdaptor adaptor,
4063 ConversionPatternRewriter &rewriter, Location loc,
4064 ArrayRef<Value> consts) const {
4065 Value sgprs[8];
4066 for (int64_t i = 0; i < 8; ++i) {
4067 sgprs[i] = consts[0];
4068 }
4069
4070 sgprs[0] = setWorkgroupMask(op, adaptor, rewriter, loc, sgprs[0]);
4071 sgprs[0] = setDataSize(op, adaptor, rewriter, loc, sgprs[0], consts);
4072 sgprs[0] = setAtomicBarrier(op, adaptor, rewriter, loc, sgprs[0], consts);
4073 sgprs[0] = setIterateEnable(op, adaptor, rewriter, loc, sgprs[0], consts);
4074 sgprs[0] = setPadEnable(op, adaptor, rewriter, loc, sgprs[0], consts);
4075 sgprs[0] = setEarlyTimeout(op, adaptor, rewriter, loc, sgprs[0], consts);
4076 sgprs[0] = setPadInterval(op, adaptor, rewriter, loc, sgprs[0], consts);
4077 sgprs[0] = setPadAmount(op, adaptor, rewriter, loc, sgprs[0], consts);
4078
4079 sgprs[1] =
4080 setAtomicBarrierAddress(op, adaptor, rewriter, loc, sgprs[1], consts);
4081 std::tie(sgprs[1], sgprs[2]) =
4082 setTensorDim0(op, adaptor, rewriter, loc, sgprs[1], sgprs[2], consts);
4083 std::tie(sgprs[2], sgprs[3]) =
4084 setTensorDim1(op, adaptor, rewriter, loc, sgprs[2], sgprs[3], consts);
4085
4086 sgprs[3] = setTileDim0(op, adaptor, rewriter, loc, sgprs[3], consts);
4087 sgprs[4] =
4088 setTileDim1OrValidIndices(op, adaptor, rewriter, loc, sgprs[4], consts);
4089 sgprs[4] = setTileDim2(op, adaptor, rewriter, loc, sgprs[4], consts);
4090 std::tie(sgprs[5], sgprs[6]) = setTensorDim0Stride(
4091 op, adaptor, rewriter, loc, sgprs[5], sgprs[6], consts);
4092 std::tie(sgprs[6], sgprs[7]) = setTensorDim1Stride(
4093 op, adaptor, rewriter, loc, sgprs[6], sgprs[7], consts);
4094
4095 IntegerType i32 = rewriter.getI32Type();
4096 Type v8i32 = this->typeConverter->convertType(VectorType::get(8, i32));
4097 assert(v8i32 && "expected type conversion to succeed");
4098 Value dgroup1 = LLVM::PoisonOp::create(rewriter, loc, v8i32);
4099
4100 for (auto [sgpr, constant] : llvm::zip_equal(sgprs, consts)) {
4101 dgroup1 =
4102 LLVM::InsertElementOp::create(rewriter, loc, dgroup1, sgpr, constant);
4103 }
4104
4105 return dgroup1;
4106 }
4107
4108 Value setTensorDimX(DescriptorOp op, OpAdaptor adaptor,
4109 ConversionPatternRewriter &rewriter, Location loc,
4110 Value sgpr0, ArrayRef<Value> consts, int64_t dimX,
4111 int64_t offset) const {
4112 ArrayRef<int64_t> globalStaticSizes = adaptor.getGlobalStaticSizes();
4113 ValueRange globalDynamicSizes = adaptor.getGlobalDynamicSizes();
4114 SmallVector<OpFoldResult> mixedGlobalSizes =
4115 getMixedValues(globalStaticSizes, globalDynamicSizes, rewriter);
4116 if (mixedGlobalSizes.size() <= static_cast<unsigned long>(dimX))
4117 return sgpr0;
4118
4119 OpFoldResult tensorDimXOpFoldResult = *(mixedGlobalSizes.rbegin() + dimX);
4120 Value tensorDimX;
4121 if (auto attr = dyn_cast<Attribute>(tensorDimXOpFoldResult)) {
4122 tensorDimX =
4123 createI32Constant(rewriter, loc, cast<IntegerAttr>(attr).getInt());
4124 } else {
4125 IntegerType i32 = rewriter.getI32Type();
4126 tensorDimX = cast<Value>(tensorDimXOpFoldResult);
4127 tensorDimX = LLVM::TruncOp::create(rewriter, loc, i32, tensorDimX);
4128 }
4129
4130 return setValueAtOffset(rewriter, loc, sgpr0, tensorDimX, offset);
4131 }
4132
4133 Value setTensorDim2(DescriptorOp op, OpAdaptor adaptor,
4134 ConversionPatternRewriter &rewriter, Location loc,
4135 Value sgpr0, ArrayRef<Value> consts) const {
4136 return setTensorDimX(op, adaptor, rewriter, loc, sgpr0, consts, 2, 0);
4137 }
4138
4139 Value truncateAndSetValueAtOffset(ConversionPatternRewriter &rewriter,
4140 Location loc, Value accumulator,
4141 Value value, int64_t shift) const {
4142
4143 IntegerType i32 = rewriter.getI32Type();
4144 value = LLVM::TruncOp::create(rewriter, loc, i32, value);
4145 return setValueAtOffset(rewriter, loc, accumulator, value, shift);
4146 }
4147
4148 Value setLDSAddrIncrement(DescriptorOp op, OpAdaptor adaptor,
4149 ConversionPatternRewriter &rewriter, Location loc,
4150 Value sgpr1, ArrayRef<Value> consts,
4151 int64_t offset) const {
4152 Value ldsAddrIncrement = adaptor.getLdsIncrement();
4153 return setValueAtOffset(rewriter, loc, sgpr1, ldsAddrIncrement, offset);
4154 }
4155
4156 std::pair<Value, Value>
4157 setGlobalAddrIncrement(DescriptorOp op, OpAdaptor adaptor,
4158 ConversionPatternRewriter &rewriter, Location loc,
4159 Value sgpr2, Value sgpr3, ArrayRef<Value> consts,
4160 int64_t offset) const {
4161 Value globalAddrIncrement = adaptor.getGlobalIncrement();
4162 sgpr2 = truncateAndSetValueAtOffset(rewriter, loc, sgpr2,
4163 globalAddrIncrement, offset);
4164 Value shift = createI64Constant(rewriter, loc, 32);
4165 globalAddrIncrement =
4166 LLVM::LShrOp::create(rewriter, loc, globalAddrIncrement, shift);
4167 constexpr int64_t first16BitsHigh = (1ll << 16) - 1;
4168 sgpr3 = truncateAndSetValueAtOffset(rewriter, loc, sgpr3,
4169 globalAddrIncrement, offset + 32);
4170 Value mask = createI32Constant(rewriter, loc, first16BitsHigh);
4171 sgpr3 = LLVM::AndOp::create(rewriter, loc, sgpr3, mask);
4172 return {sgpr2, sgpr3};
4173 }
4174
4175 Value setTensorDim3OrLDSAddrIncrement(DescriptorOp op, OpAdaptor adaptor,
4176 ConversionPatternRewriter &rewriter,
4177 Location loc, Value sgpr1,
4178 ArrayRef<Value> consts) const {
4179 Value ldsIncrement = op.getLdsIncrement();
4180 constexpr int64_t dim = 3;
4181 constexpr int64_t offset = 32;
4182 if (!ldsIncrement)
4183 return setTensorDimX(op, adaptor, rewriter, loc, sgpr1, consts, dim,
4184 offset);
4185 return setLDSAddrIncrement(op, adaptor, rewriter, loc, sgpr1, consts,
4186 offset);
4187 }
4188
4189 std::pair<Value, Value> setTensorDim2StrideOrGlobalAddrIncrement(
4190 DescriptorOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter,
4191 Location loc, Value sgpr2, Value sgpr3, ArrayRef<Value> consts) const {
4192 Value globalIncrement = op.getGlobalIncrement();
4193 constexpr int32_t dim = 2;
4194 constexpr int32_t offset = 64;
4195 if (!globalIncrement)
4196 return setTensorDimXStride(op, adaptor, rewriter, loc, sgpr2, sgpr3,
4197 consts, dim, offset);
4198 return setGlobalAddrIncrement(op, adaptor, rewriter, loc, sgpr2, sgpr3,
4199 consts, offset);
4200 }
4201
4202 Value setIterateCount(DescriptorOp op, OpAdaptor adaptor,
4203 ConversionPatternRewriter &rewriter, Location loc,
4204 Value sgpr3, ArrayRef<Value> consts,
4205 int32_t offset) const {
4206 Value iterationCount = adaptor.getIterationCount();
4207 IntegerType i32 = rewriter.getI32Type();
4208 // pre-condition: iterationCount is in the inclusive interval [1, 256].
4209 // TODO: validation if the value breaks the pre-condition.
4210 // If the pre-condition fails, there is a possibility of
4211 // affecting the higher bits. In a following PR implement
4212 // RuntimeVerifiableOpInterface that instruments conditions that need to be
4213 // checked at runtime.
4214 iterationCount = LLVM::TruncOp::create(rewriter, loc, i32, iterationCount);
4215 iterationCount =
4216 LLVM::SubOp::create(rewriter, loc, iterationCount, consts[1]);
4217 return setValueAtOffset(rewriter, loc, sgpr3, iterationCount, offset);
4218 }
4219
4220 Value setTileDim3OrIterateCount(DescriptorOp op, OpAdaptor adaptor,
4221 ConversionPatternRewriter &rewriter,
4222 Location loc, Value sgpr3,
4223 ArrayRef<Value> consts) const {
4224 Value iterateCount = op.getIterationCount();
4225 constexpr int32_t dim = 2;
4226 constexpr int32_t offset = 112;
4227 if (!iterateCount)
4228 return setTileDimX(op, adaptor, rewriter, loc, sgpr3, consts, dim,
4229 offset);
4230
4231 return setIterateCount(op, adaptor, rewriter, loc, sgpr3, consts, offset);
4232 }
4233
4234 Value getDGroup2(DescriptorOp op, OpAdaptor adaptor,
4235 ConversionPatternRewriter &rewriter, Location loc,
4236 ArrayRef<Value> consts) const {
4237 if constexpr (DescriptorOp::isGather())
4238 return getDGroup2Gather(op, adaptor, rewriter, loc, consts);
4239 return getDGroup2NonGather(op, adaptor, rewriter, loc, consts);
4240 }
4241
4242 Value getDGroup2NonGather(DescriptorOp op, OpAdaptor adaptor,
4243 ConversionPatternRewriter &rewriter, Location loc,
4244 ArrayRef<Value> consts) const {
4245 IntegerType i32 = rewriter.getI32Type();
4246 Type v4i32 = this->typeConverter->convertType(VectorType::get(4, i32));
4247 assert(v4i32 && "expected type conversion to succeed.");
4248
4249 bool onlyNeedsTwoDescriptors = !op.getLdsIncrement() && op.getRank() <= 2;
4250 if (onlyNeedsTwoDescriptors)
4251 return LLVM::ZeroOp::create(rewriter, loc, v4i32);
4252
4253 constexpr int64_t sgprlen = 4;
4254 Value sgprs[sgprlen];
4255 for (int i = 0; i < sgprlen; ++i)
4256 sgprs[i] = consts[0];
4257
4258 sgprs[0] = setTensorDim2(op, adaptor, rewriter, loc, sgprs[0], consts);
4259 sgprs[1] = setTensorDim3OrLDSAddrIncrement(op, adaptor, rewriter, loc,
4260 sgprs[1], consts);
4261 std::tie(sgprs[2], sgprs[3]) = setTensorDim2StrideOrGlobalAddrIncrement(
4262 op, adaptor, rewriter, loc, sgprs[2], sgprs[3], consts);
4263 sgprs[3] =
4264 setTileDim3OrIterateCount(op, adaptor, rewriter, loc, sgprs[3], consts);
4265
4266 Value dgroup2 = LLVM::PoisonOp::create(rewriter, loc, v4i32);
4267 for (auto [sgpr, constant] : llvm::zip(sgprs, consts))
4268 dgroup2 =
4269 LLVM::InsertElementOp::create(rewriter, loc, dgroup2, sgpr, constant);
4270
4271 return dgroup2;
4272 }
4273
4274 Value getGatherIndices(DescriptorOp op, OpAdaptor adaptor,
4275 ConversionPatternRewriter &rewriter, Location loc,
4276 ArrayRef<Value> consts, bool firstHalf) const {
4277 IntegerType i32 = rewriter.getI32Type();
4278 Type v4i32 = this->typeConverter->convertType(VectorType::get(4, i32));
4279 assert(v4i32 && "expected type conversion to succeed.");
4280
4281 Value indices = adaptor.getIndices();
4282 auto vectorType = cast<VectorType>(indices.getType());
4283 unsigned length = vectorType.getShape().back();
4284 Type elementType = vectorType.getElementType();
4285 unsigned maxLength = elementType == i32 ? 4 : 8;
4286 int32_t offset = firstHalf ? 0 : maxLength;
4287 unsigned discountedLength =
4288 std::max(static_cast<int32_t>(length - offset), 0);
4289
4290 unsigned targetSize = std::min(maxLength, discountedLength);
4291
4292 SmallVector<Value> indicesVector;
4293 for (unsigned i = offset; i < targetSize + offset; ++i) {
4294 Value idx;
4295 if (i < consts.size())
4296 idx = consts[i];
4297 else
4298 idx = createI32Constant(rewriter, loc, i);
4299 Value elem = LLVM::ExtractElementOp::create(rewriter, loc, indices, idx);
4300 indicesVector.push_back(elem);
4301 }
4302
4303 SmallVector<Value> indicesI32Vector;
4304 if (elementType == i32) {
4305 indicesI32Vector = indicesVector;
4306 } else {
4307 for (unsigned i = 0; i < targetSize; ++i) {
4308 Value index = indicesVector[i];
4309 indicesI32Vector.push_back(
4310 LLVM::ZExtOp::create(rewriter, loc, i32, index));
4311 }
4312 if ((targetSize % 2) != 0)
4313 // Add padding when not divisible by two.
4314 indicesI32Vector.push_back(consts[0]);
4315 }
4316
4317 SmallVector<Value> indicesToInsert;
4318 if (elementType == i32) {
4319 indicesToInsert = indicesI32Vector;
4320 } else {
4321 unsigned size = indicesI32Vector.size() / 2;
4322 for (unsigned i = 0; i < size; ++i) {
4323 Value first = indicesI32Vector[2 * i];
4324 Value second = indicesI32Vector[2 * i + 1];
4325 Value joined = setValueAtOffset(rewriter, loc, first, second, 16);
4326 indicesToInsert.push_back(joined);
4327 }
4328 }
4329
4330 Value dgroup = LLVM::PoisonOp::create(rewriter, loc, v4i32);
4331 for (auto [sgpr, constant] : llvm::zip_first(indicesToInsert, consts))
4332 dgroup =
4333 LLVM::InsertElementOp::create(rewriter, loc, dgroup, sgpr, constant);
4334
4335 return dgroup;
4336 }
4337
4338 Value getDGroup2Gather(DescriptorOp op, OpAdaptor adaptor,
4339 ConversionPatternRewriter &rewriter, Location loc,
4340 ArrayRef<Value> consts) const {
4341 return getGatherIndices(op, adaptor, rewriter, loc, consts, true);
4342 }
4343
4344 std::pair<Value, Value>
4345 setTensorDim3Stride(DescriptorOp op, OpAdaptor adaptor,
4346 ConversionPatternRewriter &rewriter, Location loc,
4347 Value sgpr0, Value sgpr1, ArrayRef<Value> consts) const {
4348 constexpr int32_t dim = 3;
4349 constexpr int32_t offset = 0;
4350 return setTensorDimXStride(op, adaptor, rewriter, loc, sgpr0, sgpr1, consts,
4351 dim, offset);
4352 }
4353
4354 std::pair<Value, Value> setTensorDim4(DescriptorOp op, OpAdaptor adaptor,
4355 ConversionPatternRewriter &rewriter,
4356 Location loc, Value sgpr1, Value sgpr2,
4357 ArrayRef<Value> consts) const {
4358 constexpr int32_t dim = 4;
4359 constexpr int32_t offset = 48;
4360 return setTensorDimX(op, adaptor, rewriter, loc, sgpr1, sgpr2, consts, dim,
4361 offset);
4362 }
4363
4364 Value setTileDim4(DescriptorOp op, OpAdaptor adaptor,
4365 ConversionPatternRewriter &rewriter, Location loc,
4366 Value sgpr2, ArrayRef<Value> consts) const {
4367 constexpr int32_t dim = 4;
4368 constexpr int32_t offset = 80;
4369 return setTileDimX(op, adaptor, rewriter, loc, sgpr2, consts, dim, offset);
4370 }
4371
4372 Value getDGroup3(DescriptorOp op, OpAdaptor adaptor,
4373 ConversionPatternRewriter &rewriter, Location loc,
4374 ArrayRef<Value> consts) const {
4375 if constexpr (DescriptorOp::isGather())
4376 return getDGroup3Gather(op, adaptor, rewriter, loc, consts);
4377 return getDGroup3NonGather(op, adaptor, rewriter, loc, consts);
4378 }
4379
4380 Value getDGroup3NonGather(DescriptorOp op, OpAdaptor adaptor,
4381 ConversionPatternRewriter &rewriter, Location loc,
4382 ArrayRef<Value> consts) const {
4383 IntegerType i32 = rewriter.getI32Type();
4384 Type v4i32 = this->typeConverter->convertType(VectorType::get(4, i32));
4385 assert(v4i32 && "expected type conversion to succeed.");
4386 bool onlyNeedsTwoDescriptors = !op.getLdsIncrement() && op.getRank() <= 2;
4387 if (onlyNeedsTwoDescriptors)
4388 return LLVM::ZeroOp::create(rewriter, loc, v4i32);
4389
4390 constexpr int32_t sgprlen = 4;
4391 Value sgprs[sgprlen];
4392 for (int i = 0; i < sgprlen; ++i)
4393 sgprs[i] = consts[0];
4394
4395 std::tie(sgprs[0], sgprs[1]) = setTensorDim3Stride(
4396 op, adaptor, rewriter, loc, sgprs[0], sgprs[1], consts);
4397 std::tie(sgprs[1], sgprs[2]) =
4398 setTensorDim4(op, adaptor, rewriter, loc, sgprs[1], sgprs[2], consts);
4399 sgprs[2] = setTileDim4(op, adaptor, rewriter, loc, sgprs[2], consts);
4400
4401 Value dgroup3 = LLVM::PoisonOp::create(rewriter, loc, v4i32);
4402 for (auto [sgpr, constant] : llvm::zip(sgprs, consts))
4403 dgroup3 =
4404 LLVM::InsertElementOp::create(rewriter, loc, dgroup3, sgpr, constant);
4405
4406 return dgroup3;
4407 }
4408
4409 Value getDGroup3Gather(DescriptorOp op, OpAdaptor adaptor,
4410 ConversionPatternRewriter &rewriter, Location loc,
4411 ArrayRef<Value> consts) const {
4412 return getGatherIndices(op, adaptor, rewriter, loc, consts, false);
4413 }
4414
4415 LogicalResult
4416 matchAndRewrite(DescriptorOp op, OpAdaptor adaptor,
4417 ConversionPatternRewriter &rewriter) const override {
4418 if (chipset < kGfx1250)
4419 return op->emitOpError(
4420 "make_dma_descriptor is only supported on gfx1250");
4421
4422 Location loc = op.getLoc();
4423
4424 SmallVector<Value> consts;
4425 for (int64_t i = 0; i < 8; ++i)
4426 consts.push_back(createI32Constant(rewriter, loc, i));
4427
4428 Value dgroup0 = this->getDGroup0(adaptor);
4429 Value dgroup1 = this->getDGroup1(op, adaptor, rewriter, loc, consts);
4430 Value dgroup2 = this->getDGroup2(op, adaptor, rewriter, loc, consts);
4431 Value dgroup3 = this->getDGroup3(op, adaptor, rewriter, loc, consts);
4432 SmallVector<Value> results = {dgroup0, dgroup1, dgroup2, dgroup3};
4433 rewriter.replaceOpWithMultiple(op, {results});
4434 return success();
4435 }
4436};
4437
4438template <typename SourceOp, typename TargetOp>
4439struct AMDGPUTensorLoadStoreOpLowering
4440 : public ConvertOpToLLVMPattern<SourceOp> {
4441 using ConvertOpToLLVMPattern<SourceOp>::ConvertOpToLLVMPattern;
4443 AMDGPUTensorLoadStoreOpLowering(const LLVMTypeConverter &converter,
4444 Chipset chipset)
4445 : ConvertOpToLLVMPattern<SourceOp>(converter), chipset(chipset) {}
4446 Chipset chipset;
4447
4448 LogicalResult
4449 matchAndRewrite(SourceOp op, Adaptor adaptor,
4450 ConversionPatternRewriter &rewriter) const override {
4451 if (chipset < kGfx1250)
4452 return op->emitOpError("is only supported on gfx1250");
4453
4454 ValueRange desc = adaptor.getDesc();
4455 // Create a <v8 x i32> 0 as the fifth argument to match llvm intrinsic. It
4456 // will move into the TDM descriptor once it becomes relevant for future use
4457 auto v8i32 = VectorType::get(8, rewriter.getI32Type());
4458 Value dgroup4 = LLVM::ZeroOp::create(rewriter, op.getLoc(), v8i32);
4459 Attribute cachePolicy = rewriter.getI32IntegerAttr(0);
4460 rewriter.replaceOpWithNewOp<TargetOp>(op, desc[0], desc[1], desc[2],
4461 desc[3], dgroup4, cachePolicy,
4462 /*alias_scopes=*/nullptr,
4463 /*noalias_scopes=*/nullptr,
4464 /*tbaa=*/nullptr);
4465 return success();
4466 }
4467};
4468
4469struct GlobalPrefetchOpLowering
4470 : public ConvertOpToLLVMPattern<GlobalPrefetchOp> {
4471 GlobalPrefetchOpLowering(const LLVMTypeConverter &converter, Chipset chipset)
4472 : ConvertOpToLLVMPattern<GlobalPrefetchOp>(converter), chipset(chipset) {}
4473
4474 LogicalResult
4475 matchAndRewrite(GlobalPrefetchOp op, GlobalPrefetchOpAdaptor adaptor,
4476 ConversionPatternRewriter &rewriter) const override {
4477 if (chipset < kGfx1250)
4478 return op->emitOpError("is only supported on gfx1250+");
4479
4480 const bool isSpeculative = op.getSpeculative();
4481 const int32_t immArgValue = getGlobalPrefetchLLVMEncoding(
4482 op.getTemporalHint(), op.getCacheScope(), isSpeculative);
4483 // amdgpu.global_prefetch is gfx1250+, so its policy bits use gfx12
4484 // encoding.
4485 Attribute cachePolicy = ROCDL::Gfx12CachePolicyAttr::get(
4486 rewriter.getContext(),
4487 static_cast<ROCDL::Gfx12CachePolicy>(immArgValue));
4488
4489 ValueRange indices = adaptor.getIndices();
4490 Value memRef = adaptor.getSrc();
4491 MemRefDescriptor descriptor(memRef);
4492 MemRefType memRefType = op.getSrc().getType();
4493 Location loc = op->getLoc();
4494 auto inboundsFlags = isSpeculative ? LLVM::GEPNoWrapFlags::none
4495 : LLVM::GEPNoWrapFlags::inbounds |
4496 LLVM::GEPNoWrapFlags::nuw;
4497 Value prefetchPtr = getStridedElementPtr(
4498 rewriter, loc, memRefType, descriptor, indices, inboundsFlags);
4499
4500 rewriter.replaceOpWithNewOp<ROCDL::GlobalPrefetchOp>(
4501 op, prefetchPtr, cachePolicy, mlir::ArrayAttr{}, mlir::ArrayAttr{},
4502 mlir::ArrayAttr{});
4503 return success();
4504 }
4505
4506private:
4507 Chipset chipset;
4508};
4509
4510struct ConvertAMDGPUToROCDLPass
4511 : public impl::ConvertAMDGPUToROCDLPassBase<ConvertAMDGPUToROCDLPass> {
4512 using Base::Base;
4513
4514 void runOnOperation() override {
4515 MLIRContext *ctx = &getContext();
4516 FailureOr<Chipset> maybeChipset = Chipset::parse(chipset);
4517 if (failed(maybeChipset)) {
4518 emitError(UnknownLoc::get(ctx), "Invalid chipset name: " + chipset);
4519 return signalPassFailure();
4520 }
4521
4522 RewritePatternSet patterns(ctx);
4523 LLVMTypeConverter converter(ctx);
4524
4525 populateAMDGPUToROCDLConversionPatterns(converter, patterns, *maybeChipset);
4526 amdgpu::populateCommonGPUTypeAndAttributeConversions(converter);
4527 LLVMConversionTarget target(getContext());
4528 target.addIllegalDialect<::mlir::amdgpu::AMDGPUDialect>();
4529 target.addLegalDialect<::mlir::LLVM::LLVMDialect>();
4530 target.addLegalDialect<::mlir::ROCDL::ROCDLDialect>();
4531 if (failed(applyPartialConversion(getOperation(), target,
4532 std::move(patterns))))
4533 signalPassFailure();
4534 }
4535};
4536} // namespace
4537
4539 TypeConverter &typeConverter) {
4541 typeConverter, [](gpu::AddressSpace space) {
4542 switch (space) {
4543 case gpu::AddressSpace::Global:
4544 return ROCDL::ROCDLDialect::kGlobalMemoryAddressSpace;
4545 case gpu::AddressSpace::Workgroup:
4546 return ROCDL::ROCDLDialect::kSharedMemoryAddressSpace;
4547 case gpu::AddressSpace::Private:
4548 return ROCDL::ROCDLDialect::kPrivateMemoryAddressSpace;
4549 case gpu::AddressSpace::Constant:
4550 return ROCDL::ROCDLDialect::kConstantMemoryAddressSpace;
4551 }
4552 llvm_unreachable("unknown address space enum value");
4553 });
4554 typeConverter.addConversion([](gpu::NamedBarrierType type) {
4555 return LLVM::LLVMPointerType::get(
4556 type.getContext(), ROCDL::ROCDLDialect::kSharedMemoryAddressSpace);
4557 });
4558}
4559
4561 TypeConverter &typeConverter) {
4562 typeConverter.addTypeAttributeConversion(
4563 [](BaseMemRefType type, amdgpu::AddressSpaceAttr as)
4564 -> TypeConverter::AttributeConversionResult {
4565 MLIRContext *ctx = as.getContext();
4566 Type i64 = IntegerType::get(ctx, 64);
4567 switch (as.getValue()) {
4568 case amdgpu::AddressSpace::FatRawBuffer:
4569 return IntegerAttr::get(i64, 7);
4570 case amdgpu::AddressSpace::BufferRsrc:
4571 return IntegerAttr::get(i64, 8);
4572 case amdgpu::AddressSpace::FatStructuredBuffer:
4573 return IntegerAttr::get(i64, 9);
4574 }
4575 return TypeConverter::AttributeConversionResult::abort();
4576 });
4577 typeConverter.addConversion([&](DsBarrierStateType type) -> Type {
4578 return IntegerType::get(type.getContext(), 64);
4579 });
4580 typeConverter.addConversion([&](TDMBaseType type) -> Type {
4581 Type i32 = IntegerType::get(type.getContext(), 32);
4582 return typeConverter.convertType(VectorType::get(4, i32));
4583 });
4584 typeConverter.addConversion([&](TDMGatherBaseType type) -> Type {
4585 Type i32 = IntegerType::get(type.getContext(), 32);
4586 return typeConverter.convertType(VectorType::get(4, i32));
4587 });
4588 typeConverter.addConversion(
4589 [&](TDMDescriptorType type,
4590 SmallVectorImpl<Type> &result) -> std::optional<LogicalResult> {
4591 Type i32 = IntegerType::get(type.getContext(), 32);
4592 Type v4i32 = typeConverter.convertType(VectorType::get(4, i32));
4593 Type v8i32 = typeConverter.convertType(VectorType::get(8, i32));
4594 llvm::append_values(result, v4i32, v8i32, v4i32, v4i32);
4595 return success();
4596 });
4597
4598 auto addUnrealizedCast = [](OpBuilder &builder, TypeRange types,
4599 ValueRange inputs,
4601 // Only create unrealized_conversion_cast for TDMDescriptorType.
4602 // All other types which are not expected, should be
4603 // materialized by other target materialization functions.
4604 if (inputs.size() != 1)
4605 return {};
4606
4607 if (!isa<TDMDescriptorType>(inputs[0].getType()))
4608 return {};
4609
4610 auto cast = UnrealizedConversionCastOp::create(builder, loc, types, inputs);
4611 return cast.getResults();
4612 };
4613
4614 typeConverter.addTargetMaterialization(addUnrealizedCast);
4615}
4616
4618 RewritePatternSet &patterns,
4619 Chipset chipset) {
4621 patterns
4622 .add<FatRawBufferCastLowering,
4623 RawBufferOpLowering<RawBufferLoadOp, ROCDL::RawPtrBufferLoadOp>,
4624 RawBufferOpLowering<RawBufferStoreOp, ROCDL::RawPtrBufferStoreOp>,
4625 RawBufferOpLowering<RawBufferAtomicFaddOp,
4626 ROCDL::RawPtrBufferAtomicFaddOp>,
4627 RawBufferOpLowering<RawBufferAtomicFmaxOp,
4628 ROCDL::RawPtrBufferAtomicFmaxOp>,
4629 RawBufferOpLowering<RawBufferAtomicSmaxOp,
4630 ROCDL::RawPtrBufferAtomicSmaxOp>,
4631 RawBufferOpLowering<RawBufferAtomicUminOp,
4632 ROCDL::RawPtrBufferAtomicUminOp>,
4633 RawBufferOpLowering<RawBufferAtomicCmpswapOp,
4634 ROCDL::RawPtrBufferAtomicCmpSwap>,
4635 AMDGPUDPPLowering, MemoryCounterWaitOpLowering, LDSBarrierOpLowering,
4636 SchedBarrierOpLowering, MFMAOpLowering, ScaledMFMAOpLowering,
4637 SparseMFMAOpLowering, WMMAOpLowering, ScaledWMMAOpLowering,
4638 SparseWMMAOpLowering, DotOpLowering, ExtPackedFp8OpLowering,
4639 ScaledExtPackedMatrixOpLowering, ScaledExtPackedOpLowering,
4640 PackedScaledTruncOpLowering, PackedTrunc2xFp8OpLowering,
4641 PackedStochRoundFp8OpLowering, GatherToLDSOpLowering,
4642 GlobalLoadAsyncToLDSOpLowering, TransposeLoadOpLowering,
4643 GlobalTransposeLoadOpLowering, AMDGPUPermlaneLowering,
4644 AMDGPUPermlaneVarLowering, AMDGPUMakeDmaBaseLowering<MakeDmaBaseOp>,
4645 AMDGPUMakeDmaBaseLowering<MakeGatherDmaBaseOp>,
4646 AMDGPULowerDescriptor<MakeDmaDescriptorOp>,
4647 AMDGPULowerDescriptor<MakeGatherDmaDescriptorOp>,
4648 AMDGPUTensorLoadStoreOpLowering<TensorLoadToLDSOp,
4649 ROCDL::TensorLoadToLDSOp>,
4650 AMDGPUTensorLoadStoreOpLowering<TensorStoreFromLDSOp,
4651 ROCDL::TensorStoreFromLDSOp>,
4652 DsBarrierInitOpLowering, DsBarrierPollStateOpLowering,
4653 DsAsyncBarrierArriveOpLowering, DsBarrierArriveOpLowering,
4654 GlobalPrefetchOpLowering>(converter, chipset);
4655 patterns.add<AMDGPUSwizzleBitModeLowering, DsBarrierStatePhaseOpLowering,
4656 DsBarrierStatePendingCountOpLowering,
4657 DsBarrierStateInitCountOpLowering,
4658 DsBarrierStatePhaseParityLowering>(converter);
4659}
static bool typeIsExpectedFp8ForChipset(Chipset chipset, Type type)
Return true if type is the E4M3FN variant of an 8-bit float that is supported by the _fp8 instruction...
constexpr Chipset kGfx942
static std::optional< StringRef > wmmaOpToIntrinsicRDNA(Type elemSourceType, Type elemBSourceType, Type elemDestType, uint32_t k, bool isRDNA3)
Returns the rocdl intrinsic corresponding to a WMMA operation wmma for RDNA3/4 architectures.
static bool hasDot10Insts(const Chipset &chipset)
static bool hasDot7Insts(const Chipset &chipset)
static std::optional< SparseWMMAOpInfo > sparseWMMAOpToIntrinsic(SparseWMMAOp swmmac, Chipset chipset)
static std::optional< StringRef > mfmaOpToIntrinsic(MFMAOp mfma, Chipset chipset)
Return the rocdl intrinsic corresponding to a MFMA operation mfma if one exists.
constexpr Chipset kGfx908
static void wmmaPushInputOperand(ConversionPatternRewriter &rewriter, Location loc, const TypeConverter *typeConverter, bool isUnsigned, Value llvmInput, Value mlirInput, SmallVectorImpl< Value > &operands, SmallVectorImpl< NamedAttribute > &attrs, StringRef attrName)
Push an input operand.
static std::optional< ScaledMFMAIntrinsic > mfmaOpToScaledIntrinsic(Type aType, Type bType, Type destType, uint32_t m, uint32_t n, uint32_t k, uint32_t b, Chipset chipset)
constexpr Chipset kGfx1250
static Value castScaleOperand(ConversionPatternRewriter &rewriter, Location loc, Value input)
Converts the scaled MFMA/WMMA operands, scalesA and scalesB, from MLIR AMDGPU dialect convention to R...
constexpr Chipset kGfx90a
static std::optional< StringRef > getScaledWmmaIntrinsicName(int64_t m, int64_t n, int64_t k, bool isScale16)
Determines the ROCDL intrinsic name for scaled WMMA based on dimensions and scale block size (16 or 3...
static void wmmaPushOutputOperand(ConversionPatternRewriter &rewriter, Location loc, const TypeConverter *typeConverter, Value output, int32_t subwordOffset, bool clamp, SmallVectorImpl< Value > &operands, SmallVectorImpl< NamedAttribute > &attrs)
Push the output operand.
static bool typeIsExpectedBf8ForChipset(Chipset chipset, Type type)
Return true if type is the E5M2 variant of an 8-bit float that is supported by the _bf8 instructions ...
static std::optional< StringRef > wmmaOpToIntrinsic(WMMAOp wmma, Chipset chipset)
Returns the rocdl intrinsic corresponding to a WMMA operation wmma if one exists.
static bool hasDot11Insts(const Chipset &chipset)
static std::optional< StringRef > smfmacOpToIntrinsic(SparseMFMAOp op, Chipset chipset)
Returns the rocdl intrinsic corresponding to a SparseMFMA (smfmac) operation if one exists.
static Value makeBufferRsrc(ConversionPatternRewriter &rewriter, Location loc, Value basePointer, Value numRecords, bool boundsCheck, amdgpu::Chipset chipset, Value cacheSwizzleStride=nullptr, unsigned addressSpace=8)
static Value createI64Constant(ConversionPatternRewriter &rewriter, Location loc, int64_t value)
static bool hasDot9Insts(const Chipset &chipset)
static std::optional< StringRef > wmmaOpToIntrinsicGfx1250(Type elemSourceType, Type elemBSourceType, Type elemDestType, uint32_t k)
Return the rocdl intrinsic corresponding to a WMMA operation wmma for the gfx1250 architecture.
constexpr Chipset kGfx1200
static Value getNumRecords(ConversionPatternRewriter &rewriter, Location loc, MemRefType memrefType, MemRefDescriptor &memrefDescriptor, ArrayRef< int64_t > strides, int64_t elementByteWidth, amdgpu::Chipset chipset, bool boundsCheck)
Compute the contents of the num_records field for a given memref descriptor - that is,...
static Value packSmallFloatVectorOperand(ConversionPatternRewriter &rewriter, Location loc, Value input, bool allowBf16=true)
Pack small float vector operands (fp4/fp6/fp8/bf16) into the format expected by scaled matrix multipl...
static std::optional< ROCDL::WMMAMatrixScaleFormat > getWmmaScaleFormat(Type elemType)
Maps f8 scale element types to WMMA scale format codes.
static Value convertPackedVectorOperand(ConversionPatternRewriter &rewriter, Location loc, Value input, bool allowBf16=true)
Converts packed vector operands to the expected ROCDL types.
static Value getLinearIndexI32(ConversionPatternRewriter &rewriter, Location loc, MemRefDescriptor &memRefDescriptor, ValueRange indices, ArrayRef< int64_t > strides)
Returns the linear index used to access an element in the memref.
static Value convertUnsignedToI32(ConversionPatternRewriter &rewriter, Location loc, Value val)
Convert an unsigned number val to i32.
static bool hasDot8Insts(const Chipset &chipset)
static bool hasDot2Insts(const Chipset &chipset)
static Value createI32Constant(ConversionPatternRewriter &rewriter, Location loc, int32_t value)
static std::optional< ROCDL::MatrixFormat > smallFloatTypeToMatrixFormat(Type mlirElemType)
std::tuple< StringRef, ROCDL::MatrixFormat, ROCDL::MatrixFormat > ScaledMFMAIntrinsic
If there is a scaled MFMA instruction for the input element types aType and bType,...
static bool hasDot12Insts(const Chipset &chipset)
static Value convertUnsignedToI64(ConversionPatternRewriter &rewriter, Location loc, Value val)
Convert an unsigned number val to i64.
constexpr Chipset kGfx950
static bool hasDot1Insts(const Chipset &chipset)
return success()
b
Return true if permutation is a valid permutation of the outer_dims_perm (case OuterOrInnerPerm::Oute...
ArrayAttr()
b getContext())
auto load
*if copies could not be generated due to yet unimplemented cases *copyInPlacementStart and copyOutPlacementStart in copyPlacementBlock *specify the insertion points where the incoming copies and outgoing should be the output argument nBegin is set to its * replacement(set to `begin` if no invalidation happens). Since outgoing *copies could have been inserted at `end`
static constexpr unsigned kSizePosInMemRefDescriptor
static constexpr unsigned kStridePosInMemRefDescriptor
static constexpr unsigned kOffsetPosInMemRefDescriptor
static constexpr unsigned kAllocatedPtrPosInMemRefDescriptor
static constexpr unsigned kAlignedPtrPosInMemRefDescriptor
static Value clamp(ImplicitLocOpBuilder &builder, Value value, Value lowerBound, Value upperBound)
This class provides a shared interface for ranked and unranked memref types.
Utility class for operation conversions targeting the LLVM dialect that match exactly one source oper...
Definition Pattern.h:227
ConvertOpToLLVMPattern(const LLVMTypeConverter &typeConverter, PatternBenefit benefit=1)
Definition Pattern.h:233
typename SourceOp::template GenericAdaptor< ArrayRef< ValueRange > > OneToNOpAdaptor
Definition Pattern.h:230
typename SourceOp::Adaptor OpAdaptor
Definition Pattern.h:229
Value getStridedElementPtr(ConversionPatternRewriter &rewriter, Location loc, MemRefType type, Value memRefDesc, ValueRange indices, LLVM::GEPNoWrapFlags noWrapFlags=LLVM::GEPNoWrapFlags::none) const
Convenience wrapper for the corresponding helper utility.
Definition Pattern.cpp:66
llvm::TypeSize getTypeSizeInBits(Type t) const
Returns the size in bits of the given type in the current scope.
Conversion from types to the LLVM IR dialect.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition Location.h:76
MLIRContext is the top-level object for a collection of MLIR operations.
Definition MLIRContext.h:63
Helper class to produce LLVM dialect operations extracting or inserting elements of a MemRef descript...
Value stride(OpBuilder &builder, Location loc, unsigned pos)
Builds IR extracting the pos-th size from the descriptor.
Value size(OpBuilder &builder, Location loc, unsigned pos)
Builds IR extracting the pos-th size from the descriptor.
NamedAttribute represents a combination of a name and an Attribute value.
Definition Attributes.h:164
This class helps build Operations.
Definition Builders.h:209
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Definition Operation.h:432
result_range getResults()
Definition Operation.h:440
unsigned getNumResults()
Return the number of results held by this operation.
Definition Operation.h:429
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
This class provides an abstraction over the various different ranges of value types.
Definition TypeRange.h:40
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition Types.h:74
bool isF64() const
Definition Types.cpp:41
MLIRContext * getContext() const
Return the MLIRContext in which this type was uniqued.
Definition Types.cpp:35
bool isF8E5M2() const
Definition Types.cpp:45
bool isSignedInteger() const
Return true if this is a signed integer type (with the specified width).
Definition Types.cpp:78
bool isF8E4M3FN() const
Definition Types.cpp:44
bool isF32() const
Definition Types.cpp:40
bool isUnsignedInteger() const
Return true if this is an unsigned integer type (with the specified width).
Definition Types.cpp:90
bool isInteger() const
Return true if this is an integer type (with the specified width).
Definition Types.cpp:58
bool isF16() const
Definition Types.cpp:38
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
Definition Types.cpp:124
bool isBF16() const
Definition Types.cpp:37
This class provides an abstraction over the different types of ranges over Values.
Definition ValueRange.h:389
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
Type getType() const
Return the type of this value.
Definition Value.h:105
Value getStridedElementPtr(OpBuilder &builder, Location loc, const LLVMTypeConverter &converter, MemRefType type, Value memRefDesc, ValueRange indices, LLVM::GEPNoWrapFlags noWrapFlags=LLVM::GEPNoWrapFlags::none)
Performs the index computation to get to the element at indices of the memory pointed to by memRefDes...
Definition Pattern.cpp:603
int32_t getGlobalPrefetchLLVMEncoding(amdgpu::LoadTemporalHint hint, amdgpu::Scope scope, bool isSpeculative)
Definition AMDGPUEnums.h:18
bool hasOcpFp8(const Chipset &chipset)
Definition Chipset.h:52
void populateCommonGPUTypeAndAttributeConversions(TypeConverter &typeConverter)
Remap common GPU memory spaces (Workgroup, Private, etc) to LLVM address spaces.
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition Remarks.h:717
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
Definition Matchers.h:490
SmallVector< OpFoldResult > getMixedValues(ArrayRef< int64_t > staticValues, ValueRange dynamicValues, MLIRContext *context)
Return a vector of OpFoldResults with the same size a staticValues, but all elements for which Shaped...
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
Definition Utils.cpp:307
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
detail::constant_int_predicate_matcher m_Zero()
Matches a constant scalar / vector splat / tensor splat integer zero.
Definition Matchers.h:442
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
void populateGpuMemorySpaceAttributeConversions(TypeConverter &typeConverter, const MemorySpaceMapping &mapping)
Populates memory space attribute conversion rules for lowering gpu.address_space to integer values.
void populateAMDGPUToROCDLConversionPatterns(LLVMTypeConverter &converter, RewritePatternSet &patterns, amdgpu::Chipset chipset)
Note: This function will also add conversions for the AMDGPU-specific address spaces and types,...
llvm::TypeSwitch< T, ResultT > TypeSwitch
Definition LLVM.h:139
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
void populateAMDGPUTypeAndAttributeConversions(TypeConverter &typeConverter)
Remap AMDGPU memory spaces to LLVM address spaces by mapping amdgpu::AddressSpace::fat_raw_buffer to ...
Returns the rocdl intrinsic corresponding to a SparseWMMA operation swmmac if one exists.
Represents the amdgpu gfx chipset version, e.g., gfx90a, gfx942, gfx1103.
Definition Chipset.h:22
unsigned majorVersion
Definition Chipset.h:23
unsigned minorVersion
Definition Chipset.h:24
static FailureOr< Chipset > parse(StringRef name)
Parses the chipset version string and returns the chipset on success, and failure otherwise.
Definition Chipset.cpp:14
unsigned steppingVersion
Definition Chipset.h:25