MLIR 23.0.0git
AMDGPUToROCDL.cpp
Go to the documentation of this file.
1//===- AMDGPUToROCDL.cpp - AMDGPU to ROCDL dialect conversion -------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
10
21#include "mlir/IR/Attributes.h"
24#include "mlir/IR/Matchers.h"
26#include "mlir/Pass/Pass.h"
27
29
30#include "llvm/ADT/STLExtras.h"
31#include "llvm/ADT/TypeSwitch.h"
32#include "llvm/Support/AMDGPUAddrSpace.h"
33#include "llvm/Support/Casting.h"
34#include "llvm/Support/ErrorHandling.h"
35#include <cstdint>
36#include <optional>
37
38namespace mlir {
39#define GEN_PASS_DEF_CONVERTAMDGPUTOROCDLPASS
40#include "mlir/Conversion/Passes.h.inc"
41} // namespace mlir
42
43using namespace mlir;
44using namespace mlir::amdgpu;
45
46// Define commonly used chipsets versions for convenience.
47constexpr Chipset kGfx908 = Chipset(9, 0, 8);
48constexpr Chipset kGfx90a = Chipset(9, 0, 0xa);
49constexpr Chipset kGfx942 = Chipset(9, 4, 2);
50constexpr Chipset kGfx950 = Chipset(9, 5, 0);
51constexpr Chipset kGfx1250 = Chipset(12, 5, 0);
52
53// Predicates mirroring the LLVM AMDGPU `HasDot{N}Insts` features that gate
54// the `v_dot*` instructions consumed by the `amdgpu.dot` lowering.
55static bool hasDot1Insts(const Chipset &chipset) {
56 if (chipset.majorVersion == 9)
57 return chipset >= Chipset(9, 0, 6);
58 if (chipset.majorVersion == 10) {
59 if (chipset.minorVersion == 1)
60 return chipset.steppingVersion == 1u || chipset.steppingVersion == 2u;
61 return chipset.minorVersion >= 3u;
62 }
63 return false;
64}
65
66static bool hasDot2Insts(const Chipset &chipset) {
67 return hasDot1Insts(chipset);
68}
69
70static bool hasDot7Insts(const Chipset &chipset) {
71 return chipset.majorVersion >= 11 || hasDot1Insts(chipset);
72}
73
74static bool hasDot8Insts(const Chipset &chipset) {
75 return chipset.majorVersion >= 11;
76}
77
78static bool hasDot9Insts(const Chipset &chipset) {
79 if (chipset.majorVersion == 11)
80 return true;
81 return chipset.majorVersion == 12 && chipset.minorVersion == 0;
82}
83
84static bool hasDot10Insts(const Chipset &chipset) {
85 if (chipset.majorVersion == 11)
86 return true;
87 if (chipset.majorVersion == 12)
88 return chipset.minorVersion == 0;
89 return hasDot1Insts(chipset);
90}
91
92static bool hasDot11Insts(const Chipset &chipset) {
93 if (chipset.majorVersion == 11)
94 return chipset.minorVersion == 7u;
95 return chipset.majorVersion == 12 && chipset.minorVersion == 0;
96}
97
98static bool hasDot12Insts(const Chipset &chipset) {
99 if (chipset == Chipset(9, 5, 0))
100 return true;
101 if (chipset.majorVersion == 11)
102 return true;
103 return chipset.majorVersion == 12 && chipset.minorVersion == 0;
104}
105
106/// Convert an unsigned number `val` to i32.
107static Value convertUnsignedToI32(ConversionPatternRewriter &rewriter,
108 Location loc, Value val) {
109 IntegerType i32 = rewriter.getI32Type();
110 // Force check that `val` is of int type.
111 auto valTy = cast<IntegerType>(val.getType());
112 if (i32 == valTy)
113 return val;
114 return valTy.getWidth() > 32
115 ? Value(LLVM::TruncOp::create(rewriter, loc, i32, val))
116 : Value(LLVM::ZExtOp::create(rewriter, loc, i32, val));
117}
118
119static Value createI32Constant(ConversionPatternRewriter &rewriter,
120 Location loc, int32_t value) {
121 return LLVM::ConstantOp::create(rewriter, loc, rewriter.getI32Type(), value);
122}
123
124/// Convert an unsigned number `val` to i64.
125static Value convertUnsignedToI64(ConversionPatternRewriter &rewriter,
126 Location loc, Value val) {
127 IntegerType i64 = rewriter.getI64Type();
128 // Force check that `val` is of int type.
129 auto valTy = cast<IntegerType>(val.getType());
130 if (i64 == valTy)
131 return val;
132 return valTy.getWidth() > 64
133 ? Value(LLVM::TruncOp::create(rewriter, loc, i64, val))
134 : Value(LLVM::ZExtOp::create(rewriter, loc, i64, val));
135}
136
137static Value createI64Constant(ConversionPatternRewriter &rewriter,
138 Location loc, int64_t value) {
139 return LLVM::ConstantOp::create(rewriter, loc, rewriter.getI64Type(), value);
140}
141
142/// Returns the linear index used to access an element in the memref.
143static Value getLinearIndexI32(ConversionPatternRewriter &rewriter,
144 Location loc, MemRefDescriptor &memRefDescriptor,
146 IntegerType i32 = rewriter.getI32Type();
147 Value index;
148 for (auto [i, increment, stride] : llvm::enumerate(indices, strides)) {
149 if (stride != 1) { // Skip if stride is 1.
150 Value strideValue =
151 ShapedType::isDynamic(stride)
152 ? convertUnsignedToI32(rewriter, loc,
153 memRefDescriptor.stride(rewriter, loc, i))
154 : LLVM::ConstantOp::create(rewriter, loc, i32, stride);
155 increment = LLVM::MulOp::create(rewriter, loc, increment, strideValue);
156 }
157 index = index ? LLVM::AddOp::create(rewriter, loc, index, increment)
158 : increment;
159 }
160 return index ? index : createI32Constant(rewriter, loc, 0);
161}
162
163/// Compute the contents of the `num_records` field for a given memref
164/// descriptor - that is, the number of bytes that's one element past the
165/// greatest possible valid index into the memref.
166static Value getNumRecords(ConversionPatternRewriter &rewriter, Location loc,
167 MemRefType memrefType,
168 MemRefDescriptor &memrefDescriptor,
169 ArrayRef<int64_t> strides, int64_t elementByteWidth,
170 amdgpu::Chipset chipset, bool boundsCheck) {
171 if (chipset >= kGfx1250 && !boundsCheck) {
172 constexpr int64_t first45bits = (1ll << 45) - 1;
173 return createI64Constant(rewriter, loc, first45bits);
174 }
175 if (memrefType.hasStaticShape() &&
176 !llvm::any_of(strides, ShapedType::isDynamic)) {
177 int64_t size = memrefType.getRank() == 0 ? 1 : 0;
178 ArrayRef<int64_t> shape = memrefType.getShape();
179 for (uint32_t i = 0, e = memrefType.getRank(); i < e; ++i)
180 size = std::max(shape[i] * strides[i], size);
181 size = size * elementByteWidth;
182 return createI64Constant(rewriter, loc, size);
183 }
184 Value maxIndex;
185 for (uint32_t i = 0, e = memrefType.getRank(); i < e; ++i) {
186 Value size = memrefDescriptor.size(rewriter, loc, i);
187 Value stride = memrefDescriptor.stride(rewriter, loc, i);
188 Value maxThisDim = LLVM::MulOp::create(rewriter, loc, size, stride);
189 maxIndex = maxIndex
190 ? LLVM::UMaxOp::create(rewriter, loc, maxIndex, maxThisDim)
191 : maxThisDim;
192 }
193 Value maxIndexI64 = convertUnsignedToI64(rewriter, loc, maxIndex);
194 Value byteWidthConst = createI64Constant(rewriter, loc, elementByteWidth);
195 return LLVM::MulOp::create(rewriter, loc, maxIndexI64, byteWidthConst);
196}
197
198static Value makeBufferRsrc(ConversionPatternRewriter &rewriter, Location loc,
199 Value basePointer, Value numRecords,
200 bool boundsCheck, amdgpu::Chipset chipset,
201 Value cacheSwizzleStride = nullptr,
202 unsigned addressSpace = 8) {
203 // The stride value is generally 0. However, on MI-300 and onward, you can
204 // enable a cache swizzling mode by setting bit 14 of the stride field
205 // and setting that stride to a cache stride.
206 Type i16 = rewriter.getI16Type();
207 Value stride;
208 if (chipset.majorVersion == 9 && chipset >= kGfx942 && cacheSwizzleStride) {
209 Value cacheStrideZext =
210 LLVM::ZExtOp::create(rewriter, loc, i16, cacheSwizzleStride);
211 Value swizzleBit = LLVM::ConstantOp::create(
212 rewriter, loc, i16, rewriter.getI16IntegerAttr(1 << 14));
213 stride = LLVM::OrOp::create(rewriter, loc, cacheStrideZext, swizzleBit,
214 /*isDisjoint=*/true);
215 } else {
216 stride = LLVM::ConstantOp::create(rewriter, loc, i16,
217 rewriter.getI16IntegerAttr(0));
218 }
219
220 uint32_t flags = 0;
221 if (chipset >= kGfx1250) {
222 // Flag word:
223 // bit 0: swizzle
224 // bit 1: 0 means (total_offset + payload > numRecords)
225 // 1 means ((total_offset + payload >) numRecords) || ((offset +
226 // payload) > stride) only applied when swizzle_enable = 0. keep at
227 // zero.
228 // whether oob is done depends on numRecords.
229 // bits 2-3: Type (must be 0)
230 } else {
231 // Get the number of elements.
232 // Flag word:
233 // bits 0-11: dst sel, ignored by these intrinsics
234 // bits 12-14: data format (ignored, must be nonzero, 7=float)
235 // bits 15-18: data format (ignored, must be nonzero, 4=32bit)
236 // bit 19: In nested heap (0 here)
237 // bit 20: Behavior on unmap (0 means "return 0 / ignore")
238 // bits 21-22: Index stride for swizzles (N/A)
239 // bit 23: Add thread ID (0)
240 // bit 24: Reserved to 1 (RDNA) or 0 (CDNA)
241 // bits 25-26: Reserved (0)
242 // bit 27: Buffer is non-volatile (CDNA only)
243 // bits 28-29: Out of bounds select (0 = structured, 1 = check index, 2 =
244 // none, 3 = either swizzles or testing against offset field) RDNA only
245 // bits 30-31: Type (must be 0)
246 flags |= (7 << 12) | (4 << 15);
247 if (chipset.majorVersion >= 10) {
248 flags |= (1 << 24);
249 uint32_t oob = boundsCheck ? 3 : 2;
250 flags |= (oob << 28);
251 }
252 }
253 Value flagsConst = createI32Constant(rewriter, loc, flags);
254 Type rsrcType =
255 LLVM::LLVMPointerType::get(rewriter.getContext(), addressSpace);
256 Value resource = rewriter.createOrFold<ROCDL::MakeBufferRsrcOp>(
257 loc, rsrcType, basePointer, stride, numRecords, flagsConst);
258 return resource;
259}
260
261namespace {
262struct FatRawBufferCastLowering
263 : public ConvertOpToLLVMPattern<FatRawBufferCastOp> {
264 FatRawBufferCastLowering(const LLVMTypeConverter &converter, Chipset chipset)
265 : ConvertOpToLLVMPattern<FatRawBufferCastOp>(converter),
266 chipset(chipset) {}
267
268 Chipset chipset;
269
270 LogicalResult
271 matchAndRewrite(FatRawBufferCastOp op, FatRawBufferCastOpAdaptor adaptor,
272 ConversionPatternRewriter &rewriter) const override {
273 Location loc = op.getLoc();
274 Value memRef = adaptor.getSource();
275 Value unconvertedMemref = op.getSource();
276 MemRefType memrefType = cast<MemRefType>(unconvertedMemref.getType());
277 MemRefDescriptor descriptor(memRef);
278
279 DataLayout dataLayout = DataLayout::closest(op);
280 int64_t elementByteWidth =
281 dataLayout.getTypeSizeInBits(memrefType.getElementType()) / 8;
282
283 int64_t unusedOffset = 0;
284 SmallVector<int64_t, 5> strideVals;
285 if (failed(memrefType.getStridesAndOffset(strideVals, unusedOffset)))
286 return op.emitOpError("Can't lower non-stride-offset memrefs");
287
288 Value numRecords = adaptor.getValidBytes();
289 if (!numRecords)
290 numRecords =
291 getNumRecords(rewriter, loc, memrefType, descriptor, strideVals,
292 elementByteWidth, chipset, adaptor.getBoundsCheck());
293
294 Value basePointer =
295 adaptor.getResetOffset()
296 ? descriptor.bufferPtr(rewriter, loc, *getTypeConverter(),
297 memrefType)
298 : descriptor.alignedPtr(rewriter, loc);
299
300 Value offset = adaptor.getResetOffset()
301 ? LLVM::ConstantOp::create(rewriter, loc, getIndexType(),
302 rewriter.getIndexAttr(0))
303 : descriptor.offset(rewriter, loc);
304
305 bool hasSizes = memrefType.getRank() > 0;
306 // No need to unpack() and pack() all the individual sizes and strides,
307 // so we'll just extract the arrays.
308 Value sizes = hasSizes
309 ? LLVM::ExtractValueOp::create(rewriter, loc, descriptor,
311 : Value{};
312 Value strides =
313 hasSizes ? LLVM::ExtractValueOp::create(rewriter, loc, descriptor,
315 : Value{};
316
317 Value fatPtr = makeBufferRsrc(
318 rewriter, loc, basePointer, numRecords, adaptor.getBoundsCheck(),
319 chipset, adaptor.getCacheSwizzleStride(), /*addressSpace=*/7);
320
321 Value result = MemRefDescriptor::poison(
322 rewriter, loc,
323 getTypeConverter()->convertType(op.getResult().getType()));
324 SmallVector<int64_t> pos{kAllocatedPtrPosInMemRefDescriptor};
325 result = LLVM::InsertValueOp::create(rewriter, loc, result, fatPtr, pos);
326 result = LLVM::InsertValueOp::create(rewriter, loc, result, fatPtr,
328 result = LLVM::InsertValueOp::create(rewriter, loc, result, offset,
330 if (hasSizes) {
331 result = LLVM::InsertValueOp::create(rewriter, loc, result, sizes,
333 result = LLVM::InsertValueOp::create(rewriter, loc, result, strides,
335 }
336 rewriter.replaceOp(op, result);
337 return success();
338 }
339};
340
341/// Define lowering patterns for raw buffer ops
342template <typename GpuOp, typename Intrinsic>
343struct RawBufferOpLowering : public ConvertOpToLLVMPattern<GpuOp> {
344 RawBufferOpLowering(const LLVMTypeConverter &converter, Chipset chipset)
345 : ConvertOpToLLVMPattern<GpuOp>(converter), chipset(chipset) {}
346
347 Chipset chipset;
348 static constexpr uint32_t maxVectorOpWidth = 128;
349
350 LogicalResult
351 matchAndRewrite(GpuOp gpuOp, typename GpuOp::Adaptor adaptor,
352 ConversionPatternRewriter &rewriter) const override {
353 Location loc = gpuOp.getLoc();
354 Value memref = adaptor.getMemref();
355 Value unconvertedMemref = gpuOp.getMemref();
356 MemRefType memrefType = cast<MemRefType>(unconvertedMemref.getType());
357
358 if (chipset.majorVersion < 9)
359 return gpuOp.emitOpError("raw buffer ops require GCN or higher");
360
361 Value storeData = adaptor.getODSOperands(0)[0];
362 if (storeData == memref) // no write component to this op
363 storeData = Value();
364 Type wantedDataType;
365 if (storeData)
366 wantedDataType = storeData.getType();
367 else
368 wantedDataType = gpuOp.getODSResults(0)[0].getType();
369
370 Value atomicCmpData = Value();
371 // Operand index 1 of a load is the indices, trying to read them can crash.
372 if (storeData) {
373 Value maybeCmpData = adaptor.getODSOperands(1)[0];
374 if (maybeCmpData != memref)
375 atomicCmpData = maybeCmpData;
376 }
377
378 Type llvmWantedDataType = this->typeConverter->convertType(wantedDataType);
379
380 Type i32 = rewriter.getI32Type();
381
382 // Get the type size in bytes.
383 DataLayout dataLayout = DataLayout::closest(gpuOp);
384 int64_t elementByteWidth =
385 dataLayout.getTypeSizeInBits(memrefType.getElementType()) / 8;
386 Value byteWidthConst = createI32Constant(rewriter, loc, elementByteWidth);
387
388 // If we want to load a vector<NxT> with total size <= 32
389 // bits, use a scalar load and bitcast it. Similarly, if bitsize(T) < 32
390 // and the total load size is >= 32, use a vector load of N / (bitsize(T) /
391 // 32) x i32 and bitcast. Also, the CAS intrinsic requires integer operands,
392 // so bitcast any floats to integers.
393 Type llvmBufferValType = llvmWantedDataType;
394 if (atomicCmpData) {
395 if (auto floatType = dyn_cast<FloatType>(wantedDataType))
396 llvmBufferValType = this->getTypeConverter()->convertType(
397 rewriter.getIntegerType(floatType.getWidth()));
398 }
399 if (auto dataVector = dyn_cast<VectorType>(wantedDataType)) {
400 uint32_t vecLen = dataVector.getNumElements();
401 uint32_t elemBits =
402 dataLayout.getTypeSizeInBits(dataVector.getElementType());
403 uint32_t totalBits = elemBits * vecLen;
404 bool usePackedFp16 =
405 isa_and_present<RawBufferAtomicFaddOp>(*gpuOp) && vecLen == 2;
406 if (totalBits > maxVectorOpWidth)
407 return gpuOp.emitOpError(
408 "Total width of loads or stores must be no more than " +
409 Twine(maxVectorOpWidth) + " bits, but we call for " +
410 Twine(totalBits) +
411 " bits. This should've been caught in validation");
412 if (!usePackedFp16 && elemBits < 32) {
413 if (totalBits > 32) {
414 if (totalBits % 32 != 0)
415 return gpuOp.emitOpError("Load or store of more than 32-bits that "
416 "doesn't fit into words. Can't happen\n");
417 llvmBufferValType = this->typeConverter->convertType(
418 VectorType::get(totalBits / 32, i32));
419 } else {
420 llvmBufferValType = this->typeConverter->convertType(
421 rewriter.getIntegerType(totalBits));
422 }
423 }
424 }
425 if (auto vecType = dyn_cast<VectorType>(llvmBufferValType)) {
426 // Buffer intrinsics doesn't support 1-element vectors, cast them to
427 // scalars.
428 if (vecType.getNumElements() == 1)
429 llvmBufferValType = vecType.getElementType();
430 }
431
432 SmallVector<Value, 6> args;
433 if (storeData) {
434 if (llvmBufferValType != llvmWantedDataType) {
435 Value castForStore = LLVM::BitcastOp::create(
436 rewriter, loc, llvmBufferValType, storeData);
437 args.push_back(castForStore);
438 } else {
439 args.push_back(storeData);
440 }
441 }
442
443 if (atomicCmpData) {
444 if (llvmBufferValType != llvmWantedDataType) {
445 Value castForCmp = LLVM::BitcastOp::create(
446 rewriter, loc, llvmBufferValType, atomicCmpData);
447 args.push_back(castForCmp);
448 } else {
449 args.push_back(atomicCmpData);
450 }
451 }
452
453 // Construct buffer descriptor from memref, attributes
454 int64_t offset = 0;
455 SmallVector<int64_t, 5> strides;
456 if (failed(memrefType.getStridesAndOffset(strides, offset)))
457 return gpuOp.emitOpError("Can't lower non-stride-offset memrefs");
458
459 MemRefDescriptor memrefDescriptor(memref);
460
461 Value ptr = memrefDescriptor.bufferPtr(
462 rewriter, loc, *this->getTypeConverter(), memrefType);
463 Value numRecords =
464 getNumRecords(rewriter, loc, memrefType, memrefDescriptor, strides,
465 elementByteWidth, chipset, adaptor.getBoundsCheck());
466 Value resource = makeBufferRsrc(rewriter, loc, ptr, numRecords,
467 adaptor.getBoundsCheck(), chipset);
468 args.push_back(resource);
469
470 // Indexing (voffset)
471 Value voffset = getLinearIndexI32(rewriter, loc, memrefDescriptor,
472 adaptor.getIndices(), strides);
473 if (std::optional<int32_t> indexOffset = adaptor.getIndexOffset();
474 indexOffset && *indexOffset > 0) {
475 Value extraOffsetConst = createI32Constant(rewriter, loc, *indexOffset);
476 voffset = voffset ? LLVM::AddOp::create(rewriter, loc, voffset,
477 extraOffsetConst)
478 : extraOffsetConst;
479 }
480 voffset = LLVM::MulOp::create(rewriter, loc, voffset, byteWidthConst);
481 args.push_back(voffset);
482
483 // SGPR offset.
484 Value sgprOffset = adaptor.getSgprOffset();
485 if (!sgprOffset)
486 sgprOffset = createI32Constant(rewriter, loc, 0);
487 sgprOffset = LLVM::MulOp::create(rewriter, loc, sgprOffset, byteWidthConst);
488 args.push_back(sgprOffset);
489
490 // bit 0: GLC = 0 (atomics drop value, less coherency)
491 // bits 1-2: SLC, DLC = 0 (similarly)
492 // bit 3: swizzled (0 for raw)
493 args.push_back(createI32Constant(rewriter, loc, 0));
494
495 llvm::SmallVector<Type, 1> resultTypes(gpuOp->getNumResults(),
496 llvmBufferValType);
497 Operation *lowered = Intrinsic::create(rewriter, loc, resultTypes, args,
498 ArrayRef<NamedAttribute>());
499 if (lowered->getNumResults() == 1) {
500 Value replacement = lowered->getResult(0);
501 if (llvmBufferValType != llvmWantedDataType) {
502 replacement = LLVM::BitcastOp::create(rewriter, loc, llvmWantedDataType,
504 }
505 rewriter.replaceOp(gpuOp, replacement);
506 } else {
507 rewriter.eraseOp(gpuOp);
508 }
509 return success();
510 }
511};
512
513// TODO: AMDGPU backend already have all this bitpacking logic, we should move
514// it to some common place.
515/// Vmcnt, Expcnt and Lgkmcnt are decoded as follows:
516/// Vmcnt = Waitcnt[3:0] (pre-gfx9)
517/// Vmcnt = Waitcnt[15:14,3:0] (gfx9,10)
518/// Vmcnt = Waitcnt[15:10] (gfx11)
519/// Expcnt = Waitcnt[6:4] (pre-gfx11)
520/// Expcnt = Waitcnt[2:0] (gfx11)
521/// Lgkmcnt = Waitcnt[11:8] (pre-gfx10)
522/// Lgkmcnt = Waitcnt[13:8] (gfx10)
523/// Lgkmcnt = Waitcnt[9:4] (gfx11)
524static FailureOr<unsigned> encodeWaitcnt(Chipset chipset, unsigned vmcnt,
525 unsigned expcnt, unsigned lgkmcnt) {
526 if (chipset.majorVersion < 9) {
527 vmcnt = std::min(15u, vmcnt);
528 expcnt = std::min(7u, expcnt);
529 lgkmcnt = std::min(15u, lgkmcnt);
530 return vmcnt | (expcnt << 4) | (lgkmcnt << 8);
531 }
532 if (chipset.majorVersion == 9) {
533 vmcnt = std::min(63u, vmcnt);
534 expcnt = std::min(7u, expcnt);
535 lgkmcnt = std::min(15u, lgkmcnt);
536 unsigned lowBits = vmcnt & 0xF;
537 unsigned highBits = (vmcnt >> 4) << 14;
538 unsigned otherCnts = (expcnt << 4) | (lgkmcnt << 8);
539 return lowBits | highBits | otherCnts;
540 }
541 if (chipset.majorVersion == 10) {
542 vmcnt = std::min(63u, vmcnt);
543 expcnt = std::min(7u, expcnt);
544 lgkmcnt = std::min(63u, lgkmcnt);
545 unsigned lowBits = vmcnt & 0xF;
546 unsigned highBits = (vmcnt >> 4) << 14;
547 unsigned otherCnts = (expcnt << 4) | (lgkmcnt << 8);
548 return lowBits | highBits | otherCnts;
549 }
550 if (chipset.majorVersion == 11) {
551 vmcnt = std::min(63u, vmcnt);
552 expcnt = std::min(7u, expcnt);
553 lgkmcnt = std::min(63u, lgkmcnt);
554 return (vmcnt << 10) | expcnt | (lgkmcnt << 4);
555 }
556 return failure();
557}
558
559struct MemoryCounterWaitOpLowering
560 : public ConvertOpToLLVMPattern<MemoryCounterWaitOp> {
561 MemoryCounterWaitOpLowering(const LLVMTypeConverter &converter,
562 Chipset chipset)
563 : ConvertOpToLLVMPattern<MemoryCounterWaitOp>(converter),
564 chipset(chipset) {}
565
566 Chipset chipset;
567
568 LogicalResult
569 matchAndRewrite(MemoryCounterWaitOp op, OpAdaptor adaptor,
570 ConversionPatternRewriter &rewriter) const override {
571 if (chipset.majorVersion >= 12) {
572 Location loc = op.getLoc();
573 if (std::optional<int> ds = adaptor.getDs())
574 ROCDL::WaitDscntOp::create(rewriter, loc, *ds);
575
576 if (std::optional<int> load = adaptor.getLoad())
577 ROCDL::WaitLoadcntOp::create(rewriter, loc, *load);
578
579 if (std::optional<int> store = adaptor.getStore())
580 ROCDL::WaitStorecntOp::create(rewriter, loc, *store);
581
582 if (std::optional<int> exp = adaptor.getExp())
583 ROCDL::WaitExpcntOp::create(rewriter, loc, *exp);
584
585 if (std::optional<int> tensor = adaptor.getTensor())
586 ROCDL::WaitTensorcntOp::create(rewriter, loc, *tensor);
587
588 rewriter.eraseOp(op);
589 return success();
590 }
591
592 if (adaptor.getTensor())
593 return op.emitOpError("unsupported chipset");
594
595 auto getVal = [](Attribute attr) -> unsigned {
596 if (attr)
597 return cast<IntegerAttr>(attr).getInt();
598
599 // This value will be clamped to the maximum value for the chipset.
600 return 1024;
601 };
602 unsigned ds = getVal(adaptor.getDsAttr());
603 unsigned exp = getVal(adaptor.getExpAttr());
604
605 unsigned vmcnt = 1024;
606 Attribute load = adaptor.getLoadAttr();
607 Attribute store = adaptor.getStoreAttr();
608 if (load && store) {
609 vmcnt = getVal(load) + getVal(store);
610 } else if (load) {
611 vmcnt = getVal(load);
612 } else if (store) {
613 vmcnt = getVal(store);
614 }
615
616 FailureOr<unsigned> waitcnt = encodeWaitcnt(chipset, vmcnt, exp, ds);
617 if (failed(waitcnt))
618 return op.emitOpError("unsupported chipset");
619
620 rewriter.replaceOpWithNewOp<ROCDL::SWaitcntOp>(op, *waitcnt);
621 return success();
622 }
623};
624
625struct LDSBarrierOpLowering : public ConvertOpToLLVMPattern<LDSBarrierOp> {
626 LDSBarrierOpLowering(const LLVMTypeConverter &converter, Chipset chipset)
627 : ConvertOpToLLVMPattern<LDSBarrierOp>(converter), chipset(chipset) {}
628
629 Chipset chipset;
630
631 LogicalResult
632 matchAndRewrite(LDSBarrierOp op, LDSBarrierOp::Adaptor adaptor,
633 ConversionPatternRewriter &rewriter) const override {
634 Location loc = op.getLoc();
635 // This ensures that waits on global memory aren't introduced on
636 // chips that don't have the BackOffBarrier feature enabled in LLVM.
637 bool requiresInlineAsm = chipset < kGfx90a;
638
639 Attribute mmra =
640 rewriter.getAttr<LLVM::MMRATagAttr>("amdgpu-synchronize-as", "local");
641 // Note: while there *is* a workgroup-one-as scope, this, when combined with
642 // the MMRA, will lead to the fence having no effect. This is because the
643 // codepaths for an atomic load or store will observe that a
644 // one-address-space atomic to LDS requires no synchronization because
645 // operations on LDS are totally ordered with respect to each other, and so
646 // will not emit the correct waitcnt operations that these fences are
647 // intended to produce. Therefore, we use a broader type of fence and rely
648 // on the MMRA to relax it to the semantics we want.
649 StringRef scope = "workgroup";
650
651 auto relFence = LLVM::FenceOp::create(rewriter, loc,
652 LLVM::AtomicOrdering::release, scope);
653 relFence->setDiscardableAttr(LLVM::LLVMDialect::getMmraAttrName(), mmra);
654 if (requiresInlineAsm) {
655 auto asmDialectAttr = LLVM::AsmDialectAttr::get(rewriter.getContext(),
656 LLVM::AsmDialect::AD_ATT);
657 const char *asmStr = ";;;WARNING: BREAKS DEBUG WATCHES\ns_barrier";
658 const char *constraints = "";
659 LLVM::InlineAsmOp::create(
660 rewriter, loc,
661 /*resultTypes=*/TypeRange(), /*operands=*/ValueRange(),
662 /*asm_string=*/asmStr, constraints, /*has_side_effects=*/true,
663 /*is_align_stack=*/false, LLVM::TailCallKind::None,
664 /*asm_dialect=*/asmDialectAttr,
665 /*operand_attrs=*/ArrayAttr());
666 } else if (chipset.majorVersion < 12) {
667 ROCDL::SBarrierOp::create(rewriter, loc);
668 } else {
669 ROCDL::BarrierSignalOp::create(rewriter, loc, -1);
670 ROCDL::BarrierWaitOp::create(rewriter, loc, -1);
671 }
672
673 auto acqFence = LLVM::FenceOp::create(rewriter, loc,
674 LLVM::AtomicOrdering::acquire, scope);
675 acqFence->setDiscardableAttr(LLVM::LLVMDialect::getMmraAttrName(), mmra);
676 rewriter.replaceOp(op, acqFence);
677 return success();
678 }
679};
680
681struct SchedBarrierOpLowering : public ConvertOpToLLVMPattern<SchedBarrierOp> {
682 SchedBarrierOpLowering(const LLVMTypeConverter &converter, Chipset chipset)
683 : ConvertOpToLLVMPattern<SchedBarrierOp>(converter), chipset(chipset) {}
684
685 Chipset chipset;
686
687 LogicalResult
688 matchAndRewrite(SchedBarrierOp op, SchedBarrierOp::Adaptor adaptor,
689 ConversionPatternRewriter &rewriter) const override {
690 rewriter.replaceOpWithNewOp<ROCDL::SchedBarrier>(op,
691 (uint32_t)op.getOpts());
692 return success();
693 }
694};
695
696} // namespace
697
698/// Pack small float vector operands (fp4/fp6/fp8/bf16) into the format
699/// expected by scaled matrix multiply intrinsics (MFMA/WMMA).
700///
701/// Specifically:
702/// 1. If the element type is bfloat16, bitcast it to i16 unless rocdl intrinsic
703/// allows bf16. Newer MFMAs support bf16 types on operand, check
704/// IntrinsicsAMDGPU.td file for reference.
705/// 2. If instead we have a more than 64-bit quantity, use a <N / 4 x i32>
706/// instead, which is what the f8f6f4 intrinsics use.
707/// 3. If `input` is a vector of N <= 8 bytes, bitcast it to a (N * 8)-bit
708/// integer.
709///
710/// Note that the type of `input` has already been LLVM type converted:
711/// therefore 8-bit and smaller floats are represented as their corresponding
712/// `iN` integers.
713static Value packSmallFloatVectorOperand(ConversionPatternRewriter &rewriter,
714 Location loc, Value input,
715 bool allowBf16 = true) {
716 Type inputType = input.getType();
717 if (auto vectorType = dyn_cast<VectorType>(inputType)) {
718 if (vectorType.getElementType().isBF16() && !allowBf16)
719 return LLVM::BitcastOp::create(
720 rewriter, loc, vectorType.clone(rewriter.getI16Type()), input);
721 if (vectorType.getElementType().isInteger(8) &&
722 vectorType.getNumElements() <= 8)
723 return LLVM::BitcastOp::create(
724 rewriter, loc,
725 rewriter.getIntegerType(vectorType.getNumElements() * 8), input);
726 if (isa<IntegerType>(vectorType.getElementType()) &&
727 vectorType.getElementTypeBitWidth() <= 8) {
728 int64_t numWords = llvm::divideCeil(
729 vectorType.getNumElements() * vectorType.getElementTypeBitWidth(),
730 32);
731 return LLVM::BitcastOp::create(
732 rewriter, loc, VectorType::get(numWords, rewriter.getI32Type()),
733 input);
734 }
735 }
736 return input;
737}
738
739/// Converts packed vector operands to the expected ROCDL types.
740static Value convertPackedVectorOperand(ConversionPatternRewriter &rewriter,
741 Location loc, Value input,
742 bool allowBf16 = true) {
743 Type inputType = input.getType();
744 auto vectorType = cast<VectorType>(inputType);
745 // bf16 -> i16 when not allowed (pre-gfx950).
746 if (vectorType.getElementType().isBF16() && !allowBf16)
747 return LLVM::BitcastOp::create(
748 rewriter, loc, vectorType.clone(rewriter.getI16Type()), input);
749 // i8/fp8 vectors -> vector<Nxi32>.
750 if (isa<IntegerType>(vectorType.getElementType()) &&
751 vectorType.getElementTypeBitWidth() <= 8) {
752 int64_t numWords = llvm::divideCeil(
753 vectorType.getNumElements() * vectorType.getElementTypeBitWidth(), 32);
754 Type castType = (numWords > 1)
755 ? Type{VectorType::get(numWords, rewriter.getI32Type())}
756 : rewriter.getI32Type();
757 return LLVM::BitcastOp::create(rewriter, loc, castType, input);
758 }
759 return input;
760}
761
762/// Converts the scaled MFMA/WMMA operands, `scalesA` and `scalesB`, from MLIR
763/// AMDGPU dialect convention to ROCDL and LLVM AMDGPU intrinsics convention.
764///
765/// Specifically:
766/// 1. If `input` is a i8 value, zero extend it to i32
767/// 2. If `input` is a vector of length 4 or 8 and type i8, cast it to i32
768///
769/// Note that the type of `input` has already been LLVM type converted:
770/// therefore 8-bit and smaller floats are represented as their corresponding
771/// `iN` integers.
772static Value castScaleOperand(ConversionPatternRewriter &rewriter, Location loc,
773 Value input) {
774 return TypeSwitch<Type, Value>(input.getType())
775 .Case([&](IntegerType) {
776 // Handle scalar i8: zero extend to i32.
777 return LLVM::ZExtOp::create(rewriter, loc, rewriter.getI32Type(),
778 input);
779 })
780 .Case([&](VectorType vectorType) {
781 // Handle vector<4xi8> -> i32 or vector<8xi8> -> i64.
782 int64_t numElements = vectorType.getNumElements();
783 assert((numElements == 4 || numElements == 8) &&
784 "scale operand must be a vector of length 4 or 8");
785 IntegerType outputType =
786 (numElements == 4) ? rewriter.getI32Type() : rewriter.getI64Type();
787 return LLVM::BitcastOp::create(rewriter, loc, outputType, input);
788 })
789 .DefaultUnreachable("unexpected input type for scale operand");
790}
791
792/// Maps f8 scale element types to WMMA scale format codes.
793static std::optional<uint32_t> getWmmaScaleFormat(Type elemType) {
795 .Case([](Float8E8M0FNUType) { return 0; })
796 .Case([](Float8E4M3FNType) { return 2; })
797 .Default(std::nullopt);
798}
799
800/// Determines the ROCDL intrinsic name for scaled WMMA based on dimensions
801/// and scale block size (16 or 32).
802static std::optional<StringRef>
804 if (m == 16 && n == 16 && k == 128)
805 return isScale16
806 ? ROCDL::wmma_scale16_f32_16x16x128_f8f6f4::getOperationName()
807 : ROCDL::wmma_scale_f32_16x16x128_f8f6f4::getOperationName();
808
809 if (m == 32 && n == 16 && k == 128)
810 return isScale16 ? ROCDL::wmma_scale16_f32_32x16x128_f4::getOperationName()
811 : ROCDL::wmma_scale_f32_32x16x128_f4::getOperationName();
812
813 return std::nullopt;
814}
815
816/// Push an input operand. If it is a float type, nothing to do. If it is
817/// an integer type, then we need to also push its signdness (1 for signed, 0
818/// for unsigned) and we need to pack the input 16xi8 vector into a 4xi32
819/// vector (or the 8xi8 vector into a 2xi32 one for gfx12+).
820/// We also need to convert bfloat inputs to i16 to account for the bfloat
821/// intrinsics having been defined before the AMD backend supported bfloat. We
822/// similarly need to pack 8-bit float types into integers as if they were i8
823/// (which they are for the backend's purposes).
825 ConversionPatternRewriter &rewriter, Location loc,
826 const TypeConverter *typeConverter, bool isUnsigned, Value llvmInput,
827 Value mlirInput, SmallVectorImpl<Value> &operands,
828 SmallVectorImpl<NamedAttribute> &attrs, StringRef attrName) {
829 Type inputType = llvmInput.getType();
830 auto vectorType = dyn_cast<VectorType>(inputType);
831 if (!vectorType) {
832 operands.push_back(llvmInput);
833 return;
834 }
835 Type elemType = vectorType.getElementType();
836 if (elemType.getIntOrFloatBitWidth() > 8) {
837 operands.push_back(llvmInput);
838 return;
839 }
840
841 // We need to check the type of the input before conversion to properly test
842 // for int8. This is because, in LLVM, fp8 type is converted to int8, so the
843 // fp8/int8 information is lost during the conversion process.
844 auto mlirInputType = cast<VectorType>(mlirInput.getType());
845 bool isInputInteger = mlirInputType.getElementType().isInteger();
846 if (isInputInteger) {
847 // if element type is 8-bit signed or unsigned, ignore the isUnsigned flag
848 bool localIsUnsigned = isUnsigned;
849 if (elemType.isUnsignedInteger()) {
850 localIsUnsigned = true;
851 } else if (elemType.isSignedInteger()) {
852 localIsUnsigned = false;
853 }
854 attrs.push_back(
855 NamedAttribute(attrName, rewriter.getBoolAttr(!localIsUnsigned)));
856 }
857
858 int64_t numBits =
859 vectorType.getNumElements() * elemType.getIntOrFloatBitWidth();
860 Type i32 = rewriter.getI32Type();
861 Type intrinsicInType = numBits <= 32
862 ? (Type)rewriter.getIntegerType(numBits)
863 : (Type)VectorType::get(numBits / 32, i32);
864 auto llvmIntrinsicInType = typeConverter->convertType(intrinsicInType);
865 Value castInput = rewriter.createOrFold<LLVM::BitcastOp>(
866 loc, llvmIntrinsicInType, llvmInput);
867 // The wave64-mode 16x16x16 intrinsics that take 4-bit integers only need
868 // (256 / 64) * 4 = 16 bits of input (on gfx12+) but take i32 arguments.
869 // Add in the zeros here.
870 if (numBits < 32)
871 castInput = LLVM::ZExtOp::create(rewriter, loc, i32, castInput);
872 operands.push_back(castInput);
873}
874
875/// Push the output operand. For many cases this is only pushing the output in
876/// the operand list. But when we have f16 -> f16 or bf16 -> bf16 intrinsics,
877/// since the same numbers of VGPRs is used, we need to decide if to store the
878/// result in the upper 16 bits of the VGPRs or in the lower part. To store the
879/// result in the lower 16 bits, set subwordOffset to 1, otherwise result will
880/// be stored it in the upper part. The subwordOffset must not be set for gfx12,
881/// as the instructions have been changed to return fewer registers instead.
882static void wmmaPushOutputOperand(ConversionPatternRewriter &rewriter,
883 Location loc,
884 const TypeConverter *typeConverter,
885 Value output, int32_t subwordOffset,
886 bool clamp, SmallVectorImpl<Value> &operands,
888 Type inputType = output.getType();
889 auto vectorType = dyn_cast<VectorType>(inputType);
890 Type elemType = vectorType.getElementType();
891 operands.push_back(output);
892 if (elemType.isF16() || elemType.isBF16() || elemType.isInteger(16)) {
893 attrs.push_back(
894 NamedAttribute("opsel", rewriter.getBoolAttr(subwordOffset)));
895 } else if (elemType.isInteger(32)) {
896 attrs.push_back(NamedAttribute("clamp", rewriter.getBoolAttr(clamp)));
897 }
898}
899
900/// Return true if `type` is the E5M2 variant of an 8-bit float that is
901/// supported by the `_bf8` instructions on the given `chipset`.
902static bool typeIsExpectedBf8ForChipset(Chipset chipset, Type type) {
903 return (chipset == kGfx942 && isa<Float8E5M2FNUZType>(type)) ||
904 (hasOcpFp8(chipset) && isa<Float8E5M2Type>(type));
905}
906
907/// Return true if `type` is the E4M3FN variant of an 8-bit float that is
908/// supported by the `_fp8` instructions on the given `chipset`.
909static bool typeIsExpectedFp8ForChipset(Chipset chipset, Type type) {
910 return (chipset == kGfx942 && isa<Float8E4M3FNUZType>(type)) ||
911 (hasOcpFp8(chipset) && isa<Float8E4M3FNType>(type));
912}
913
914/// Return the `rocdl` intrinsic corresponding to a MFMA operation `mfma`
915/// if one exists. This includes checking to ensure the intrinsic is supported
916/// on the architecture you are compiling for.
917static std::optional<StringRef> mfmaOpToIntrinsic(MFMAOp mfma,
918 Chipset chipset) {
919 uint32_t m = mfma.getM(), n = mfma.getN(), k = mfma.getK(),
920 b = mfma.getBlocks();
921 Type sourceElem = getElementTypeOrSelf(mfma.getSourceA().getType());
922 Type destElem = getElementTypeOrSelf(mfma.getDestC().getType());
923
924 if (sourceElem.isF32() && destElem.isF32()) {
925 if (mfma.getReducePrecision() && chipset >= kGfx942) {
926 if (m == 32 && n == 32 && k == 4 && b == 1)
927 return ROCDL::mfma_f32_32x32x4_xf32::getOperationName();
928 if (m == 16 && n == 16 && k == 8 && b == 1)
929 return ROCDL::mfma_f32_16x16x8_xf32::getOperationName();
930 }
931 if (m == 32 && n == 32 && k == 1 && b == 2)
932 return ROCDL::mfma_f32_32x32x1f32::getOperationName();
933 if (m == 16 && n == 16 && k == 1 && b == 4)
934 return ROCDL::mfma_f32_16x16x1f32::getOperationName();
935 if (m == 4 && n == 4 && k == 1 && b == 16)
936 return ROCDL::mfma_f32_4x4x1f32::getOperationName();
937 if (m == 32 && n == 32 && k == 2 && b == 1)
938 return ROCDL::mfma_f32_32x32x2f32::getOperationName();
939 if (m == 16 && n == 16 && k == 4 && b == 1)
940 return ROCDL::mfma_f32_16x16x4f32::getOperationName();
941 }
942
943 if (sourceElem.isF16() && destElem.isF32()) {
944 if (chipset >= kGfx950) {
945 if (m == 32 && n == 32 && k == 16 && b == 1)
946 return ROCDL::mfma_f32_32x32x16_f16::getOperationName();
947 if (m == 16 && n == 16 && k == 32 && b == 1)
948 return ROCDL::mfma_f32_16x16x32_f16::getOperationName();
949 }
950 if (m == 32 && n == 32 && k == 4 && b == 2)
951 return ROCDL::mfma_f32_32x32x4f16::getOperationName();
952 if (m == 16 && n == 16 && k == 4 && b == 4)
953 return ROCDL::mfma_f32_16x16x4f16::getOperationName();
954 if (m == 4 && n == 4 && k == 4 && b == 16)
955 return ROCDL::mfma_f32_4x4x4f16::getOperationName();
956 if (m == 32 && n == 32 && k == 8 && b == 1)
957 return ROCDL::mfma_f32_32x32x8f16::getOperationName();
958 if (m == 16 && n == 16 && k == 16 && b == 1)
959 return ROCDL::mfma_f32_16x16x16f16::getOperationName();
960 }
961
962 if (sourceElem.isBF16() && destElem.isF32()) {
963 if (chipset >= kGfx950) {
964 if (m == 32 && n == 32 && k == 16 && b == 1)
965 return ROCDL::mfma_f32_32x32x16_bf16::getOperationName();
966 if (m == 16 && n == 16 && k == 32 && b == 1)
967 return ROCDL::mfma_f32_16x16x32_bf16::getOperationName();
968 }
969 if (chipset >= kGfx90a) {
970 if (m == 32 && n == 32 && k == 4 && b == 2)
971 return ROCDL::mfma_f32_32x32x4bf16_1k::getOperationName();
972 if (m == 16 && n == 16 && k == 4 && b == 4)
973 return ROCDL::mfma_f32_16x16x4bf16_1k::getOperationName();
974 if (m == 4 && n == 4 && k == 4 && b == 16)
975 return ROCDL::mfma_f32_4x4x4bf16_1k::getOperationName();
976 if (m == 32 && n == 32 && k == 8 && b == 1)
977 return ROCDL::mfma_f32_32x32x8bf16_1k::getOperationName();
978 if (m == 16 && n == 16 && k == 16 && b == 1)
979 return ROCDL::mfma_f32_16x16x16bf16_1k::getOperationName();
980 }
981 if (m == 32 && n == 32 && k == 2 && b == 2)
982 return ROCDL::mfma_f32_32x32x2bf16::getOperationName();
983 if (m == 16 && n == 16 && k == 2 && b == 4)
984 return ROCDL::mfma_f32_16x16x2bf16::getOperationName();
985 if (m == 4 && n == 4 && k == 2 && b == 16)
986 return ROCDL::mfma_f32_4x4x2bf16::getOperationName();
987 if (m == 32 && n == 32 && k == 4 && b == 1)
988 return ROCDL::mfma_f32_32x32x4bf16::getOperationName();
989 if (m == 16 && n == 16 && k == 8 && b == 1)
990 return ROCDL::mfma_f32_16x16x8bf16::getOperationName();
991 }
992
993 if (sourceElem.isInteger(8) && destElem.isInteger(32)) {
994 if (chipset >= kGfx950) {
995 if (m == 32 && n == 32 && k == 32 && b == 1)
996 return ROCDL::mfma_i32_32x32x32_i8::getOperationName();
997 if (m == 16 && n == 16 && k == 64 && b == 1)
998 return ROCDL::mfma_i32_16x16x64_i8::getOperationName();
999 }
1000 if (m == 32 && n == 32 && k == 4 && b == 2)
1001 return ROCDL::mfma_i32_32x32x4i8::getOperationName();
1002 if (m == 16 && n == 16 && k == 4 && b == 4)
1003 return ROCDL::mfma_i32_16x16x4i8::getOperationName();
1004 if (m == 4 && n == 4 && k == 4 && b == 16)
1005 return ROCDL::mfma_i32_4x4x4i8::getOperationName();
1006 if (m == 32 && n == 32 && k == 8 && b == 1)
1007 return ROCDL::mfma_i32_32x32x8i8::getOperationName();
1008 if (m == 16 && n == 16 && k == 16 && b == 1)
1009 return ROCDL::mfma_i32_16x16x16i8::getOperationName();
1010 if (m == 32 && n == 32 && k == 16 && b == 1 && chipset >= kGfx942)
1011 return ROCDL::mfma_i32_32x32x16_i8::getOperationName();
1012 if (m == 16 && n == 16 && k == 32 && b == 1 && chipset >= kGfx942)
1013 return ROCDL::mfma_i32_16x16x32_i8::getOperationName();
1014 }
1015
1016 if (sourceElem.isF64() && destElem.isF64() && chipset >= kGfx90a) {
1017 if (m == 16 && n == 16 && k == 4 && b == 1)
1018 return ROCDL::mfma_f64_16x16x4f64::getOperationName();
1019 if (m == 4 && n == 4 && k == 4 && b == 4)
1020 return ROCDL::mfma_f64_4x4x4f64::getOperationName();
1021 }
1022
1023 if (destElem.isF32() && typeIsExpectedBf8ForChipset(chipset, sourceElem)) {
1024 // Known to be correct because there are no scalar f8 instructions and
1025 // because a length mismatch will have been caught by the verifier.
1026 Type sourceBElem =
1027 cast<VectorType>(mfma.getSourceB().getType()).getElementType();
1028 if (m == 16 && n == 16 && k == 32 && b == 1) {
1029 if (typeIsExpectedBf8ForChipset(chipset, sourceBElem))
1030 return ROCDL::mfma_f32_16x16x32_bf8_bf8::getOperationName();
1031 if (typeIsExpectedFp8ForChipset(chipset, sourceBElem))
1032 return ROCDL::mfma_f32_16x16x32_bf8_fp8::getOperationName();
1033 }
1034 if (m == 32 && n == 32 && k == 16 && b == 1) {
1035 if (typeIsExpectedBf8ForChipset(chipset, sourceBElem))
1036 return ROCDL::mfma_f32_32x32x16_bf8_bf8::getOperationName();
1037 if (typeIsExpectedFp8ForChipset(chipset, sourceBElem))
1038 return ROCDL::mfma_f32_32x32x16_bf8_fp8::getOperationName();
1039 }
1040 }
1041
1042 if (destElem.isF32() && typeIsExpectedFp8ForChipset(chipset, sourceElem)) {
1043 Type sourceBElem =
1044 cast<VectorType>(mfma.getSourceB().getType()).getElementType();
1045 if (m == 16 && n == 16 && k == 32 && b == 1) {
1046 if (typeIsExpectedBf8ForChipset(chipset, sourceBElem))
1047 return ROCDL::mfma_f32_16x16x32_fp8_bf8::getOperationName();
1048 if (typeIsExpectedFp8ForChipset(chipset, sourceBElem))
1049 return ROCDL::mfma_f32_16x16x32_fp8_fp8::getOperationName();
1050 }
1051 if (m == 32 && n == 32 && k == 16 && b == 1) {
1052 if (typeIsExpectedBf8ForChipset(chipset, sourceBElem))
1053 return ROCDL::mfma_f32_32x32x16_fp8_bf8::getOperationName();
1054 if (typeIsExpectedFp8ForChipset(chipset, sourceBElem))
1055 return ROCDL::mfma_f32_32x32x16_fp8_fp8::getOperationName();
1056 }
1057 }
1058
1059 return std::nullopt;
1060}
1061
1062static std::optional<uint32_t> smallFloatTypeToFormatCode(Type mlirElemType) {
1064 .Case([](Float8E4M3FNType) { return 0u; })
1065 .Case([](Float8E5M2Type) { return 1u; })
1066 .Case([](Float6E2M3FNType) { return 2u; })
1067 .Case([](Float6E3M2FNType) { return 3u; })
1068 .Case([](Float4E2M1FNType) { return 4u; })
1069 .Default(std::nullopt);
1070}
1071
1072/// If there is a scaled MFMA instruction for the input element types `aType`
1073/// and `bType`, output type `destType`, problem size M, N, K, and B (number of
1074/// blocks) on the given `chipset`, return a tuple consisting of the
1075/// OperationName of the intrinsic and the type codes that need to be passed to
1076/// that intrinsic. Note that this is also used to implement some un-scaled
1077/// MFMAs, since the compiler represents the ordinary instruction as a "scaled"
1078/// MFMA with a scale of 0.
1079static std::optional<std::tuple<StringRef, uint32_t, uint32_t>>
1080mfmaOpToScaledIntrinsic(Type aType, Type bType, Type destType, uint32_t m,
1081 uint32_t n, uint32_t k, uint32_t b, Chipset chipset) {
1082 aType = getElementTypeOrSelf(aType);
1083 bType = getElementTypeOrSelf(bType);
1084 destType = getElementTypeOrSelf(destType);
1085
1086 if (chipset < kGfx950)
1087 return std::nullopt;
1088 if (!isa<Float32Type>(destType))
1089 return std::nullopt;
1090
1091 std::optional<uint32_t> aTypeCode = smallFloatTypeToFormatCode(aType);
1092 std::optional<uint32_t> bTypeCode = smallFloatTypeToFormatCode(bType);
1093 if (!aTypeCode || !bTypeCode)
1094 return std::nullopt;
1095
1096 if (m == 32 && n == 32 && k == 64 && b == 1)
1097 return std::tuple{ROCDL::mfma_scale_f32_32x32x64_f8f6f4::getOperationName(),
1098 *aTypeCode, *bTypeCode};
1099 if (m == 16 && n == 16 && k == 128 && b == 1)
1100 return std::tuple{
1101 ROCDL::mfma_scale_f32_16x16x128_f8f6f4::getOperationName(), *aTypeCode,
1102 *bTypeCode};
1103
1104 return std::nullopt;
1105}
1106
1107static std::optional<std::tuple<StringRef, uint32_t, uint32_t>>
1108mfmaOpToScaledIntrinsic(MFMAOp mfma, Chipset chipset) {
1110 mfma.getSourceA().getType(), mfma.getSourceB().getType(),
1111 mfma.getDestC().getType(), mfma.getM(), mfma.getN(), mfma.getK(),
1112 mfma.getBlocks(), chipset);
1113}
1114
1115static std::optional<std::tuple<StringRef, uint32_t, uint32_t>>
1116mfmaOpToScaledIntrinsic(ScaledMFMAOp smfma, Chipset chipset) {
1117 return mfmaOpToScaledIntrinsic(smfma.getSourceA().getType(),
1118 smfma.getSourceB().getType(),
1119 smfma.getDestC().getType(), smfma.getM(),
1120 smfma.getN(), smfma.getK(), 1u, chipset);
1121}
1122
1123/// Returns the `rocdl` intrinsic corresponding to a WMMA operation `wmma`
1124/// for RDNA3/4 architectures.
1125static std::optional<StringRef>
1126wmmaOpToIntrinsicRDNA(Type elemSourceType, Type elemBSourceType,
1127 Type elemDestType, uint32_t k, bool isRDNA3) {
1128 using fp8 = Float8E4M3FNType;
1129 using bf8 = Float8E5M2Type;
1130
1131 // Handle k == 16 for RDNA3/4.
1132 if (k == 16) {
1133 // Common patterns for RDNA3 and RDNA4.
1134 if (elemSourceType.isF16() && elemDestType.isF32())
1135 return ROCDL::wmma_f32_16x16x16_f16::getOperationName();
1136 if (elemSourceType.isBF16() && elemDestType.isF32())
1137 return ROCDL::wmma_f32_16x16x16_bf16::getOperationName();
1138 if (elemSourceType.isF16() && elemDestType.isF16())
1139 return ROCDL::wmma_f16_16x16x16_f16::getOperationName();
1140 if (elemSourceType.isBF16() && elemDestType.isBF16())
1141 return ROCDL::wmma_bf16_16x16x16_bf16::getOperationName();
1142 if (elemSourceType.isInteger(8) && elemDestType.isInteger(32))
1143 return ROCDL::wmma_i32_16x16x16_iu8::getOperationName();
1144
1145 // RDNA3 specific patterns.
1146 if (isRDNA3) {
1147 if (elemSourceType.isInteger(4) && elemDestType.isInteger(32))
1148 return ROCDL::wmma_i32_16x16x16_iu4::getOperationName();
1149 return std::nullopt;
1150 }
1151
1152 // RDNA4 specific patterns (fp8/bf8).
1153 if (isa<fp8>(elemSourceType) && isa<fp8>(elemBSourceType) &&
1154 elemDestType.isF32())
1155 return ROCDL::wmma_f32_16x16x16_fp8_fp8::getOperationName();
1156 if (isa<fp8>(elemSourceType) && isa<bf8>(elemBSourceType) &&
1157 elemDestType.isF32())
1158 return ROCDL::wmma_f32_16x16x16_fp8_bf8::getOperationName();
1159 if (isa<bf8>(elemSourceType) && isa<bf8>(elemBSourceType) &&
1160 elemDestType.isF32())
1161 return ROCDL::wmma_f32_16x16x16_bf8_bf8::getOperationName();
1162 if (isa<bf8>(elemSourceType) && isa<fp8>(elemBSourceType) &&
1163 elemDestType.isF32())
1164 return ROCDL::wmma_f32_16x16x16_bf8_fp8::getOperationName();
1165 if (elemSourceType.isInteger(4) && elemDestType.isInteger(32))
1166 return ROCDL::wmma_i32_16x16x16_iu4::getOperationName();
1167
1168 return std::nullopt;
1169 }
1170
1171 // Handle k == 32 for RDNA4.
1172 if (k == 32 && !isRDNA3) {
1173 if (elemSourceType.isInteger(4) && elemDestType.isInteger(32))
1174 return ROCDL::wmma_i32_16x16x32_iu4::getOperationName();
1175 }
1176
1177 return std::nullopt;
1178}
1179
1180/// Return the `rocdl` intrinsic corresponding to a WMMA operation `wmma`
1181/// for the gfx1250 architecture.
1182static std::optional<StringRef> wmmaOpToIntrinsicGfx1250(Type elemSourceType,
1183 Type elemBSourceType,
1184 Type elemDestType,
1185 uint32_t k) {
1186 using fp8 = Float8E4M3FNType;
1187 using bf8 = Float8E5M2Type;
1188
1189 if (k == 4) {
1190 if (elemSourceType.isF32() && elemDestType.isF32())
1191 return ROCDL::wmma_f32_16x16x4_f32::getOperationName();
1192
1193 return std::nullopt;
1194 }
1195
1196 if (k == 32) {
1197 if (elemSourceType.isF16() && elemDestType.isF32())
1198 return ROCDL::wmma_f32_16x16x32_f16::getOperationName();
1199 if (elemSourceType.isBF16() && elemDestType.isF32())
1200 return ROCDL::wmma_f32_16x16x32_bf16::getOperationName();
1201 if (elemSourceType.isF16() && elemDestType.isF16())
1202 return ROCDL::wmma_f16_16x16x32_f16::getOperationName();
1203 if (elemSourceType.isBF16() && elemDestType.isBF16())
1204 return ROCDL::wmma_bf16_16x16x32_bf16::getOperationName();
1205
1206 return std::nullopt;
1207 }
1208
1209 if (k == 64) {
1210 if (isa<fp8>(elemSourceType) && isa<fp8>(elemBSourceType)) {
1211 if (elemDestType.isF32())
1212 return ROCDL::wmma_f32_16x16x64_fp8_fp8::getOperationName();
1213 if (elemDestType.isF16())
1214 return ROCDL::wmma_f16_16x16x64_fp8_fp8::getOperationName();
1215 }
1216 if (isa<fp8>(elemSourceType) && isa<bf8>(elemBSourceType)) {
1217 if (elemDestType.isF32())
1218 return ROCDL::wmma_f32_16x16x64_fp8_bf8::getOperationName();
1219 if (elemDestType.isF16())
1220 return ROCDL::wmma_f16_16x16x64_fp8_bf8::getOperationName();
1221 }
1222 if (isa<bf8>(elemSourceType) && isa<bf8>(elemBSourceType)) {
1223 if (elemDestType.isF32())
1224 return ROCDL::wmma_f32_16x16x64_bf8_bf8::getOperationName();
1225 if (elemDestType.isF16())
1226 return ROCDL::wmma_f16_16x16x64_bf8_bf8::getOperationName();
1227 }
1228 if (isa<bf8>(elemSourceType) && isa<fp8>(elemBSourceType)) {
1229 if (elemDestType.isF32())
1230 return ROCDL::wmma_f32_16x16x64_bf8_fp8::getOperationName();
1231 if (elemDestType.isF16())
1232 return ROCDL::wmma_f16_16x16x64_bf8_fp8::getOperationName();
1233 }
1234 if (elemSourceType.isInteger(8) && elemDestType.isInteger(32))
1235 return ROCDL::wmma_i32_16x16x64_iu8::getOperationName();
1236
1237 return std::nullopt;
1238 }
1239
1240 if (k == 128) {
1241 if (isa<fp8>(elemSourceType) && isa<fp8>(elemBSourceType)) {
1242 if (elemDestType.isF32())
1243 return ROCDL::wmma_f32_16x16x128_fp8_fp8::getOperationName();
1244 if (elemDestType.isF16())
1245 return ROCDL::wmma_f16_16x16x128_fp8_fp8::getOperationName();
1246 }
1247 if (isa<fp8>(elemSourceType) && isa<bf8>(elemBSourceType)) {
1248 if (elemDestType.isF32())
1249 return ROCDL::wmma_f32_16x16x128_fp8_bf8::getOperationName();
1250 if (elemDestType.isF16())
1251 return ROCDL::wmma_f16_16x16x128_fp8_bf8::getOperationName();
1252 }
1253 if (isa<bf8>(elemSourceType) && isa<bf8>(elemBSourceType)) {
1254 if (elemDestType.isF32())
1255 return ROCDL::wmma_f32_16x16x128_bf8_bf8::getOperationName();
1256 if (elemDestType.isF16())
1257 return ROCDL::wmma_f16_16x16x128_bf8_bf8::getOperationName();
1258 }
1259 if (isa<bf8>(elemSourceType) && isa<fp8>(elemBSourceType)) {
1260 if (elemDestType.isF32())
1261 return ROCDL::wmma_f32_16x16x128_bf8_fp8::getOperationName();
1262 if (elemDestType.isF16())
1263 return ROCDL::wmma_f16_16x16x128_bf8_fp8::getOperationName();
1264 }
1265
1266 return std::nullopt;
1267 }
1268
1269 return std::nullopt;
1270}
1271
1272/// Returns the `rocdl` intrinsic corresponding to a SparseMFMA (smfmac)
1273/// operation if one exists. This includes checking to ensure the intrinsic is
1274/// supported on the architecture you are compiling for.
1275static std::optional<StringRef> smfmacOpToIntrinsic(SparseMFMAOp op,
1276 Chipset chipset) {
1277 bool isGfx950 = chipset >= kGfx950;
1278 auto isFp8 = [&](Type t) { return typeIsExpectedFp8ForChipset(chipset, t); };
1279 auto isBf8 = [&](Type t) { return typeIsExpectedBf8ForChipset(chipset, t); };
1280
1281 uint32_t m = op.getM(), n = op.getN(), k = op.getK();
1282 Type sourceAElem = getElementTypeOrSelf(op.getSourceA().getType());
1283 Type sourceBElem = getElementTypeOrSelf(op.getSourceB().getType());
1284 Type destElem = getElementTypeOrSelf(op.getDestC().getType());
1285
1286 if (m == 16 && n == 16 && k == 32) {
1287 if (sourceAElem.isF16() && sourceBElem.isF16() && destElem.isF32())
1288 return ROCDL::smfmac_f32_16x16x32_f16::getOperationName();
1289 if (sourceAElem.isBF16() && sourceBElem.isBF16() && destElem.isF32())
1290 return ROCDL::smfmac_f32_16x16x32_bf16::getOperationName();
1291 }
1292
1293 if (m == 16 && n == 16 && k == 64) {
1294 if (isGfx950) {
1295 if (sourceAElem.isF16() && sourceBElem.isF16() && destElem.isF32())
1296 return ROCDL::smfmac_f32_16x16x64_f16::getOperationName();
1297 if (sourceAElem.isBF16() && sourceBElem.isBF16() && destElem.isF32())
1298 return ROCDL::smfmac_f32_16x16x64_bf16::getOperationName();
1299 }
1300 if (sourceAElem.isInteger(8) && sourceBElem.isInteger(8) &&
1301 destElem.isInteger(32))
1302 return ROCDL::smfmac_i32_16x16x64_i8::getOperationName();
1303 if (isFp8(sourceAElem) && isFp8(sourceBElem) && destElem.isF32())
1304 return ROCDL::smfmac_f32_16x16x64_fp8_fp8::getOperationName();
1305 if (isFp8(sourceAElem) && isBf8(sourceBElem) && destElem.isF32())
1306 return ROCDL::smfmac_f32_16x16x64_fp8_bf8::getOperationName();
1307 if (isBf8(sourceAElem) && isFp8(sourceBElem) && destElem.isF32())
1308 return ROCDL::smfmac_f32_16x16x64_bf8_fp8::getOperationName();
1309 if (isBf8(sourceAElem) && isBf8(sourceBElem) && destElem.isF32())
1310 return ROCDL::smfmac_f32_16x16x64_bf8_bf8::getOperationName();
1311 }
1312
1313 if (m == 16 && n == 16 && k == 128 && isGfx950) {
1314 if (sourceAElem.isInteger(8) && sourceBElem.isInteger(8) &&
1315 destElem.isInteger(32))
1316 return ROCDL::smfmac_i32_16x16x128_i8::getOperationName();
1317 if (isFp8(sourceAElem) && isFp8(sourceBElem) && destElem.isF32())
1318 return ROCDL::smfmac_f32_16x16x128_fp8_fp8::getOperationName();
1319 if (isFp8(sourceAElem) && isBf8(sourceBElem) && destElem.isF32())
1320 return ROCDL::smfmac_f32_16x16x128_fp8_bf8::getOperationName();
1321 if (isBf8(sourceAElem) && isFp8(sourceBElem) && destElem.isF32())
1322 return ROCDL::smfmac_f32_16x16x128_bf8_fp8::getOperationName();
1323 if (isBf8(sourceAElem) && isBf8(sourceBElem) && destElem.isF32())
1324 return ROCDL::smfmac_f32_16x16x128_bf8_bf8::getOperationName();
1325 }
1326
1327 if (m == 32 && n == 32 && k == 16) {
1328 if (sourceAElem.isF16() && sourceBElem.isF16() && destElem.isF32())
1329 return ROCDL::smfmac_f32_32x32x16_f16::getOperationName();
1330 if (sourceAElem.isBF16() && sourceBElem.isBF16() && destElem.isF32())
1331 return ROCDL::smfmac_f32_32x32x16_bf16::getOperationName();
1332 }
1333
1334 if (m == 32 && n == 32 && k == 32) {
1335 if (isGfx950) {
1336 if (sourceAElem.isF16() && sourceBElem.isF16() && destElem.isF32())
1337 return ROCDL::smfmac_f32_32x32x32_f16::getOperationName();
1338 if (sourceAElem.isBF16() && sourceBElem.isBF16() && destElem.isF32())
1339 return ROCDL::smfmac_f32_32x32x32_bf16::getOperationName();
1340 }
1341 if (sourceAElem.isInteger(8) && sourceBElem.isInteger(8) &&
1342 destElem.isInteger(32))
1343 return ROCDL::smfmac_i32_32x32x32_i8::getOperationName();
1344 if (isFp8(sourceAElem) && isFp8(sourceBElem) && destElem.isF32())
1345 return ROCDL::smfmac_f32_32x32x32_fp8_fp8::getOperationName();
1346 if (isFp8(sourceAElem) && isBf8(sourceBElem) && destElem.isF32())
1347 return ROCDL::smfmac_f32_32x32x32_fp8_bf8::getOperationName();
1348 if (isBf8(sourceAElem) && isFp8(sourceBElem) && destElem.isF32())
1349 return ROCDL::smfmac_f32_32x32x32_bf8_fp8::getOperationName();
1350 if (isBf8(sourceAElem) && isBf8(sourceBElem) && destElem.isF32())
1351 return ROCDL::smfmac_f32_32x32x32_bf8_bf8::getOperationName();
1352 }
1353
1354 if (m == 32 && n == 32 && k == 64 && isGfx950) {
1355 if (sourceAElem.isInteger(8) && sourceBElem.isInteger(8) &&
1356 destElem.isInteger(32))
1357 return ROCDL::smfmac_i32_32x32x64_i8::getOperationName();
1358 if (isFp8(sourceAElem) && isFp8(sourceBElem) && destElem.isF32())
1359 return ROCDL::smfmac_f32_32x32x64_fp8_fp8::getOperationName();
1360 if (isFp8(sourceAElem) && isBf8(sourceBElem) && destElem.isF32())
1361 return ROCDL::smfmac_f32_32x32x64_fp8_bf8::getOperationName();
1362 if (isBf8(sourceAElem) && isFp8(sourceBElem) && destElem.isF32())
1363 return ROCDL::smfmac_f32_32x32x64_bf8_fp8::getOperationName();
1364 if (isBf8(sourceAElem) && isBf8(sourceBElem) && destElem.isF32())
1365 return ROCDL::smfmac_f32_32x32x64_bf8_bf8::getOperationName();
1366 }
1367
1368 return std::nullopt;
1369}
1370
1371/// Returns the `rocdl` intrinsic corresponding to a WMMA operation `wmma`
1372/// if one exists. This includes checking to ensure the intrinsic is supported
1373/// on the architecture you are compiling for.
1374static std::optional<StringRef> wmmaOpToIntrinsic(WMMAOp wmma,
1375 Chipset chipset) {
1376 auto sourceVectorType = cast<VectorType>(wmma.getSourceA().getType());
1377 auto sourceBVectorType = cast<VectorType>(wmma.getSourceB().getType());
1378 auto destVectorType = cast<VectorType>(wmma.getDestC().getType());
1379 Type elemSourceType = sourceVectorType.getElementType();
1380 Type elemBSourceType = sourceBVectorType.getElementType();
1381 Type elemDestType = destVectorType.getElementType();
1382
1383 const uint32_t k = wmma.getK();
1384 const bool isRDNA3 = chipset.majorVersion == 11;
1385 const bool isRDNA4 = chipset.majorVersion == 12 && chipset.minorVersion == 0;
1386
1387 // Handle RDNA3 and RDNA4.
1388 if (isRDNA3 || isRDNA4)
1389 return wmmaOpToIntrinsicRDNA(elemSourceType, elemBSourceType, elemDestType,
1390 k, isRDNA3);
1391
1392 // Handle gfx1250.
1393 if (chipset == kGfx1250)
1394 return wmmaOpToIntrinsicGfx1250(elemSourceType, elemBSourceType,
1395 elemDestType, k);
1396
1397 return std::nullopt;
1398}
1399
1400/// Returns the `rocdl` intrinsic corresponding to a SparseWMMA operation
1401/// `swmmac` if one exists. This includes checking to ensure the intrinsic is
1402/// supported on the architecture you are compiling for.
1404 StringRef name;
1408};
1409
1410static std::optional<SparseWMMAOpInfo>
1411sparseWMMAOpToIntrinsic(SparseWMMAOp swmmac, Chipset chipset) {
1412 Type sourceAElem = getElementTypeOrSelf(swmmac.getSourceA().getType());
1413 Type sourceBElem = getElementTypeOrSelf(swmmac.getSourceB().getType());
1414 Type destElem = getElementTypeOrSelf(swmmac.getDestC().getType());
1415
1416 uint32_t m = swmmac.getM(), n = swmmac.getN(), k = swmmac.getK();
1417
1418 if ((m != 16) || (n != 16))
1419 return std::nullopt;
1420
1421 const bool isRDNA4 = chipset.majorVersion == 12 && chipset.minorVersion == 0;
1422 if (isRDNA4) {
1423 if (k == 32) {
1424 if (destElem.isF32() && sourceAElem.isF16() && sourceBElem.isF16())
1425 return SparseWMMAOpInfo{
1426 ROCDL::swmmac_f32_16x16x32_f16::getOperationName(), false, false,
1427 false};
1428 if (destElem.isF32() && sourceAElem.isBF16() && sourceBElem.isBF16())
1429 return SparseWMMAOpInfo{
1430 ROCDL::swmmac_f32_16x16x32_bf16::getOperationName(), false, false,
1431 false};
1432 if (destElem.isF16() && sourceAElem.isF16() && sourceBElem.isF16())
1433 return SparseWMMAOpInfo{
1434 ROCDL::swmmac_f16_16x16x32_f16::getOperationName(), false, false,
1435 false};
1436 if (destElem.isBF16() && sourceAElem.isBF16() && sourceBElem.isBF16())
1437 return SparseWMMAOpInfo{
1438 ROCDL::swmmac_bf16_16x16x32_bf16::getOperationName(), false, false,
1439 false};
1440 if (destElem.isInteger(32) && sourceAElem.isInteger(8) &&
1441 sourceBElem.isInteger(8))
1442 return SparseWMMAOpInfo{
1443 ROCDL::swmmac_i32_16x16x32_iu8::getOperationName(), true, false,
1444 true};
1445 if (destElem.isInteger(32) && sourceAElem.isInteger(4) &&
1446 sourceBElem.isInteger(4))
1447 return SparseWMMAOpInfo{
1448 ROCDL::swmmac_i32_16x16x32_iu4::getOperationName(), true, false,
1449 true};
1450 if (destElem.isF32() && sourceAElem.isF8E4M3FN() &&
1451 sourceBElem.isF8E4M3FN())
1452 return SparseWMMAOpInfo{
1453 ROCDL::swmmac_f32_16x16x32_fp8_fp8::getOperationName(), false,
1454 false, false};
1455 if (destElem.isF32() && sourceAElem.isF8E4M3FN() &&
1456 sourceBElem.isF8E5M2())
1457 return SparseWMMAOpInfo{
1458 ROCDL::swmmac_f32_16x16x32_fp8_bf8::getOperationName(), false,
1459 false, false};
1460 if (destElem.isF32() && sourceAElem.isF8E5M2() &&
1461 sourceBElem.isF8E4M3FN())
1462 return SparseWMMAOpInfo{
1463 ROCDL::swmmac_f32_16x16x32_bf8_fp8::getOperationName(), false,
1464 false, false};
1465 if (destElem.isF32() && sourceAElem.isF8E5M2() && sourceBElem.isF8E5M2())
1466 return SparseWMMAOpInfo{
1467 ROCDL::swmmac_f32_16x16x32_bf8_bf8::getOperationName(), false,
1468 false, false};
1469 }
1470 if (k == 64) {
1471 if (destElem.isInteger(32) && sourceAElem.isInteger(4) &&
1472 sourceBElem.isInteger(4))
1473 return SparseWMMAOpInfo{
1474 ROCDL::swmmac_i32_16x16x64_iu4::getOperationName(), true, false,
1475 true};
1476 }
1477 }
1478
1479 const bool isGFX1250 = chipset == kGfx1250;
1480 const bool isWavesize64 = swmmac.getWave64();
1481 if (isGFX1250 && !isWavesize64) {
1482 if (k == 64) {
1483 if (destElem.isF32() && sourceAElem.isF16() && sourceBElem.isF16())
1484 return SparseWMMAOpInfo{
1485 ROCDL::swmmac_f32_16x16x64_f16::getOperationName(), true, true,
1486 false};
1487 if (destElem.isF32() && sourceAElem.isBF16() && sourceBElem.isBF16())
1488 return SparseWMMAOpInfo{
1489 ROCDL::swmmac_f32_16x16x64_bf16::getOperationName(), true, true,
1490 false};
1491 if (destElem.isF16() && sourceAElem.isF16() && sourceBElem.isF16())
1492 return SparseWMMAOpInfo{
1493 ROCDL::swmmac_f16_16x16x64_f16::getOperationName(), true, true,
1494 false};
1495 if (destElem.isBF16() && sourceAElem.isBF16() && sourceBElem.isBF16())
1496 return SparseWMMAOpInfo{
1497 ROCDL::swmmac_bf16_16x16x64_bf16::getOperationName(), true, true,
1498 false};
1499 }
1500 if (k == 128) {
1501 if (destElem.isF32() && sourceAElem.isF8E4M3FN() &&
1502 sourceBElem.isF8E4M3FN())
1503 return SparseWMMAOpInfo{
1504 ROCDL::swmmac_f32_16x16x128_fp8_fp8::getOperationName(), false,
1505 true, false};
1506 if (destElem.isF32() && sourceAElem.isF8E4M3FN() &&
1507 sourceBElem.isF8E5M2())
1508 return SparseWMMAOpInfo{
1509 ROCDL::swmmac_f32_16x16x128_fp8_bf8::getOperationName(), false,
1510 true, false};
1511 if (destElem.isF32() && sourceAElem.isF8E5M2() &&
1512 sourceBElem.isF8E4M3FN())
1513 return SparseWMMAOpInfo{
1514 ROCDL::swmmac_f32_16x16x128_bf8_fp8::getOperationName(), false,
1515 true, false};
1516 if (destElem.isF32() && sourceAElem.isF8E5M2() && sourceBElem.isF8E5M2())
1517 return SparseWMMAOpInfo{
1518 ROCDL::swmmac_f32_16x16x128_bf8_bf8::getOperationName(), false,
1519 true, false};
1520 if (destElem.isF16() && sourceAElem.isF8E4M3FN() &&
1521 sourceBElem.isF8E4M3FN())
1522 return SparseWMMAOpInfo{
1523 ROCDL::swmmac_f16_16x16x128_fp8_fp8::getOperationName(), false,
1524 true, false};
1525 if (destElem.isF16() && sourceAElem.isF8E4M3FN() &&
1526 sourceBElem.isF8E5M2())
1527 return SparseWMMAOpInfo{
1528 ROCDL::swmmac_f16_16x16x128_fp8_bf8::getOperationName(), false,
1529 true, false};
1530 if (destElem.isF16() && sourceAElem.isF8E5M2() &&
1531 sourceBElem.isF8E4M3FN())
1532 return SparseWMMAOpInfo{
1533 ROCDL::swmmac_f16_16x16x128_bf8_fp8::getOperationName(), false,
1534 true, false};
1535 if (destElem.isF16() && sourceAElem.isF8E5M2() && sourceBElem.isF8E5M2())
1536 return SparseWMMAOpInfo{
1537 ROCDL::swmmac_f16_16x16x128_bf8_bf8::getOperationName(), false,
1538 true, false};
1539 if (destElem.isF16() && sourceAElem.isInteger(8) &&
1540 sourceBElem.isInteger(8))
1541 return SparseWMMAOpInfo{
1542 ROCDL::swmmac_f16_16x16x128_bf8_bf8::getOperationName(), false,
1543 true, false};
1544 if (destElem.isInteger(32) && sourceAElem.isInteger(8) &&
1545 sourceBElem.isInteger(8))
1546 return SparseWMMAOpInfo{
1547 ROCDL::swmmac_i32_16x16x128_iu8::getOperationName(), true, true,
1548 true};
1549 }
1550 }
1551
1552 return std::nullopt;
1553}
1554
1555namespace {
1556struct MFMAOpLowering : public ConvertOpToLLVMPattern<MFMAOp> {
1557 MFMAOpLowering(const LLVMTypeConverter &converter, Chipset chipset)
1558 : ConvertOpToLLVMPattern<MFMAOp>(converter), chipset(chipset) {}
1559
1560 Chipset chipset;
1561
1562 LogicalResult
1563 matchAndRewrite(MFMAOp op, MFMAOpAdaptor adaptor,
1564 ConversionPatternRewriter &rewriter) const override {
1565 Location loc = op.getLoc();
1566 Type outType = typeConverter->convertType(op.getDestD().getType());
1567 Type intrinsicOutType = outType;
1568 if (auto outVecType = dyn_cast<VectorType>(outType))
1569 if (outVecType.getElementType().isBF16())
1570 intrinsicOutType = outVecType.clone(rewriter.getI16Type());
1571
1572 if (chipset.majorVersion != 9 || chipset < kGfx908)
1573 return op->emitOpError("MFMA only supported on gfx908+");
1574 uint32_t getBlgpField = static_cast<uint32_t>(op.getBlgp());
1575 if (op.getNegateA() || op.getNegateB() || op.getNegateC()) {
1576 if (chipset < kGfx942)
1577 return op.emitOpError("negation unsupported on older than gfx942");
1578 getBlgpField |=
1579 op.getNegateA() | (op.getNegateB() << 1) | (op.getNegateC() << 2);
1580 }
1581 std::optional<StringRef> maybeIntrinsic = mfmaOpToIntrinsic(op, chipset);
1582 std::optional<std::tuple<StringRef, uint32_t, uint32_t>>
1583 maybeScaledIntrinsic = mfmaOpToScaledIntrinsic(op, chipset);
1584 if (!maybeIntrinsic.has_value() && !maybeScaledIntrinsic.has_value())
1585 return op.emitOpError("no intrinsic matching MFMA size on given chipset");
1586
1587 bool isScaled =
1588 !maybeIntrinsic.has_value() && maybeScaledIntrinsic.has_value();
1589 if (isScaled &&
1590 (adaptor.getAbid() > 0 || getBlgpField > 0 || op.getCbsz() > 0)) {
1591 return op.emitOpError(
1592 "non-default abid, blgp, and cbsz aren't supported on MFMAs that can "
1593 "be scaled as those fields are used for type information");
1594 }
1595
1596 StringRef intrinsicName =
1597 isScaled ? std::get<0>(*maybeScaledIntrinsic) : *maybeIntrinsic;
1598 // Determine if we can use bf16 in the intrinsic. Newer MFMAs in gfx950+
1599 // allows bf16 as the input. For reference check IntrinsicsAMDGPU.td file.
1600 bool allowBf16 = [&]() {
1601 if (chipset < kGfx950)
1602 return false;
1603 if (isScaled)
1604 return true;
1605 return intrinsicName.contains("16x16x32.bf16") ||
1606 intrinsicName.contains("32x32x16.bf16");
1607 }();
1608 OperationState loweredOp(loc, intrinsicName);
1609 loweredOp.addTypes(intrinsicOutType);
1610 loweredOp.addOperands({packSmallFloatVectorOperand(
1611 rewriter, loc, adaptor.getSourceA(), allowBf16),
1613 rewriter, loc, adaptor.getSourceB(), allowBf16),
1614 adaptor.getDestC()});
1615 if (isScaled) {
1616 Value zero = createI32Constant(rewriter, loc, 0);
1617 auto [_scaledName, aTypeCode, bTypeCode] = *maybeScaledIntrinsic;
1618 loweredOp.addOperands({/*scale A=*/zero, /*scale B=*/zero});
1619 loweredOp.addAttributes({{"cbsz", rewriter.getI32IntegerAttr(aTypeCode)},
1620 {"blgp", rewriter.getI32IntegerAttr(bTypeCode)},
1621 {"opselA", rewriter.getI32IntegerAttr(0)},
1622 {"opselB", rewriter.getI32IntegerAttr(0)}});
1623 } else {
1624 loweredOp.addAttributes(
1625 {{"cbsz", rewriter.getI32IntegerAttr(op.getCbsz())},
1626 {"abid", rewriter.getI32IntegerAttr(op.getAbid())},
1627 {"blgp", rewriter.getI32IntegerAttr(getBlgpField)}});
1628 };
1629 Value lowered = rewriter.create(loweredOp)->getResult(0);
1630 if (outType != intrinsicOutType)
1631 lowered = LLVM::BitcastOp::create(rewriter, loc, outType, lowered);
1632 rewriter.replaceOp(op, lowered);
1633 return success();
1634 }
1635};
1636
1637struct ScaledMFMAOpLowering : public ConvertOpToLLVMPattern<ScaledMFMAOp> {
1638 ScaledMFMAOpLowering(const LLVMTypeConverter &converter, Chipset chipset)
1639 : ConvertOpToLLVMPattern(converter), chipset(chipset) {}
1640
1641 Chipset chipset;
1642
1643 LogicalResult
1644 matchAndRewrite(ScaledMFMAOp op, ScaledMFMAOpAdaptor adaptor,
1645 ConversionPatternRewriter &rewriter) const override {
1646 Location loc = op.getLoc();
1647 Type intrinsicOutType = typeConverter->convertType(op.getDestD().getType());
1648
1649 if (chipset.majorVersion != 9 || chipset < kGfx950)
1650 return op->emitOpError("scaled MFMA only supported on gfx908+");
1651 std::optional<std::tuple<StringRef, uint32_t, uint32_t>>
1652 maybeScaledIntrinsic = mfmaOpToScaledIntrinsic(op, chipset);
1653 if (!maybeScaledIntrinsic.has_value())
1654 return op.emitOpError(
1655 "no intrinsic matching scaled MFMA size on given chipset");
1656
1657 auto [intrinsicName, aTypeCode, bTypeCode] = *maybeScaledIntrinsic;
1658 OperationState loweredOp(loc, intrinsicName);
1659 loweredOp.addTypes(intrinsicOutType);
1660 loweredOp.addOperands(
1661 {packSmallFloatVectorOperand(rewriter, loc, adaptor.getSourceA()),
1662 packSmallFloatVectorOperand(rewriter, loc, adaptor.getSourceB()),
1663 adaptor.getDestC()});
1664 loweredOp.addOperands(
1665 {/*scales A*/
1666 castScaleOperand(rewriter, loc, adaptor.getScalesA()),
1667 /*scales B*/
1668 castScaleOperand(rewriter, loc, adaptor.getScalesB())});
1669 loweredOp.addAttributes(
1670 {{"cbsz", rewriter.getI32IntegerAttr(aTypeCode)},
1671 {"blgp", rewriter.getI32IntegerAttr(bTypeCode)},
1672 {"opselA", rewriter.getI32IntegerAttr(adaptor.getScalesIdxA())},
1673 {"opselB", rewriter.getI32IntegerAttr(adaptor.getScalesIdxB())}});
1674
1675 Value lowered = rewriter.create(loweredOp)->getResult(0);
1676 rewriter.replaceOp(op, lowered);
1677 return success();
1678 }
1679};
1680
1681struct SparseMFMAOpLowering : public ConvertOpToLLVMPattern<SparseMFMAOp> {
1682 SparseMFMAOpLowering(const LLVMTypeConverter &converter, Chipset chipset)
1683 : ConvertOpToLLVMPattern<SparseMFMAOp>(converter), chipset(chipset) {}
1684
1685 Chipset chipset;
1686
1687 LogicalResult
1688 matchAndRewrite(SparseMFMAOp op, SparseMFMAOpAdaptor adaptor,
1689 ConversionPatternRewriter &rewriter) const override {
1690 Location loc = op.getLoc();
1691 auto outType =
1692 typeConverter->convertType<VectorType>(op.getDestC().getType());
1693 if (!outType)
1694 return rewriter.notifyMatchFailure(op, "type conversion failed");
1695
1696 // smfmac is supported on gfx942 and gfx950.
1697 if (chipset.majorVersion != 9 || chipset < kGfx942)
1698 return op->emitOpError("sparse MFMA (smfmac) only supported on gfx942+");
1699 bool isGfx950 = chipset >= kGfx950;
1700
1701 Value a = convertPackedVectorOperand(rewriter, loc, adaptor.getSourceA(),
1702 isGfx950);
1703 Value b = convertPackedVectorOperand(rewriter, loc, adaptor.getSourceB(),
1704 isGfx950);
1705 Value c = adaptor.getDestC();
1706
1707 std::optional<StringRef> maybeIntrinsic = smfmacOpToIntrinsic(op, chipset);
1708 if (!maybeIntrinsic.has_value())
1709 return op.emitOpError(
1710 "no intrinsic matching sparse MFMA on the given chipset");
1711
1712 // Bitcast sparse indices from vector<4xi8> or vector<2xi16> to i32.
1713 // gfx950 8-bit variants already carry the index as i32; skip the bitcast.
1714 Value sparseIdx = adaptor.getSparseIdx();
1715 Type i32Type = rewriter.getI32Type();
1716 if (sparseIdx.getType() != i32Type)
1717 sparseIdx = LLVM::BitcastOp::create(rewriter, loc, i32Type, sparseIdx);
1718
1719 OperationState loweredOp(loc, maybeIntrinsic.value());
1720 loweredOp.addTypes(outType);
1721 loweredOp.addOperands({a, b, c, sparseIdx});
1722 loweredOp.addAttributes(
1723 {{"cbsz", rewriter.getI32IntegerAttr(op.getCbsz())},
1724 {"abid", rewriter.getI32IntegerAttr(op.getAbid())}});
1725 Value lowered = rewriter.create(loweredOp)->getResult(0);
1726 rewriter.replaceOp(op, lowered);
1727 return success();
1728 }
1729};
1730
1731struct WMMAOpLowering : public ConvertOpToLLVMPattern<WMMAOp> {
1732 WMMAOpLowering(const LLVMTypeConverter &converter, Chipset chipset)
1733 : ConvertOpToLLVMPattern<WMMAOp>(converter), chipset(chipset) {}
1734
1735 Chipset chipset;
1736
1737 LogicalResult
1738 matchAndRewrite(WMMAOp op, WMMAOpAdaptor adaptor,
1739 ConversionPatternRewriter &rewriter) const override {
1740 Location loc = op.getLoc();
1741 auto outType =
1742 typeConverter->convertType<VectorType>(op.getDestD().getType());
1743 if (!outType)
1744 return rewriter.notifyMatchFailure(op, "type conversion failed");
1745
1746 if (chipset.majorVersion != 11 && chipset.majorVersion != 12)
1747 return op->emitOpError("WMMA only supported on gfx11 and gfx12");
1748
1749 bool isGFX1250 = chipset >= kGfx1250;
1750
1751 // The WMMA operations represent vectors of bf16s as vectors of i16s
1752 // (except on gfx1250), so we need to bitcast bfloats to i16 and then
1753 // bitcast them back.
1754 auto aType = cast<VectorType>(adaptor.getSourceA().getType());
1755 auto bType = cast<VectorType>(adaptor.getSourceB().getType());
1756 auto destCType = cast<VectorType>(adaptor.getDestC().getType());
1757 bool castAToI16 = aType.getElementType().isBF16() && !isGFX1250;
1758 bool castBToI16 = bType.getElementType().isBF16() && !isGFX1250;
1759 bool castDestCToI16 = destCType.getElementType().isBF16() && !isGFX1250;
1760 bool castOutToI16 = outType.getElementType().isBF16() && !isGFX1250;
1761 VectorType rawOutType = outType;
1762 if (castOutToI16)
1763 rawOutType = outType.clone(rewriter.getI16Type());
1764 Value a = adaptor.getSourceA();
1765 if (castAToI16)
1766 a = LLVM::BitcastOp::create(rewriter, loc,
1767 aType.clone(rewriter.getI16Type()), a);
1768 Value b = adaptor.getSourceB();
1769 if (castBToI16)
1770 b = LLVM::BitcastOp::create(rewriter, loc,
1771 bType.clone(rewriter.getI16Type()), b);
1772 Value destC = adaptor.getDestC();
1773 if (castDestCToI16)
1774 destC = LLVM::BitcastOp::create(
1775 rewriter, loc, destCType.clone(rewriter.getI16Type()), destC);
1776
1777 std::optional<StringRef> maybeIntrinsic = wmmaOpToIntrinsic(op, chipset);
1778
1779 if (!maybeIntrinsic.has_value())
1780 return op.emitOpError("no intrinsic matching WMMA on the given chipset");
1781
1782 if (chipset.majorVersion >= 12 && op.getSubwordOffset() != 0)
1783 return op.emitOpError("subwordOffset not supported on gfx12+");
1784
1785 SmallVector<Value, 4> operands;
1786 SmallVector<NamedAttribute, 4> attrs;
1787 wmmaPushInputOperand(rewriter, loc, typeConverter, op.getUnsignedA(), a,
1788 op.getSourceA(), operands, attrs, "signA");
1789 wmmaPushInputOperand(rewriter, loc, typeConverter, op.getUnsignedB(), b,
1790 op.getSourceB(), operands, attrs, "signB");
1791 wmmaPushOutputOperand(rewriter, loc, typeConverter, destC,
1792 op.getSubwordOffset(), op.getClamp(), operands,
1793 attrs);
1794
1795 OperationState loweredOp(loc, *maybeIntrinsic);
1796 loweredOp.addTypes(rawOutType);
1797 loweredOp.addOperands(operands);
1798 loweredOp.addAttributes(attrs);
1799 Operation *lowered = rewriter.create(loweredOp);
1800
1801 Operation *maybeCastBack = lowered;
1802 if (rawOutType != outType)
1803 maybeCastBack = LLVM::BitcastOp::create(rewriter, loc, outType,
1804 lowered->getResult(0));
1805 rewriter.replaceOp(op, maybeCastBack->getResults());
1806
1807 return success();
1808 }
1809};
1810
1811enum class DotFamily {
1812 /// ROCDL_Dot_IntrOp: single `clamp` attribute.
1813 Clamp,
1814 /// ROCDL_Dot_NoClamp_IntrOp: no attributes.
1815 NoClamp,
1816 /// ROCDL_Sudot_IntrOp: `signA`, `signB`, and `clamp` attributes.
1817 Sudot,
1818};
1819
1820static std::optional<std::pair<StringRef, DotFamily>>
1821dotOpToIntrinsic(DotOp op, Chipset chipset) {
1822 Type aElem = cast<VectorType>(op.getSourceA().getType()).getElementType();
1823 Type bElem = cast<VectorType>(op.getSourceB().getType()).getElementType();
1824 Type dest = op.getDestC().getType();
1825 bool uA = op.getUnsignedA();
1826 bool uB = op.getUnsignedB();
1827
1828 // f16 x f16 -> f32 / f16.
1829 if (aElem.isF16() && bElem.isF16()) {
1830 if (dest.isF32() && hasDot10Insts(chipset))
1831 return {{ROCDL::fdot2::getOperationName(), DotFamily::Clamp}};
1832 if (dest.isF16() && hasDot9Insts(chipset))
1833 return {{ROCDL::fdot2_f16_f16::getOperationName(), DotFamily::NoClamp}};
1834 return std::nullopt;
1835 }
1836
1837 // bf16 x bf16 -> f32 / bf16.
1838 if (aElem.isBF16() && bElem.isBF16()) {
1839 if (dest.isF32() && hasDot12Insts(chipset))
1840 return {{ROCDL::fdot2_f32_bf16::getOperationName(), DotFamily::Clamp}};
1841 if (dest.isBF16() && hasDot9Insts(chipset))
1842 return {{ROCDL::fdot2_bf16_bf16::getOperationName(), DotFamily::NoClamp}};
1843 return std::nullopt;
1844 }
1845
1846 // Integer sources -> i32.
1847 if (isa<IntegerType>(aElem) && isa<IntegerType>(bElem) &&
1848 dest.isInteger(32)) {
1849 bool mixedSign = (uA != uB);
1850 unsigned elemWidth = aElem.getIntOrFloatBitWidth();
1851
1852 if (mixedSign) {
1853 if (!hasDot8Insts(chipset))
1854 return std::nullopt;
1855 StringRef name;
1856 switch (elemWidth) {
1857 case 8:
1858 name = ROCDL::sudot4::getOperationName();
1859 break;
1860 case 4:
1861 name = ROCDL::sudot8::getOperationName();
1862 break;
1863 default:
1864 return std::nullopt;
1865 }
1866 return {{name, DotFamily::Sudot}};
1867 }
1868
1869 StringRef name;
1870 bool supported = false;
1871 switch (elemWidth) {
1872 case 16:
1873 supported = hasDot2Insts(chipset);
1874 name = uA ? ROCDL::udot2::getOperationName()
1875 : ROCDL::sdot2::getOperationName();
1876 break;
1877 case 8:
1878 supported = uA ? hasDot7Insts(chipset)
1879 : hasDot1Insts(chipset) || hasDot8Insts(chipset);
1880 name = uA ? ROCDL::udot4::getOperationName()
1881 : ROCDL::sdot4::getOperationName();
1882 break;
1883 case 4:
1884 supported = uA ? hasDot7Insts(chipset)
1885 : hasDot1Insts(chipset) || hasDot8Insts(chipset);
1886 name = uA ? ROCDL::udot8::getOperationName()
1887 : ROCDL::sdot8::getOperationName();
1888 break;
1889 default:
1890 return std::nullopt;
1891 }
1892 if (!supported)
1893 return std::nullopt;
1894 return {{name, DotFamily::Clamp}};
1895 }
1896
1897 // fp8/bf8 x fp8/bf8 -> f32.
1898 bool aIsFp8 = isa<Float8E4M3FNType>(aElem);
1899 bool aIsBf8 = isa<Float8E5M2Type>(aElem);
1900 bool bIsFp8 = isa<Float8E4M3FNType>(bElem);
1901 bool bIsBf8 = isa<Float8E5M2Type>(bElem);
1902 if ((aIsFp8 || aIsBf8) && (bIsFp8 || bIsBf8) && dest.isF32()) {
1903 if (!hasDot11Insts(chipset))
1904 return std::nullopt;
1905 StringRef name;
1906 if (aIsFp8 && bIsFp8)
1907 name = ROCDL::dot4_f32_fp8_fp8::getOperationName();
1908 else if (aIsFp8 && bIsBf8)
1909 name = ROCDL::dot4_f32_fp8_bf8::getOperationName();
1910 else if (aIsBf8 && bIsFp8)
1911 name = ROCDL::dot4_f32_bf8_fp8::getOperationName();
1912 else
1913 name = ROCDL::dot4_f32_bf8_bf8::getOperationName();
1914 return {{name, DotFamily::NoClamp}};
1915 }
1916
1917 return std::nullopt;
1918}
1919
1920struct DotOpLowering : public ConvertOpToLLVMPattern<DotOp> {
1921 DotOpLowering(const LLVMTypeConverter &converter, Chipset chipset)
1922 : ConvertOpToLLVMPattern<DotOp>(converter), chipset(chipset) {}
1923
1924 Chipset chipset;
1925
1926 LogicalResult
1927 matchAndRewrite(DotOp op, DotOpAdaptor adaptor,
1928 ConversionPatternRewriter &rewriter) const override {
1929 Location loc = op.getLoc();
1930
1931 std::optional<std::pair<StringRef, DotFamily>> maybeIntrinsic =
1932 dotOpToIntrinsic(op, chipset);
1933 if (!maybeIntrinsic)
1934 return op.emitOpError("no intrinsic matching dot on the given chipset: ")
1935 << op.getSourceA().getType() << " * " << op.getSourceB().getType()
1936 << " + " << op.getDestC().getType();
1937
1938 auto [intrinsicName, family] = maybeIntrinsic.value();
1939
1940 Value a = convertPackedVectorOperand(rewriter, loc, adaptor.getSourceA());
1941 Value b = convertPackedVectorOperand(rewriter, loc, adaptor.getSourceB());
1942 Value c = adaptor.getDestC();
1943
1944 SmallVector<NamedAttribute, 3> attrs;
1945 if (family == DotFamily::Sudot) {
1946 attrs.push_back(rewriter.getNamedAttr(
1947 "signA", rewriter.getBoolAttr(!op.getUnsignedA())));
1948 attrs.push_back(rewriter.getNamedAttr(
1949 "signB", rewriter.getBoolAttr(!op.getUnsignedB())));
1950 }
1951
1952 if (family != DotFamily::NoClamp && op.getClamp())
1953 attrs.push_back(
1954 rewriter.getNamedAttr("clamp", rewriter.getBoolAttr(true)));
1955
1956 Type resultType = typeConverter->convertType(op.getDestD().getType());
1957
1958 OperationState loweredOp(loc, intrinsicName);
1959 loweredOp.addTypes(resultType);
1960 loweredOp.addOperands({a, b, c});
1961 loweredOp.addAttributes(attrs);
1962 Operation *lowered = rewriter.create(loweredOp);
1963 rewriter.replaceOp(op, lowered->getResults());
1964 return success();
1965 }
1966};
1967
1968struct SparseWMMAOpLowering : public ConvertOpToLLVMPattern<SparseWMMAOp> {
1969 SparseWMMAOpLowering(const LLVMTypeConverter &converter, Chipset chipset)
1970 : ConvertOpToLLVMPattern<SparseWMMAOp>(converter), chipset(chipset) {}
1971
1972 Chipset chipset;
1973
1974 LogicalResult
1975 matchAndRewrite(SparseWMMAOp op, SparseWMMAOpAdaptor adaptor,
1976 ConversionPatternRewriter &rewriter) const override {
1977 Location loc = op.getLoc();
1978 auto outType =
1979 typeConverter->convertType<VectorType>(op.getDestD().getType());
1980 if (!outType)
1981 return rewriter.notifyMatchFailure(op, "type conversion failed");
1982
1983 std::optional<SparseWMMAOpInfo> maybeIntrinsic =
1984 sparseWMMAOpToIntrinsic(op, chipset);
1985
1986 if (!maybeIntrinsic.has_value())
1987 return op.emitOpError(
1988 "no intrinsic matching Sparse WMMA on the given chipset");
1989 SparseWMMAOpInfo intrinsic = maybeIntrinsic.value();
1990
1991 SmallVector<NamedAttribute> attrs;
1992
1993 if ((op.getUnsignedA() || op.getUnsignedB()) && !intrinsic.useSign)
1994 return op->emitOpError("intrinsic doesn't support unsign");
1995 if (intrinsic.useSign) {
1996 if (auto attr = op.getUnsignedAAttr())
1997 attrs.push_back({"signA", attr});
1998 if (auto attr = op.getUnsignedBAttr())
1999 attrs.push_back({"signB", attr});
2000 }
2001
2002 if ((op.getReuseA() || op.getReuseB()) && !intrinsic.useReuse)
2003 return op->emitOpError("intrinsic doesn't support reuse");
2004 if (intrinsic.useReuse) {
2005 if (auto attr = op.getReuseAAttr())
2006 attrs.push_back({"reuseA", attr});
2007 if (auto attr = op.getReuseBAttr())
2008 attrs.push_back({"reuseB", attr});
2009 }
2010
2011 if (op.getClamp() && !intrinsic.useClamp)
2012 return op->emitOpError("intrinsic doesn't support clamp");
2013 if (intrinsic.useClamp && op.getClampAttr())
2014 attrs.push_back({"clamp", op.getClampAttr()});
2015
2016 const bool isGFX1250orHigher =
2017 chipset.majorVersion == 12 && chipset.minorVersion >= 5;
2018 Value a = convertPackedVectorOperand(rewriter, loc, adaptor.getSourceA(),
2019 isGFX1250orHigher);
2020 Value b = convertPackedVectorOperand(rewriter, loc, adaptor.getSourceB(),
2021 isGFX1250orHigher);
2022 Value c = adaptor.getDestC();
2023 VectorType rawOutType = outType;
2024 if (!isGFX1250orHigher) {
2025 c = convertPackedVectorOperand(rewriter, loc, adaptor.getDestC(), false);
2026 rawOutType = cast<VectorType>(c.getType());
2027 }
2028
2029 // Bitcast sparse indices from vector<4xi8> to i32.
2030 Value sparseIdx = LLVM::BitcastOp::create(
2031 rewriter, loc, rewriter.getI32Type(), adaptor.getSparseIdx());
2032
2033 OperationState loweredOp(loc, intrinsic.name);
2034 loweredOp.addTypes(rawOutType);
2035 loweredOp.addOperands({a, b, c, sparseIdx});
2036 loweredOp.addAttributes(attrs);
2037 Operation *lowered = rewriter.create(loweredOp);
2038
2039 Operation *maybeCastBack = lowered;
2040 if (rawOutType != outType)
2041 maybeCastBack = LLVM::BitcastOp::create(rewriter, loc, outType,
2042 lowered->getResult(0));
2043 rewriter.replaceOp(op, maybeCastBack->getResults());
2044
2045 return success();
2046 }
2047};
2048
2049struct ScaledWMMAOpLowering : public ConvertOpToLLVMPattern<ScaledWMMAOp> {
2050 ScaledWMMAOpLowering(const LLVMTypeConverter &converter, Chipset chipset)
2051 : ConvertOpToLLVMPattern<ScaledWMMAOp>(converter), chipset(chipset) {}
2052
2053 Chipset chipset;
2054
2055 LogicalResult
2056 matchAndRewrite(ScaledWMMAOp op, ScaledWMMAOpAdaptor adaptor,
2057 ConversionPatternRewriter &rewriter) const override {
2058 Location loc = op.getLoc();
2059 auto outType =
2060 typeConverter->convertType<VectorType>(op.getDestD().getType());
2061 if (!outType)
2062 return rewriter.notifyMatchFailure(op, "type conversion failed");
2063
2064 if (chipset < kGfx1250)
2065 return op->emitOpError("WMMA scale only supported on gfx1250+");
2066
2067 int64_t m = op.getM();
2068 int64_t n = op.getN();
2069 int64_t k = op.getK();
2070
2071 Type aElemType = getElementTypeOrSelf(op.getSourceA().getType());
2072 Type bElemType = getElementTypeOrSelf(op.getSourceB().getType());
2073
2074 std::optional<uint32_t> aFmtCode = smallFloatTypeToFormatCode(aElemType);
2075 std::optional<uint32_t> bFmtCode = smallFloatTypeToFormatCode(bElemType);
2076
2077 if (!aFmtCode || !bFmtCode)
2078 return op.emitOpError("unsupported element types for scaled_wmma");
2079
2080 // Get scale vector types and determine variant (scale vs scale16).
2081 auto scaleAVecType = cast<VectorType>(op.getScaleA().getType());
2082 auto scaleBVecType = cast<VectorType>(op.getScaleB().getType());
2083
2084 if (scaleAVecType.getNumElements() != scaleBVecType.getNumElements())
2085 return op.emitOpError("scaleA and scaleB must have equal vector length");
2086
2087 // Extract scale format from element types.
2088 Type scaleAElemType = scaleAVecType.getElementType();
2089 Type scaleBElemType = scaleBVecType.getElementType();
2090
2091 std::optional<uint32_t> scaleAFmt = getWmmaScaleFormat(scaleAElemType);
2092 std::optional<uint32_t> scaleBFmt = getWmmaScaleFormat(scaleBElemType);
2093
2094 if (!scaleAFmt || !scaleBFmt)
2095 return op.emitOpError("unsupported scale element types");
2096
2097 // Determine which intrinsic to use based on dimensions.
2098 bool isScale16 = (scaleAVecType.getNumElements() == 8);
2099 std::optional<StringRef> intrinsicName =
2100 getScaledWmmaIntrinsicName(m, n, k, isScale16);
2101 if (!intrinsicName)
2102 return op.emitOpError("unsupported scaled_wmma dimensions: ")
2103 << m << "x" << n << "x" << k;
2104
2105 SmallVector<NamedAttribute, 8> attrs;
2106
2107 // The f4 variant does not have fmtA and fmtB attributes.
2108 bool is32x16 = (m == 32 && n == 16 && k == 128);
2109 if (!is32x16) {
2110 attrs.emplace_back("fmtA", rewriter.getI32IntegerAttr(*aFmtCode));
2111 attrs.emplace_back("fmtB", rewriter.getI32IntegerAttr(*bFmtCode));
2112 }
2113
2114 // modC uses default value of 0.
2115 attrs.emplace_back("modC", rewriter.getI16IntegerAttr(0));
2116
2117 // Scale attributes. Convert user-facing firstScaleLane (0 or 16) to the
2118 // half of the wave that is being selected (0 or 1).
2119 attrs.emplace_back(
2120 "scaleAType", rewriter.getI32IntegerAttr(op.getAFirstScaleLane() / 16));
2121 attrs.emplace_back("fmtScaleA", rewriter.getI32IntegerAttr(*scaleAFmt));
2122 attrs.emplace_back(
2123 "scaleBType", rewriter.getI32IntegerAttr(op.getBFirstScaleLane() / 16));
2124 attrs.emplace_back("fmtScaleB", rewriter.getI32IntegerAttr(*scaleBFmt));
2125
2126 // Reuse flags use default value of false.
2127 attrs.emplace_back("reuseA", rewriter.getBoolAttr(false));
2128 attrs.emplace_back("reuseB", rewriter.getBoolAttr(false));
2129
2130 // Convert typed float vectors to packed format.
2131 Value sourceA =
2132 packSmallFloatVectorOperand(rewriter, loc, adaptor.getSourceA());
2133 Value sourceB =
2134 packSmallFloatVectorOperand(rewriter, loc, adaptor.getSourceB());
2135
2136 // Pack scale vectors into i32/i64.
2137 Value packedScaleA = castScaleOperand(rewriter, loc, adaptor.getScaleA());
2138 Value packedScaleB = castScaleOperand(rewriter, loc, adaptor.getScaleB());
2139
2140 // Create the intrinsic call.
2141 OperationState loweredOp(loc, *intrinsicName);
2142 loweredOp.addTypes(outType);
2143 loweredOp.addOperands(
2144 {sourceA, sourceB, adaptor.getDestC(), packedScaleA, packedScaleB});
2145 loweredOp.addAttributes(attrs);
2146
2147 Operation *lowered = rewriter.create(loweredOp);
2148 rewriter.replaceOp(op, lowered->getResults());
2149
2150 return success();
2151 }
2152};
2153
2154struct TransposeLoadOpLowering
2155 : public ConvertOpToLLVMPattern<TransposeLoadOp> {
2156 TransposeLoadOpLowering(const LLVMTypeConverter &converter, Chipset chipset)
2157 : ConvertOpToLLVMPattern<TransposeLoadOp>(converter), chipset(chipset) {}
2158
2159 Chipset chipset;
2160
2161 LogicalResult
2162 matchAndRewrite(TransposeLoadOp op, TransposeLoadOpAdaptor adaptor,
2163 ConversionPatternRewriter &rewriter) const override {
2164 if (chipset != kGfx950)
2165 return op.emitOpError("Non-gfx950 chipset not supported");
2166
2167 Location loc = op.getLoc();
2168 auto srcMemRefType = cast<MemRefType>(op.getSrc().getType());
2169
2170 // Elements in subbyte memrefs are stored non-contiguously,
2171 // reject if source is sub-byte memref. Use emulated memrefs instead.
2172 size_t srcElementSize =
2173 srcMemRefType.getElementType().getIntOrFloatBitWidth();
2174 if (srcElementSize < 8)
2175 return op.emitOpError("Expect source memref to have at least 8 bits "
2176 "element size, got ")
2177 << srcElementSize;
2178
2179 auto resultType = cast<VectorType>(op.getResult().getType());
2180 Value srcPtr =
2181 getStridedElementPtr(rewriter, loc, srcMemRefType, adaptor.getSrc(),
2182 (adaptor.getSrcIndices()));
2183
2184 size_t numElements = resultType.getNumElements();
2185 size_t elementTypeSize =
2186 resultType.getElementType().getIntOrFloatBitWidth();
2187
2188 // ROCDL transpose load intrinsics return vectors of 32-bit integers, if
2189 // the element size is smaller than 16 bits.
2190 Type rocdlResultType = VectorType::get((numElements * elementTypeSize) / 32,
2191 rewriter.getIntegerType(32));
2192 Type llvmResultType = typeConverter->convertType(resultType);
2193
2194 switch (elementTypeSize) {
2195 case 4: {
2196 assert(numElements == 16);
2197 auto rocdlOp = ROCDL::ds_read_tr4_b64::create(rewriter, loc,
2198 rocdlResultType, srcPtr);
2199 rewriter.replaceOpWithNewOp<LLVM::BitcastOp>(op, llvmResultType, rocdlOp);
2200 break;
2201 }
2202 case 6: {
2203 assert(numElements == 16);
2204 auto rocdlOp = ROCDL::ds_read_tr6_b96::create(rewriter, loc,
2205 rocdlResultType, srcPtr);
2206 rewriter.replaceOpWithNewOp<LLVM::BitcastOp>(op, llvmResultType, rocdlOp);
2207 break;
2208 }
2209 case 8: {
2210 assert(numElements == 8);
2211 auto rocdlOp = ROCDL::ds_read_tr8_b64::create(rewriter, loc,
2212 rocdlResultType, srcPtr);
2213 rewriter.replaceOpWithNewOp<LLVM::BitcastOp>(op, llvmResultType, rocdlOp);
2214 break;
2215 }
2216 case 16: {
2217 assert(numElements == 4);
2218 rewriter.replaceOpWithNewOp<ROCDL::ds_read_tr16_b64>(op, llvmResultType,
2219 srcPtr);
2220 break;
2221 }
2222 default:
2223 return op.emitOpError("Unsupported element size for transpose load");
2224 }
2225 return success();
2226 }
2227};
2228
2229struct GatherToLDSOpLowering : public ConvertOpToLLVMPattern<GatherToLDSOp> {
2230 GatherToLDSOpLowering(const LLVMTypeConverter &converter, Chipset chipset)
2231 : ConvertOpToLLVMPattern<GatherToLDSOp>(converter), chipset(chipset) {}
2232
2233 Chipset chipset;
2234
2235 LogicalResult
2236 matchAndRewrite(GatherToLDSOp op, GatherToLDSOpAdaptor adaptor,
2237 ConversionPatternRewriter &rewriter) const override {
2238 if (chipset.majorVersion < 9 || chipset.majorVersion > 10)
2239 return op.emitOpError("pre-gfx9 and post-gfx10 not supported");
2240
2241 Location loc = op.getLoc();
2242
2243 auto srcMemRefType = cast<MemRefType>(op.getSrc().getType());
2244 auto dstMemRefType = cast<MemRefType>(op.getDst().getType());
2245
2246 // TODO: instead of only transfering one element per thread, we could
2247 // augment it to transfer multiple elements per thread by issuing multiple
2248 // `global_load_lds` instructions.
2249 Type transferType = op.getTransferType();
2250 int loadWidth = [&]() -> int {
2251 if (auto transferVectorType = dyn_cast<VectorType>(transferType)) {
2252 return (transferVectorType.getNumElements() *
2253 transferVectorType.getElementTypeBitWidth()) /
2254 8;
2255 }
2256 return transferType.getIntOrFloatBitWidth() / 8;
2257 }();
2258
2259 // Currently only 1, 2, 4, 12 and 16 byte loads are supported.
2260 if (!llvm::is_contained({1, 2, 4, 12, 16}, loadWidth))
2261 return op.emitOpError("chipset unsupported element size");
2262
2263 if (chipset != kGfx950 && llvm::is_contained({12, 16}, loadWidth))
2264 return op.emitOpError("Gather to LDS instructions with 12-byte and "
2265 "16-byte load widths are only supported on gfx950");
2266
2267 Value srcPtr =
2268 getStridedElementPtr(rewriter, loc, srcMemRefType, adaptor.getSrc(),
2269 (adaptor.getSrcIndices()));
2270 Value dstPtr =
2271 getStridedElementPtr(rewriter, loc, dstMemRefType, adaptor.getDst(),
2272 (adaptor.getDstIndices()));
2273
2274 if (op.getAsync()) {
2275 rewriter.replaceOpWithNewOp<ROCDL::LoadAsyncToLDSOp>(
2276 op, srcPtr, dstPtr, rewriter.getI32IntegerAttr(loadWidth),
2277 /*offset=*/rewriter.getI32IntegerAttr(0),
2278 /*aux=*/rewriter.getI32IntegerAttr(0), ArrayAttr{}, ArrayAttr{},
2279 ArrayAttr{});
2280 } else {
2281 rewriter.replaceOpWithNewOp<ROCDL::LoadToLDSOp>(
2282 op, srcPtr, dstPtr, rewriter.getI32IntegerAttr(loadWidth),
2283 /*offset=*/rewriter.getI32IntegerAttr(0),
2284 /*aux=*/rewriter.getI32IntegerAttr(0), ArrayAttr{}, ArrayAttr{},
2285 ArrayAttr{});
2286 }
2287
2288 return success();
2289 }
2290};
2291
2292struct GlobalLoadAsyncToLDSOpLowering
2293 : public ConvertOpToLLVMPattern<GlobalLoadAsyncToLDSOp> {
2294 GlobalLoadAsyncToLDSOpLowering(const LLVMTypeConverter &converter,
2295 Chipset chipset)
2296 : ConvertOpToLLVMPattern<GlobalLoadAsyncToLDSOp>(converter),
2297 chipset(chipset) {}
2298
2299 Chipset chipset;
2300
2301 LogicalResult
2302 matchAndRewrite(GlobalLoadAsyncToLDSOp op,
2303 GlobalLoadAsyncToLDSOpAdaptor adaptor,
2304 ConversionPatternRewriter &rewriter) const override {
2305 if (chipset < kGfx1250)
2306 return op.emitOpError(
2307 "global_load_async_to_lds is only supported on gfx1250+");
2308
2309 Location loc = op.getLoc();
2310 auto srcMemRefType = cast<MemRefType>(op.getSrc().getType());
2311 auto dstMemRefType = cast<MemRefType>(op.getDst().getType());
2312
2313 Type transferType = op.getTransferType();
2314 int transferBits =
2315 isa<VectorType>(transferType)
2316 ? cast<VectorType>(transferType).getNumElements() *
2317 cast<VectorType>(transferType).getElementTypeBitWidth()
2318 : transferType.getIntOrFloatBitWidth();
2319
2320 Value srcPtr =
2321 getStridedElementPtr(rewriter, loc, srcMemRefType, adaptor.getSrc(),
2322 adaptor.getSrcIndices());
2323 Value dstPtr =
2324 getStridedElementPtr(rewriter, loc, dstMemRefType, adaptor.getDst(),
2325 adaptor.getDstIndices());
2326
2327 if (op.getMask()) {
2328 Value mask = adaptor.getMask();
2329 int64_t nullptrVal =
2330 llvm::AMDGPU::getNullPointerValue(llvm::AMDGPUAS::LOCAL_ADDRESS);
2331 Value nullInt =
2332 createI32Constant(rewriter, loc, static_cast<int32_t>(nullptrVal));
2333 Value nullPtr =
2334 LLVM::IntToPtrOp::create(rewriter, loc, dstPtr.getType(), nullInt);
2335 dstPtr = LLVM::SelectOp::create(rewriter, loc, mask, dstPtr, nullPtr);
2336 }
2337
2338 auto offset = rewriter.getI32IntegerAttr(0);
2339 auto aux = rewriter.getI32IntegerAttr(0);
2340
2341 switch (transferBits) {
2342 case 8:
2343 rewriter.replaceOpWithNewOp<ROCDL::GlobalLoadAsyncToLDSB8Op>(
2344 op, srcPtr, dstPtr, offset, aux, ArrayAttr{}, ArrayAttr{},
2345 ArrayAttr{});
2346 break;
2347 case 32:
2348 rewriter.replaceOpWithNewOp<ROCDL::GlobalLoadAsyncToLDSB32Op>(
2349 op, srcPtr, dstPtr, offset, aux, ArrayAttr{}, ArrayAttr{},
2350 ArrayAttr{});
2351 break;
2352 case 64:
2353 rewriter.replaceOpWithNewOp<ROCDL::GlobalLoadAsyncToLDSB64Op>(
2354 op, srcPtr, dstPtr, offset, aux, ArrayAttr{}, ArrayAttr{},
2355 ArrayAttr{});
2356 break;
2357 case 128:
2358 rewriter.replaceOpWithNewOp<ROCDL::GlobalLoadAsyncToLDSB128Op>(
2359 op, srcPtr, dstPtr, offset, aux, ArrayAttr{}, ArrayAttr{},
2360 ArrayAttr{});
2361 break;
2362 default:
2363 return op.emitOpError("unsupported transfer width");
2364 }
2365 return success();
2366 }
2367};
2368
2369namespace {
2370struct ExtPackedFp8OpLowering final
2371 : public ConvertOpToLLVMPattern<ExtPackedFp8Op> {
2372 ExtPackedFp8OpLowering(const LLVMTypeConverter &converter, Chipset chipset)
2373 : ConvertOpToLLVMPattern<amdgpu::ExtPackedFp8Op>(converter),
2374 chipset(chipset) {}
2375 Chipset chipset;
2376
2377 LogicalResult
2378 matchAndRewrite(ExtPackedFp8Op op, ExtPackedFp8OpAdaptor adaptor,
2379 ConversionPatternRewriter &rewriter) const override;
2380};
2381
2382struct ScaledExtPackedMatrixOpLowering final
2383 : public ConvertOpToLLVMPattern<ScaledExtPackedMatrixOp> {
2384 ScaledExtPackedMatrixOpLowering(const LLVMTypeConverter &converter,
2385 Chipset chipset)
2386 : ConvertOpToLLVMPattern<amdgpu::ScaledExtPackedMatrixOp>(converter),
2387 chipset(chipset) {}
2388 Chipset chipset;
2389
2390 LogicalResult
2391 matchAndRewrite(ScaledExtPackedMatrixOp op,
2392 ScaledExtPackedMatrixOpAdaptor adaptor,
2393 ConversionPatternRewriter &rewriter) const override;
2394};
2395
2396struct PackedTrunc2xFp8OpLowering final
2397 : public ConvertOpToLLVMPattern<PackedTrunc2xFp8Op> {
2398 PackedTrunc2xFp8OpLowering(const LLVMTypeConverter &converter,
2399 Chipset chipset)
2400 : ConvertOpToLLVMPattern<amdgpu::PackedTrunc2xFp8Op>(converter),
2401 chipset(chipset) {}
2402 Chipset chipset;
2403
2404 LogicalResult
2405 matchAndRewrite(PackedTrunc2xFp8Op op, PackedTrunc2xFp8OpAdaptor adaptor,
2406 ConversionPatternRewriter &rewriter) const override;
2407};
2408
2409struct PackedStochRoundFp8OpLowering final
2410 : public ConvertOpToLLVMPattern<PackedStochRoundFp8Op> {
2411 PackedStochRoundFp8OpLowering(const LLVMTypeConverter &converter,
2412 Chipset chipset)
2413 : ConvertOpToLLVMPattern<amdgpu::PackedStochRoundFp8Op>(converter),
2414 chipset(chipset) {}
2415 Chipset chipset;
2416
2417 LogicalResult
2418 matchAndRewrite(PackedStochRoundFp8Op op,
2419 PackedStochRoundFp8OpAdaptor adaptor,
2420 ConversionPatternRewriter &rewriter) const override;
2421};
2422
2423struct ScaledExtPackedOpLowering final
2424 : public ConvertOpToLLVMPattern<ScaledExtPackedOp> {
2425 ScaledExtPackedOpLowering(const LLVMTypeConverter &converter, Chipset chipset)
2426 : ConvertOpToLLVMPattern<amdgpu::ScaledExtPackedOp>(converter),
2427 chipset(chipset) {}
2428 Chipset chipset;
2429
2430 LogicalResult
2431 matchAndRewrite(ScaledExtPackedOp op, ScaledExtPackedOpAdaptor adaptor,
2432 ConversionPatternRewriter &rewriter) const override;
2433};
2434
2435struct PackedScaledTruncOpLowering final
2436 : public ConvertOpToLLVMPattern<PackedScaledTruncOp> {
2437 PackedScaledTruncOpLowering(const LLVMTypeConverter &converter,
2438 Chipset chipset)
2439 : ConvertOpToLLVMPattern<amdgpu::PackedScaledTruncOp>(converter),
2440 chipset(chipset) {}
2441 Chipset chipset;
2442
2443 LogicalResult
2444 matchAndRewrite(PackedScaledTruncOp op, PackedScaledTruncOpAdaptor adaptor,
2445 ConversionPatternRewriter &rewriter) const override;
2446};
2447
2448} // end namespace
2449
2450LogicalResult ExtPackedFp8OpLowering::matchAndRewrite(
2451 ExtPackedFp8Op op, ExtPackedFp8OpAdaptor adaptor,
2452 ConversionPatternRewriter &rewriter) const {
2453 Location loc = op.getLoc();
2454 if (!(chipset == kGfx942 || hasOcpFp8(chipset)))
2455 return rewriter.notifyMatchFailure(
2456 loc, "Fp8 conversion instructions are not available on target "
2457 "architecture and their emulation is not implemented");
2458 Type v4i8 =
2459 getTypeConverter()->convertType(VectorType::get(4, rewriter.getI8Type()));
2460 Type i32 = getTypeConverter()->convertType(rewriter.getI32Type());
2461 Type f32 = getTypeConverter()->convertType(op.getResult().getType());
2462
2463 Value source = adaptor.getSource();
2464 auto sourceVecType = dyn_cast<VectorType>(op.getSource().getType());
2465 auto resultVecType = dyn_cast<VectorType>(op.getResult().getType());
2466 Type sourceElemType = getElementTypeOrSelf(op.getSource());
2467 // Extend to a v4i8
2468 if (!sourceVecType || sourceVecType.getNumElements() < 4) {
2469 Value longVec = LLVM::UndefOp::create(rewriter, loc, v4i8);
2470 if (!sourceVecType) {
2471 longVec = LLVM::InsertElementOp::create(
2472 rewriter, loc, longVec, source, createI32Constant(rewriter, loc, 0));
2473 } else {
2474 for (int32_t i = 0, e = sourceVecType.getNumElements(); i < e; ++i) {
2475 Value idx = createI32Constant(rewriter, loc, i);
2476 Value elem = LLVM::ExtractElementOp::create(rewriter, loc, source, idx);
2477 longVec =
2478 LLVM::InsertElementOp::create(rewriter, loc, longVec, elem, idx);
2479 }
2480 }
2481 source = longVec;
2482 }
2483 Value i32Source = LLVM::BitcastOp::create(rewriter, loc, i32, source);
2484 if (resultVecType) {
2485 if (typeIsExpectedBf8ForChipset(chipset, sourceElemType)) {
2486 rewriter.replaceOpWithNewOp<ROCDL::CvtPkF32Bf8Op>(op, f32, i32Source,
2487 op.getIndex());
2488 } else if (typeIsExpectedFp8ForChipset(chipset, sourceElemType)) {
2489 rewriter.replaceOpWithNewOp<ROCDL::CvtPkF32Fp8Op>(op, f32, i32Source,
2490 op.getIndex());
2491 }
2492 } else {
2493 if (typeIsExpectedBf8ForChipset(chipset, sourceElemType)) {
2494 rewriter.replaceOpWithNewOp<ROCDL::CvtF32Bf8Op>(op, f32, i32Source,
2495 op.getIndex());
2496 } else if (typeIsExpectedFp8ForChipset(chipset, sourceElemType)) {
2497 rewriter.replaceOpWithNewOp<ROCDL::CvtF32Fp8Op>(op, f32, i32Source,
2498 op.getIndex());
2499 }
2500 }
2501 return success();
2502}
2503
2504int32_t getScaleSel(int32_t blockSize, unsigned bitWidth, int32_t scaleWaveHalf,
2505 int32_t firstScaleByte) {
2506 // When lowering amdgpu.scaled_ext_packed_matrix to rocdl.cvt.scale.pk*.f*.f*
2507 // operations, the attributes blockSize, sourceType, scaleWaveHalf, and
2508 // firstScaleByte are merged into a single attribute scaleSel. This is how
2509 // those values are merged together. (Note: scaleWaveHalf isn't a high-level
2510 // attribute but is derifed from firstScaleLane).
2511 assert(llvm::is_contained({16, 32}, blockSize));
2512 assert(llvm::is_contained({4u, 6u, 8u}, bitWidth));
2513
2514 const bool isFp8 = bitWidth == 8;
2515 const bool isBlock16 = blockSize == 16;
2516
2517 if (!isFp8) {
2518 int32_t bit0 = isBlock16;
2519 assert(llvm::is_contained({0, 1, 2}, firstScaleByte));
2520 int32_t bit1 = (firstScaleByte == 2) << 1;
2521 assert(llvm::is_contained({0, 1}, scaleWaveHalf));
2522 int32_t bit2 = scaleWaveHalf << 2;
2523 return bit2 | bit1 | bit0;
2524 }
2525
2526 int32_t bit0 = isBlock16;
2527 // firstScaleByte is guaranteed to be defined by two bits.
2528 assert(llvm::is_contained({0, 1, 2, 3}, firstScaleByte));
2529 int32_t bits2and1 = firstScaleByte << 1;
2530 assert(llvm::is_contained({0, 1}, scaleWaveHalf));
2531 int32_t bit3 = scaleWaveHalf << 3;
2532 int32_t bits = bit3 | bits2and1 | bit0;
2533 // These are invalid cases.
2534 assert(!llvm::is_contained(
2535 {0b0011, 0b0101, 0b0111, 0b1000, 0b1001, 0b1011, 0b1111}, bits));
2536 return bits;
2537}
2538
2539static std::optional<StringRef>
2540scaledExtPacked816ToIntrinsic(Type srcElemType, Type destElemType) {
2541 using fp4 = Float4E2M1FNType;
2542 using fp8 = Float8E4M3FNType;
2543 using bf8 = Float8E5M2Type;
2544 using fp6 = Float6E2M3FNType;
2545 using bf6 = Float6E3M2FNType;
2546 if (isa<fp4>(srcElemType)) {
2547 if (destElemType.isF16())
2548 return ROCDL::CvtPkScalePk8F16Fp4Op::getOperationName();
2549 if (destElemType.isBF16())
2550 return ROCDL::CvtPkScalePk8Bf16Fp4Op::getOperationName();
2551 if (destElemType.isF32())
2552 return ROCDL::CvtPkScalePk8F32Fp4Op::getOperationName();
2553 return std::nullopt;
2554 }
2555 if (isa<fp8>(srcElemType)) {
2556 if (destElemType.isF16())
2557 return ROCDL::CvtPkScalePk8F16Fp8Op::getOperationName();
2558 if (destElemType.isBF16())
2559 return ROCDL::CvtPkScalePk8Bf16Fp8Op::getOperationName();
2560 if (destElemType.isF32())
2561 return ROCDL::CvtPkScalePk8F32Fp8Op::getOperationName();
2562 return std::nullopt;
2563 }
2564 if (isa<bf8>(srcElemType)) {
2565 if (destElemType.isF16())
2566 return ROCDL::CvtPkScalePk8F16Bf8Op::getOperationName();
2567 if (destElemType.isBF16())
2568 return ROCDL::CvtPkScalePk8Bf16Bf8Op::getOperationName();
2569 if (destElemType.isF32())
2570 return ROCDL::CvtPkScalePk8F32Bf8Op::getOperationName();
2571 return std::nullopt;
2572 }
2573 if (isa<fp6>(srcElemType)) {
2574 if (destElemType.isF16())
2575 return ROCDL::CvtPkScalePk16F16Fp6Op::getOperationName();
2576 if (destElemType.isBF16())
2577 return ROCDL::CvtPkScalePk16Bf16Fp6Op::getOperationName();
2578 if (destElemType.isF32())
2579 return ROCDL::CvtPkScalePk16F32Fp6Op::getOperationName();
2580 return std::nullopt;
2581 }
2582 if (isa<bf6>(srcElemType)) {
2583 if (destElemType.isF16())
2584 return ROCDL::CvtPkScalePk16F16Bf6Op::getOperationName();
2585 if (destElemType.isBF16())
2586 return ROCDL::CvtPkScalePk16Bf16Bf6Op::getOperationName();
2587 if (destElemType.isF32())
2588 return ROCDL::CvtPkScalePk16F32Bf6Op::getOperationName();
2589 return std::nullopt;
2590 }
2591 llvm_unreachable("invalid combination of element types for packed conversion "
2592 "instructions");
2593}
2594
2595LogicalResult ScaledExtPackedMatrixOpLowering::matchAndRewrite(
2596 ScaledExtPackedMatrixOp op, ScaledExtPackedMatrixOpAdaptor adaptor,
2597 ConversionPatternRewriter &rewriter) const {
2598 using fp4 = Float4E2M1FNType;
2599 using fp8 = Float8E4M3FNType;
2600 using bf8 = Float8E5M2Type;
2601 using fp6 = Float6E2M3FNType;
2602 using bf6 = Float6E3M2FNType;
2603 Location loc = op.getLoc();
2604 if (chipset != kGfx1250) {
2605 return rewriter.notifyMatchFailure(
2606 loc,
2607 "Scaled fp packed conversion instructions are not available on target "
2608 "architecture and their emulation is not implemented");
2609 }
2610 // Convert user-facing firstScaleLane (0 or 16) to the half of the wave that
2611 // is being selected.
2612 int32_t scaleWaveHalf = op.getFirstScaleLane() / 16;
2613 int32_t firstScaleByte = op.getFirstScaleByte();
2614 int32_t blockSize = op.getBlockSize();
2615 auto sourceType = cast<VectorType>(op.getSource().getType());
2616 auto srcElemType = cast<FloatType>(sourceType.getElementType());
2617 unsigned bitWidth = srcElemType.getWidth();
2618
2619 auto targetType = cast<VectorType>(op.getResult().getType());
2620 auto destElemType = cast<FloatType>(targetType.getElementType());
2621
2622 IntegerType i32 = rewriter.getI32Type();
2623 Value source = adaptor.getSource();
2624 Type llvmResultType = typeConverter->convertType(op.getResult().getType());
2625 Type packedType = nullptr;
2626 if (isa<fp4>(srcElemType)) {
2627 packedType = i32;
2628 packedType = getTypeConverter()->convertType(packedType);
2629 } else if (isa<fp8, bf8>(srcElemType)) {
2630 packedType = VectorType::get(2, i32);
2631 packedType = getTypeConverter()->convertType(packedType);
2632 } else if (isa<fp6, bf6>(srcElemType)) {
2633 packedType = VectorType::get(3, i32);
2634 packedType = getTypeConverter()->convertType(packedType);
2635 } else {
2636 llvm_unreachable("invalid element type for packed scaled ext");
2637 }
2638
2639 if (!packedType || !llvmResultType) {
2640 return rewriter.notifyMatchFailure(op, "type conversion failed");
2641 }
2642
2643 std::optional<StringRef> maybeIntrinsic =
2644 scaledExtPacked816ToIntrinsic(srcElemType, destElemType);
2645 if (!maybeIntrinsic.has_value())
2646 return op.emitOpError(
2647 "no intrinsic matching packed scaled conversion on the given chipset");
2648
2649 int32_t scaleSel =
2650 getScaleSel(blockSize, bitWidth, scaleWaveHalf, firstScaleByte);
2651 Value castedScale =
2652 LLVM::BitcastOp::create(rewriter, loc, i32, adaptor.getScale());
2653 Value castedSource =
2654 LLVM::BitcastOp::create(rewriter, loc, packedType, source);
2655
2656 OperationState loweredOp(loc, *maybeIntrinsic);
2657 loweredOp.addTypes({llvmResultType});
2658 loweredOp.addOperands({castedSource, castedScale});
2659
2660 SmallVector<NamedAttribute, 1> attrs;
2661 attrs.push_back(
2662 NamedAttribute("scaleSel", rewriter.getI32IntegerAttr(scaleSel)));
2663
2664 loweredOp.addAttributes(attrs);
2665 Operation *lowered = rewriter.create(loweredOp);
2666 rewriter.replaceOp(op, lowered);
2667
2668 return success();
2669}
2670
2671LogicalResult ScaledExtPackedOpLowering::matchAndRewrite(
2672 ScaledExtPackedOp op, ScaledExtPackedOpAdaptor adaptor,
2673 ConversionPatternRewriter &rewriter) const {
2674 Location loc = op.getLoc();
2675 if (chipset != kGfx950)
2676 return rewriter.notifyMatchFailure(
2677 loc, "Scaled fp conversion instructions are not available on target "
2678 "architecture and their emulation is not implemented");
2679 Type i32 = getTypeConverter()->convertType(rewriter.getI32Type());
2680
2681 Value source = adaptor.getSource();
2682 Value scale = adaptor.getScale();
2683
2684 VectorType sourceVecType = cast<VectorType>(op.getSource().getType());
2685 Type sourceElemType = sourceVecType.getElementType();
2686 VectorType destVecType = cast<VectorType>(op.getResult().getType());
2687 Type destElemType = destVecType.getElementType();
2688
2689 VectorType packedVecType;
2690 if (isa<Float8E5M2Type, Float8E4M3FNType>(sourceElemType)) {
2691 VectorType v4i8 = VectorType::get(4, rewriter.getI8Type());
2692 packedVecType = cast<VectorType>(getTypeConverter()->convertType(v4i8));
2693 } else if (isa<Float4E2M1FNType>(sourceElemType)) {
2694 VectorType v8i4 = VectorType::get(8, rewriter.getI4Type());
2695 packedVecType = cast<VectorType>(getTypeConverter()->convertType(v8i4));
2696 } else {
2697 llvm_unreachable("invalid element type for scaled ext");
2698 }
2699
2700 // Extend to a packedVectorType
2701 if (sourceVecType.getNumElements() < packedVecType.getNumElements()) {
2702 Value longVec = LLVM::ZeroOp::create(rewriter, loc, packedVecType);
2703 if (!sourceVecType) {
2704 longVec = LLVM::InsertElementOp::create(
2705 rewriter, loc, longVec, source, createI32Constant(rewriter, loc, 0));
2706 } else {
2707 for (int32_t i = 0, e = sourceVecType.getNumElements(); i < e; ++i) {
2708 Value idx = createI32Constant(rewriter, loc, i);
2709 Value elem = LLVM::ExtractElementOp::create(rewriter, loc, source, idx);
2710 longVec =
2711 LLVM::InsertElementOp::create(rewriter, loc, longVec, elem, idx);
2712 }
2713 }
2714 source = longVec;
2715 }
2716 Value i32Source = LLVM::BitcastOp::create(rewriter, loc, i32, source);
2717
2718 if (isa<Float8E5M2Type>(sourceElemType) && destElemType.isF32())
2719 rewriter.replaceOpWithNewOp<ROCDL::CvtScaleF32PkF32Bf8Op>(
2720 op, destVecType, i32Source, scale, op.getIndex());
2721 else if (isa<Float8E5M2Type>(sourceElemType) && destElemType.isF16())
2722 rewriter.replaceOpWithNewOp<ROCDL::CvtScaleF32PkF16Bf8Op>(
2723 op, destVecType, i32Source, scale, op.getIndex());
2724 else if (isa<Float8E5M2Type>(sourceElemType) && destElemType.isBF16())
2725 rewriter.replaceOpWithNewOp<ROCDL::CvtScaleF32PkBf16Bf8Op>(
2726 op, destVecType, i32Source, scale, op.getIndex());
2727 else if (isa<Float8E4M3FNType>(sourceElemType) && destElemType.isF32())
2728 rewriter.replaceOpWithNewOp<ROCDL::CvtScaleF32PkF32Fp8Op>(
2729 op, destVecType, i32Source, scale, op.getIndex());
2730 else if (isa<Float8E4M3FNType>(sourceElemType) && destElemType.isF16())
2731 rewriter.replaceOpWithNewOp<ROCDL::CvtScaleF32PkF16Fp8Op>(
2732 op, destVecType, i32Source, scale, op.getIndex());
2733 else if (isa<Float8E4M3FNType>(sourceElemType) && destElemType.isBF16())
2734 rewriter.replaceOpWithNewOp<ROCDL::CvtScaleF32PkBf16Fp8Op>(
2735 op, destVecType, i32Source, scale, op.getIndex());
2736 else if (isa<Float4E2M1FNType>(sourceElemType) && destElemType.isF32())
2737 rewriter.replaceOpWithNewOp<ROCDL::CvtScaleF32PkF32Fp4Op>(
2738 op, destVecType, i32Source, scale, op.getIndex());
2739 else if (isa<Float4E2M1FNType>(sourceElemType) && destElemType.isF16())
2740 rewriter.replaceOpWithNewOp<ROCDL::CvtScaleF32PkF16Fp4Op>(
2741 op, destVecType, i32Source, scale, op.getIndex());
2742 else if (isa<Float4E2M1FNType>(sourceElemType) && destElemType.isBF16())
2743 rewriter.replaceOpWithNewOp<ROCDL::CvtScaleF32PkBf16Fp4Op>(
2744 op, destVecType, i32Source, scale, op.getIndex());
2745 else
2746 return failure();
2747
2748 return success();
2749}
2750
2751LogicalResult PackedScaledTruncOpLowering::matchAndRewrite(
2752 PackedScaledTruncOp op, PackedScaledTruncOpAdaptor adaptor,
2753 ConversionPatternRewriter &rewriter) const {
2754 Location loc = op.getLoc();
2755 if (chipset != kGfx950)
2756 return rewriter.notifyMatchFailure(
2757 loc, "Scaled fp conversion instructions are not available on target "
2758 "architecture and their emulation is not implemented");
2759 Type v2i16 = getTypeConverter()->convertType(
2760 VectorType::get(2, rewriter.getI16Type()));
2761 Type i32 = getTypeConverter()->convertType(rewriter.getI32Type());
2762
2763 Type resultType = op.getResult().getType();
2764 Type resultElemType = getElementTypeOrSelf(resultType);
2765 VectorType sourceVecType = cast<VectorType>(op.getSource().getType());
2766 Type sourceElemType = sourceVecType.getElementType();
2767
2768 Type intResultType = isa<Float4E2M1FNType>(resultElemType) ? i32 : v2i16;
2769
2770 Value source = adaptor.getSource();
2771 Value scale = adaptor.getScale();
2772 Value existing = adaptor.getExisting();
2773 if (existing)
2774 existing = LLVM::BitcastOp::create(rewriter, loc, intResultType, existing);
2775 else
2776 existing = LLVM::ZeroOp::create(rewriter, loc, intResultType);
2777
2778 if (sourceVecType.getNumElements() < 2) {
2779 Value c0 = createI32Constant(rewriter, loc, 0);
2780 Value elem0 = LLVM::ExtractElementOp::create(rewriter, loc, source, c0);
2781 VectorType v2 = VectorType::get(2, sourceElemType);
2782 source = LLVM::ZeroOp::create(rewriter, loc, v2);
2783 source = LLVM::InsertElementOp::create(rewriter, loc, source, elem0, c0);
2784 }
2785
2786 Value sourceA, sourceB;
2787 if (sourceElemType.isF32()) {
2788 Value c0 = createI32Constant(rewriter, loc, 0);
2789 Value c1 = createI32Constant(rewriter, loc, 1);
2790 sourceA = LLVM::ExtractElementOp::create(rewriter, loc, source, c0);
2791 sourceB = LLVM::ExtractElementOp::create(rewriter, loc, source, c1);
2792 }
2793
2794 Value result;
2795 if (sourceElemType.isF32() && isa<Float8E5M2Type>(resultElemType))
2796 result = ROCDL::CvtScaleF32PkBf8F32Op::create(rewriter, loc, intResultType,
2797 existing, sourceA, sourceB,
2798 scale, op.getIndex());
2799 else if (sourceElemType.isF16() && isa<Float8E5M2Type>(resultElemType))
2800 result = ROCDL::CvtScaleF32PkBf8F16Op::create(
2801 rewriter, loc, intResultType, existing, source, scale, op.getIndex());
2802 else if (sourceElemType.isBF16() && isa<Float8E5M2Type>(resultElemType))
2803 result = ROCDL::CvtScaleF32PkBf8Bf16Op::create(
2804 rewriter, loc, intResultType, existing, source, scale, op.getIndex());
2805 else if (sourceElemType.isF32() && isa<Float8E4M3FNType>(resultElemType))
2806 result = ROCDL::CvtScaleF32PkFp8F32Op::create(rewriter, loc, intResultType,
2807 existing, sourceA, sourceB,
2808 scale, op.getIndex());
2809 else if (sourceElemType.isF16() && isa<Float8E4M3FNType>(resultElemType))
2810 result = ROCDL::CvtScaleF32PkFp8F16Op::create(
2811 rewriter, loc, intResultType, existing, source, scale, op.getIndex());
2812 else if (sourceElemType.isBF16() && isa<Float8E4M3FNType>(resultElemType))
2813 result = ROCDL::CvtScaleF32PkFp8Bf16Op::create(
2814 rewriter, loc, intResultType, existing, source, scale, op.getIndex());
2815 else if (sourceElemType.isF32() && isa<Float4E2M1FNType>(resultElemType))
2816 result = ROCDL::CvtScaleF32PkFp4F32Op::create(rewriter, loc, intResultType,
2817 existing, sourceA, sourceB,
2818 scale, op.getIndex());
2819 else if (sourceElemType.isF16() && isa<Float4E2M1FNType>(resultElemType))
2820 result = ROCDL::CvtScaleF32PkFp4F16Op::create(
2821 rewriter, loc, intResultType, existing, source, scale, op.getIndex());
2822 else if (sourceElemType.isBF16() && isa<Float4E2M1FNType>(resultElemType))
2823 result = ROCDL::CvtScaleF32PkFp4Bf16Op::create(
2824 rewriter, loc, intResultType, existing, source, scale, op.getIndex());
2825 else
2826 return failure();
2827
2828 result = rewriter.replaceOpWithNewOp<LLVM::BitcastOp>(
2829 op, getTypeConverter()->convertType(resultType), result);
2830 return success();
2831}
2832
2833LogicalResult PackedTrunc2xFp8OpLowering::matchAndRewrite(
2834 PackedTrunc2xFp8Op op, PackedTrunc2xFp8OpAdaptor adaptor,
2835 ConversionPatternRewriter &rewriter) const {
2836 Location loc = op.getLoc();
2837 if (!(chipset == kGfx942 || hasOcpFp8(chipset)))
2838 return rewriter.notifyMatchFailure(
2839 loc, "Fp8 conversion instructions are not available on target "
2840 "architecture and their emulation is not implemented");
2841 Type i32 = getTypeConverter()->convertType(rewriter.getI32Type());
2842
2843 Type resultType = op.getResult().getType();
2844 Type resultElemType = getElementTypeOrSelf(resultType);
2845
2846 Value sourceA = adaptor.getSourceA();
2847 Value sourceB = adaptor.getSourceB();
2848 if (!sourceB)
2849 sourceB = LLVM::UndefOp::create(rewriter, loc, sourceA.getType());
2850 Value existing = adaptor.getExisting();
2851 if (existing)
2852 existing = LLVM::BitcastOp::create(rewriter, loc, i32, existing);
2853 else
2854 existing = LLVM::UndefOp::create(rewriter, loc, i32);
2855
2856 Value result;
2857 if (typeIsExpectedBf8ForChipset(chipset, resultElemType))
2858 result = ROCDL::CvtPkBf8F32Op::create(rewriter, loc, i32, sourceA, sourceB,
2859 existing, op.getWordIndex());
2860 else if (typeIsExpectedFp8ForChipset(chipset, resultElemType))
2861 result = ROCDL::CvtPkFp8F32Op::create(rewriter, loc, i32, sourceA, sourceB,
2862 existing, op.getWordIndex());
2863
2864 result = rewriter.replaceOpWithNewOp<LLVM::BitcastOp>(
2865 op, getTypeConverter()->convertType(resultType), result);
2866 return success();
2867}
2868
2869LogicalResult PackedStochRoundFp8OpLowering::matchAndRewrite(
2870 PackedStochRoundFp8Op op, PackedStochRoundFp8OpAdaptor adaptor,
2871 ConversionPatternRewriter &rewriter) const {
2872 Location loc = op.getLoc();
2873 if (!(chipset == kGfx942 || hasOcpFp8(chipset)))
2874 return rewriter.notifyMatchFailure(
2875 loc, "Fp8 conversion instructions are not available on target "
2876 "architecture and their emulation is not implemented");
2877 Type i32 = getTypeConverter()->convertType(rewriter.getI32Type());
2878
2879 Type resultType = op.getResult().getType();
2880 Type resultElemType = getElementTypeOrSelf(resultType);
2881
2882 Value source = adaptor.getSource();
2883 Value stoch = adaptor.getStochiasticParam();
2884 Value existing = adaptor.getExisting();
2885 if (existing)
2886 existing = LLVM::BitcastOp::create(rewriter, loc, i32, existing);
2887 else
2888 existing = LLVM::UndefOp::create(rewriter, loc, i32);
2889
2890 Value result;
2891 if (typeIsExpectedBf8ForChipset(chipset, resultElemType))
2892 result = ROCDL::CvtSrBf8F32Op::create(rewriter, loc, i32, source, stoch,
2893 existing, op.getStoreIndex());
2894 else if (typeIsExpectedFp8ForChipset(chipset, resultElemType))
2895 result = ROCDL::CvtSrFp8F32Op::create(rewriter, loc, i32, source, stoch,
2896 existing, op.getStoreIndex());
2897
2898 result = rewriter.replaceOpWithNewOp<LLVM::BitcastOp>(
2899 op, getTypeConverter()->convertType(resultType), result);
2900 return success();
2901}
2902
2903// Implement the AMDGPU_DPPLowering class that will convert the amdgpu.dpp
2904// operation into the corresponding ROCDL instructions.
2905struct AMDGPUDPPLowering : public ConvertOpToLLVMPattern<DPPOp> {
2906 AMDGPUDPPLowering(const LLVMTypeConverter &converter, Chipset chipset)
2907 : ConvertOpToLLVMPattern<DPPOp>(converter), chipset(chipset) {}
2908 Chipset chipset;
2909
2910 LogicalResult
2911 matchAndRewrite(DPPOp DppOp, DPPOp::Adaptor adaptor,
2912 ConversionPatternRewriter &rewriter) const override {
2913
2914 // Convert the source operand to the corresponding LLVM type
2915 Location loc = DppOp.getLoc();
2916 Value src = adaptor.getSrc();
2917 Value old = adaptor.getOld();
2918 Type srcType = src.getType();
2919 Type oldType = old.getType();
2920 Type llvmType = nullptr;
2921 if (srcType.getIntOrFloatBitWidth() < 32) {
2922 llvmType = rewriter.getI32Type();
2923 } else if (isa<FloatType>(srcType)) {
2924 llvmType = (srcType.getIntOrFloatBitWidth() == 32)
2925 ? rewriter.getF32Type()
2926 : rewriter.getF64Type();
2927 } else if (isa<IntegerType>(srcType)) {
2928 llvmType = (srcType.getIntOrFloatBitWidth() == 32)
2929 ? rewriter.getI32Type()
2930 : rewriter.getI64Type();
2931 }
2932 auto llvmSrcIntType = typeConverter->convertType(
2933 rewriter.getIntegerType(srcType.getIntOrFloatBitWidth()));
2934
2935 // If the source type is less of 32, use bitcast to convert it to i32.
2936 auto convertOperand = [&](Value operand, Type operandType) {
2937 if (operandType.getIntOrFloatBitWidth() <= 16) {
2938 if (llvm::isa<FloatType>(operandType)) {
2939 operand =
2940 LLVM::BitcastOp::create(rewriter, loc, llvmSrcIntType, operand);
2941 }
2942 auto llvmVecType = typeConverter->convertType(mlir::VectorType::get(
2943 32 / operandType.getIntOrFloatBitWidth(), llvmSrcIntType));
2944 Value undefVec = LLVM::UndefOp::create(rewriter, loc, llvmVecType);
2945 operand =
2946 LLVM::InsertElementOp::create(rewriter, loc, undefVec, operand,
2947 createI32Constant(rewriter, loc, 0));
2948 operand = LLVM::BitcastOp::create(rewriter, loc, llvmType, operand);
2949 }
2950 return operand;
2951 };
2952
2953 src = convertOperand(src, srcType);
2954 old = convertOperand(old, oldType);
2955
2956 // This is taken from the following file llvm/lib/Target/AMDGPU/SIDefines.h
2957 enum DppCtrl : unsigned {
2958 ROW_SHL0 = 0x100,
2959 ROW_SHR0 = 0x110,
2960 ROW_ROR0 = 0x120,
2961 WAVE_SHL1 = 0x130,
2962 WAVE_ROL1 = 0x134,
2963 WAVE_SHR1 = 0x138,
2964 WAVE_ROR1 = 0x13C,
2965 ROW_MIRROR = 0x140,
2966 ROW_HALF_MIRROR = 0x141,
2967 BCAST15 = 0x142,
2968 BCAST31 = 0x143,
2969 };
2970
2971 auto kind = DppOp.getKind();
2972 auto permArgument = DppOp.getPermArgument();
2973 uint32_t DppCtrl = 0;
2974
2975 switch (kind) {
2976
2977 case DPPPerm::quad_perm: {
2978 auto quadPermAttr = cast<ArrayAttr>(*permArgument);
2979 int32_t i = 0;
2980 for (auto elem : quadPermAttr.getAsRange<IntegerAttr>()) {
2981 uint32_t num = elem.getInt();
2982 DppCtrl |= num << (i * 2);
2983 i++;
2984 }
2985 break;
2986 }
2987 case DPPPerm::row_shl: {
2988 auto intAttr = cast<IntegerAttr>(*permArgument);
2989 DppCtrl = intAttr.getInt() + DppCtrl::ROW_SHL0;
2990 break;
2991 }
2992 case DPPPerm::row_shr: {
2993 auto intAttr = cast<IntegerAttr>(*permArgument);
2994 DppCtrl = intAttr.getInt() + DppCtrl::ROW_SHR0;
2995 break;
2996 }
2997 case DPPPerm::row_ror: {
2998 auto intAttr = cast<IntegerAttr>(*permArgument);
2999 DppCtrl = intAttr.getInt() + DppCtrl::ROW_ROR0;
3000 break;
3001 }
3002 case DPPPerm::wave_shl:
3003 DppCtrl = DppCtrl::WAVE_SHL1;
3004 break;
3005 case DPPPerm::wave_shr:
3006 DppCtrl = DppCtrl::WAVE_SHR1;
3007 break;
3008 case DPPPerm::wave_rol:
3009 DppCtrl = DppCtrl::WAVE_ROL1;
3010 break;
3011 case DPPPerm::wave_ror:
3012 DppCtrl = DppCtrl::WAVE_ROR1;
3013 break;
3014 case DPPPerm::row_mirror:
3015 DppCtrl = DppCtrl::ROW_MIRROR;
3016 break;
3017 case DPPPerm::row_half_mirror:
3018 DppCtrl = DppCtrl::ROW_HALF_MIRROR;
3019 break;
3020 case DPPPerm::row_bcast_15:
3021 DppCtrl = DppCtrl::BCAST15;
3022 break;
3023 case DPPPerm::row_bcast_31:
3024 DppCtrl = DppCtrl::BCAST31;
3025 break;
3026 }
3027
3028 // Check for row_mask, bank_mask, bound_ctrl if they exist and create
3029 // constants
3030 auto rowMask = DppOp->getAttrOfType<IntegerAttr>("row_mask").getInt();
3031 auto bankMask = DppOp->getAttrOfType<IntegerAttr>("bank_mask").getInt();
3032 bool boundCtrl = DppOp->getAttrOfType<BoolAttr>("bound_ctrl").getValue();
3033
3034 // create a ROCDL_DPPMovOp instruction with the appropriate attributes
3035 auto dppMovOp =
3036 ROCDL::DPPUpdateOp::create(rewriter, loc, llvmType, old, src, DppCtrl,
3037 rowMask, bankMask, boundCtrl);
3038
3039 Value result = dppMovOp.getRes();
3040 if (srcType.getIntOrFloatBitWidth() < 32) {
3041 result = LLVM::TruncOp::create(rewriter, loc, llvmSrcIntType, result);
3042 if (!llvm::isa<IntegerType>(srcType)) {
3043 result = LLVM::BitcastOp::create(rewriter, loc, srcType, result);
3044 }
3045 }
3046
3047 // We are replacing the AMDGPU_DPPOp instruction with the new
3048 // ROCDL_DPPMovOp instruction
3049 rewriter.replaceOp(DppOp, ValueRange(result));
3050 return success();
3051 }
3052};
3053
3054struct AMDGPUSwizzleBitModeLowering
3055 : public ConvertOpToLLVMPattern<SwizzleBitModeOp> {
3057
3058 LogicalResult
3059 matchAndRewrite(SwizzleBitModeOp op, OpAdaptor adaptor,
3060 ConversionPatternRewriter &rewriter) const override {
3061 Location loc = op.getLoc();
3062 Type i32 = rewriter.getI32Type();
3063 Value src = adaptor.getSrc();
3064 SmallVector<Value> decomposed;
3065 if (failed(LLVM::decomposeValue(rewriter, loc, src, i32, decomposed)))
3066 return rewriter.notifyMatchFailure(op,
3067 "failed to decompose value to i32");
3068 unsigned andMask = op.getAndMask();
3069 unsigned orMask = op.getOrMask();
3070 unsigned xorMask = op.getXorMask();
3071
3072 // bit 15 is 0 for the BitMode swizzle.
3073 // https://gpuopen.com/learn/amd-gcn-assembly-cross-lane-operations/
3074 unsigned mask = andMask | (orMask << 5) | (xorMask << 10);
3075 Value maskValue = createI32Constant(rewriter, loc, mask);
3076 SmallVector<Value> swizzled;
3077 for (Value v : decomposed) {
3078 Value res =
3079 ROCDL::DsSwizzleOp::create(rewriter, loc, v.getType(), v, maskValue);
3080 swizzled.emplace_back(res);
3081 }
3082
3083 Value result = LLVM::composeValue(rewriter, loc, swizzled, src.getType());
3084 rewriter.replaceOp(op, result);
3085 return success();
3086 }
3087};
3088
3089struct AMDGPUPermlaneLowering : public ConvertOpToLLVMPattern<PermlaneSwapOp> {
3091
3092 AMDGPUPermlaneLowering(const LLVMTypeConverter &converter, Chipset chipset)
3093 : ConvertOpToLLVMPattern<PermlaneSwapOp>(converter), chipset(chipset) {}
3094 Chipset chipset;
3095
3096 LogicalResult
3097 matchAndRewrite(PermlaneSwapOp op, OpAdaptor adaptor,
3098 ConversionPatternRewriter &rewriter) const override {
3099 if (chipset < kGfx950)
3100 return op->emitOpError("permlane_swap is only supported on gfx950+");
3101
3102 Location loc = op.getLoc();
3103 Type i32 = rewriter.getI32Type();
3104 Value src = adaptor.getSrc();
3105 unsigned rowLength = op.getRowLength();
3106 bool fi = op.getFetchInactive();
3107 bool boundctrl = op.getBoundCtrl();
3108
3109 SmallVector<Value> decomposed;
3110 if (failed(LLVM::decomposeValue(rewriter, loc, src, i32, decomposed)))
3111 return rewriter.notifyMatchFailure(op,
3112 "failed to decompose value to i32");
3113
3114 SmallVector<Value> permuted;
3115 for (Value v : decomposed) {
3116 Value res;
3117 Type i32pair = LLVM::LLVMStructType::getLiteral(
3118 rewriter.getContext(), {v.getType(), v.getType()});
3119
3120 if (rowLength == 16)
3121 res = ROCDL::Permlane16SwapOp::create(rewriter, loc, i32pair, v, v, fi,
3122 boundctrl);
3123 else if (rowLength == 32)
3124 res = ROCDL::Permlane32SwapOp::create(rewriter, loc, i32pair, v, v, fi,
3125 boundctrl);
3126 else
3127 llvm_unreachable("unsupported row length");
3128
3129 Value vdst0 = LLVM::ExtractValueOp::create(rewriter, loc, res, {0});
3130 Value vdst1 = LLVM::ExtractValueOp::create(rewriter, loc, res, {1});
3131
3132 Value isEqual = LLVM::ICmpOp::create(rewriter, loc,
3133 LLVM::ICmpPredicate::eq, vdst0, v);
3134
3135 // Per `permlane(16|32)` semantics: if the first extracted element equals
3136 // 'v', the result is the second element; otherwise it is the first.
3137 Value vdstNew =
3138 LLVM::SelectOp::create(rewriter, loc, isEqual, vdst1, vdst0);
3139 permuted.emplace_back(vdstNew);
3140 }
3141
3142 Value result = LLVM::composeValue(rewriter, loc, permuted, src.getType());
3143 rewriter.replaceOp(op, result);
3144 return success();
3145 }
3146};
3147
3148//===----------------------------------------------------------------------===//
3149// In-LDS Barrier Operations
3150//===----------------------------------------------------------------------===//
3151
3152// Bit layout of ds_barrier_state (as i64):
3153// [63:32] init count (32 bits)
3154// [31:29] phase (3 bits)
3155// [28:0] pending count (29 bits)
3156constexpr int32_t kDsBarrierPendingCountBitWidth = 29;
3157constexpr int32_t kDsBarrierPhasePos = kDsBarrierPendingCountBitWidth;
3158constexpr int32_t kDsBarrierInitCountPos = 32;
3159constexpr int32_t kDsBarrierPendingCountMask =
3160 (1 << kDsBarrierPendingCountBitWidth) - 1;
3161
3162struct DsBarrierInitOpLowering
3163 : public ConvertOpToLLVMPattern<DsBarrierInitOp> {
3164 Chipset chipset;
3165
3166 DsBarrierInitOpLowering(const LLVMTypeConverter &converter, Chipset chipset)
3167 : ConvertOpToLLVMPattern<DsBarrierInitOp>(converter), chipset(chipset) {}
3168
3169 LogicalResult
3170 matchAndRewrite(DsBarrierInitOp op, OpAdaptor adaptor,
3171 ConversionPatternRewriter &rewriter) const override {
3172 if (chipset < kGfx1250)
3173 return op->emitOpError("only supported on gfx1250+");
3174
3175 Location loc = op.getLoc();
3176 Type i64 = rewriter.getI64Type();
3177
3178 MemRefType memrefType = cast<MemRefType>(op.getBase().getType());
3179 Value ptr = getStridedElementPtr(rewriter, loc, memrefType,
3180 adaptor.getBase(), adaptor.getIndices());
3181
3182 // Note: We give participants as the number of arrivals that have to occur
3183 // before the phase changes. Hardware changes the phase when updating the
3184 // pending count would underflow, so we subtract 1 to get the behavior we're
3185 // looking for.
3186 Value initCount =
3187 LLVM::SubOp::create(rewriter, loc, adaptor.getParticipants(),
3188 createI32Constant(rewriter, loc, 1));
3189
3190 // Just a bit of paranoia, but this also allows for configurable width if
3191 // that becomes a thing.
3192 Value countMask =
3193 createI32Constant(rewriter, loc, kDsBarrierPendingCountMask);
3194 Value maskedCount32 =
3195 LLVM::AndOp::create(rewriter, loc, initCount, countMask);
3196 Value maskedCount = LLVM::ZExtOp::create(rewriter, loc, i64, maskedCount32);
3197
3198 Value initCountShifted = LLVM::ShlOp::create(
3199 rewriter, loc, maskedCount,
3200 createI64Constant(rewriter, loc, kDsBarrierInitCountPos));
3201 Value barrierState =
3202 LLVM::OrOp::create(rewriter, loc, initCountShifted, maskedCount);
3203
3204 LLVM::StoreOp::create(
3205 rewriter, loc, barrierState, ptr, /*alignment=*/8, /*isVolatile=*/false,
3206 /*isNonTemporal=*/false,
3207 /*isInvariantGroup=*/false, LLVM::AtomicOrdering::release,
3208 /*syncscope=*/"workgroup");
3209
3210 rewriter.eraseOp(op);
3211 return success();
3212 }
3213};
3214
3215struct DsBarrierPollStateOpLowering
3216 : public ConvertOpToLLVMPattern<DsBarrierPollStateOp> {
3217 Chipset chipset;
3218
3219 DsBarrierPollStateOpLowering(const LLVMTypeConverter &converter,
3220 Chipset chipset)
3221 : ConvertOpToLLVMPattern<DsBarrierPollStateOp>(converter),
3222 chipset(chipset) {}
3223
3224 LogicalResult
3225 matchAndRewrite(DsBarrierPollStateOp op, OpAdaptor adaptor,
3226 ConversionPatternRewriter &rewriter) const override {
3227 if (chipset < kGfx1250)
3228 return op->emitOpError("only supported on gfx1250+");
3229
3230 Location loc = op.getLoc();
3231 Type i64 = rewriter.getI64Type();
3232
3233 MemRefType memrefType = cast<MemRefType>(op.getBase().getType());
3234 Value ptr = getStridedElementPtr(rewriter, loc, memrefType,
3235 adaptor.getBase(), adaptor.getIndices());
3236
3237 // Atomic load with workgroup scope and acquire ordering should be what
3238 // we're looking for.
3239 rewriter.replaceOpWithNewOp<LLVM::LoadOp>(
3240 op, i64, ptr, /*alignment=*/8, /*volatile_=*/false,
3241 /*nontemporal=*/false, /*invariant=*/false,
3242 /*invariantGroup=*/false, LLVM::AtomicOrdering::acquire,
3243 /*syncscope=*/"workgroup");
3244 return success();
3245 }
3246};
3247
3248struct DsAsyncBarrierArriveOpLowering
3249 : public ConvertOpToLLVMPattern<DsAsyncBarrierArriveOp> {
3250 Chipset chipset;
3251
3252 DsAsyncBarrierArriveOpLowering(const LLVMTypeConverter &converter,
3253 Chipset chipset)
3254 : ConvertOpToLLVMPattern<DsAsyncBarrierArriveOp>(converter),
3255 chipset(chipset) {}
3256
3257 LogicalResult
3258 matchAndRewrite(DsAsyncBarrierArriveOp op, OpAdaptor adaptor,
3259 ConversionPatternRewriter &rewriter) const override {
3260 if (chipset < kGfx1250)
3261 return op->emitOpError("only supported on gfx1250+");
3262
3263 Location loc = op.getLoc();
3264
3265 MemRefType memrefType = cast<MemRefType>(op.getBase().getType());
3266 Value ptr = getStridedElementPtr(rewriter, loc, memrefType,
3267 adaptor.getBase(), adaptor.getIndices());
3268
3269 rewriter.replaceOpWithNewOp<ROCDL::DsAtomicAsyncBarrierArriveOp>(
3270 op, ptr, /*alias_scopes=*/nullptr, /*noalias_scopes=*/nullptr,
3271 /*tbaa=*/nullptr);
3272 return success();
3273 }
3274};
3275
3276struct DsBarrierArriveOpLowering
3277 : public ConvertOpToLLVMPattern<DsBarrierArriveOp> {
3278 Chipset chipset;
3279
3280 DsBarrierArriveOpLowering(const LLVMTypeConverter &converter, Chipset chipset)
3281 : ConvertOpToLLVMPattern<DsBarrierArriveOp>(converter), chipset(chipset) {
3282 }
3283
3284 LogicalResult
3285 matchAndRewrite(DsBarrierArriveOp op, OpAdaptor adaptor,
3286 ConversionPatternRewriter &rewriter) const override {
3287 if (chipset < kGfx1250)
3288 return op->emitOpError("only supported on gfx1250+");
3289
3290 Location loc = op.getLoc();
3291 Type i64 = rewriter.getI64Type();
3292
3293 MemRefType memrefType = cast<MemRefType>(op.getBase().getType());
3294 Value ptr = getStridedElementPtr(rewriter, loc, memrefType,
3295 adaptor.getBase(), adaptor.getIndices());
3296
3297 rewriter.replaceOpWithNewOp<ROCDL::DsAtomicBarrierArriveRtnOp>(
3298 op, i64, ptr, adaptor.getCount(), /*alias_scopes=*/nullptr,
3299 /*noalias_scopes=*/nullptr, /*tbaa=*/nullptr);
3300 return success();
3301 }
3302};
3303
3304struct DsBarrierStatePhaseOpLowering
3305 : public ConvertOpToLLVMPattern<DsBarrierStatePhaseOp> {
3307
3308 LogicalResult
3309 matchAndRewrite(DsBarrierStatePhaseOp op, OpAdaptor adaptor,
3310 ConversionPatternRewriter &rewriter) const override {
3311 Location loc = op.getLoc();
3312 Type i32 = rewriter.getI32Type();
3313
3314 Value state = adaptor.getState();
3315
3316 Value noInitCount = LLVM::TruncOp::create(rewriter, loc, i32, state);
3317 Value phase = LLVM::LShrOp::create(
3318 rewriter, loc, noInitCount,
3319 createI32Constant(rewriter, loc, kDsBarrierPhasePos));
3320
3321 rewriter.replaceOp(op, phase);
3322 return success();
3323 }
3324};
3325
3326struct DsBarrierStatePendingCountOpLowering
3327 : public ConvertOpToLLVMPattern<DsBarrierStatePendingCountOp> {
3329
3330 LogicalResult
3331 matchAndRewrite(DsBarrierStatePendingCountOp op, OpAdaptor adaptor,
3332 ConversionPatternRewriter &rewriter) const override {
3333 Location loc = op.getLoc();
3334 Type i32 = rewriter.getI32Type();
3335
3336 Value state = adaptor.getState();
3337
3338 Value noInitCount = LLVM::TruncOp::create(rewriter, loc, i32, state);
3339 Value pendingCount = LLVM::AndOp::create(
3340 rewriter, loc, noInitCount,
3341 createI32Constant(rewriter, loc,
3342 static_cast<uint32_t>(kDsBarrierPendingCountMask)));
3343
3344 rewriter.replaceOp(op, pendingCount);
3345 return success();
3346 }
3347};
3348
3349struct DsBarrierStateInitCountOpLowering
3350 : public ConvertOpToLLVMPattern<DsBarrierStateInitCountOp> {
3352
3353 LogicalResult
3354 matchAndRewrite(DsBarrierStateInitCountOp op, OpAdaptor adaptor,
3355 ConversionPatternRewriter &rewriter) const override {
3356 Location loc = op.getLoc();
3357 Type i32 = rewriter.getI32Type();
3358
3359 Value state = adaptor.getState();
3360
3361 Value initCountI64 = LLVM::LShrOp::create(
3362 rewriter, loc, state,
3363 createI64Constant(rewriter, loc, kDsBarrierInitCountPos));
3364 Value initCount = LLVM::TruncOp::create(rewriter, loc, i32, initCountI64);
3365
3366 rewriter.replaceOp(op, initCount);
3367 return success();
3368 }
3369};
3370
3371struct DsBarrierStatePhaseParityLowering
3372 : public ConvertOpToLLVMPattern<DsBarrierStatePhaseParity> {
3374
3375 LogicalResult
3376 matchAndRewrite(DsBarrierStatePhaseParity op, OpAdaptor adaptor,
3377 ConversionPatternRewriter &rewriter) const override {
3378 Location loc = op.getLoc();
3379 Type i1 = rewriter.getI1Type();
3380
3381 Value state = adaptor.getState();
3382
3383 Value noInitCount =
3384 LLVM::TruncOp::create(rewriter, loc, rewriter.getI32Type(), state);
3385 Value phase = LLVM::LShrOp::create(
3386 rewriter, loc, noInitCount,
3387 createI32Constant(rewriter, loc, kDsBarrierPhasePos));
3388 Value parity = LLVM::TruncOp::create(rewriter, loc, i1, phase);
3389
3390 rewriter.replaceOp(op, parity);
3391 return success();
3392 }
3393};
3394
3395//===----------------------------------------------------------------------===//
3396// Tensor Data Mover (TDM)
3397//===----------------------------------------------------------------------===//
3398
3399static Value setValueAtOffset(ConversionPatternRewriter &rewriter, Location loc,
3400 Value accumulator, Value value, int64_t shift) {
3401 shift = shift % 32;
3402 Value shiftAmount;
3403 if (shift != 0) {
3404 shiftAmount = createI32Constant(rewriter, loc, shift % 32);
3405 value = LLVM::ShlOp::create(rewriter, loc, value, shiftAmount);
3406 }
3407
3408 if (matchPattern(accumulator, mlir::m_Zero()))
3409 return value;
3410
3411 constexpr bool isDisjoint = true;
3412 return LLVM::OrOp::create(rewriter, loc, accumulator, value, isDisjoint);
3413}
3414
3415template <typename BaseOp>
3416struct AMDGPUMakeDmaBaseLowering : public ConvertOpToLLVMPattern<BaseOp> {
3417 using ConvertOpToLLVMPattern<BaseOp>::ConvertOpToLLVMPattern;
3418 using Adaptor = typename ConvertOpToLLVMPattern<BaseOp>::OpAdaptor;
3419
3420 AMDGPUMakeDmaBaseLowering(const LLVMTypeConverter &converter, Chipset chipset)
3421 : ConvertOpToLLVMPattern<BaseOp>(converter), chipset(chipset) {}
3422 Chipset chipset;
3423
3424 LogicalResult
3425 matchAndRewrite(BaseOp op, Adaptor adaptor,
3426 ConversionPatternRewriter &rewriter) const override {
3427 if (chipset < kGfx1250)
3428 return op->emitOpError("make_dma_base is only supported on gfx1250");
3429
3430 Location loc = op.getLoc();
3431
3432 constexpr int32_t constlen = 4;
3433 Value consts[constlen];
3434 for (int64_t i = 0; i < constlen; ++i)
3435 consts[i] = createI32Constant(rewriter, loc, i);
3436
3437 constexpr int32_t sgprslen = constlen;
3438 Value sgprs[sgprslen];
3439 for (int64_t i = 0; i < sgprslen; ++i) {
3440 sgprs[i] = consts[0];
3441 }
3442
3443 sgprs[0] = consts[1];
3444
3445 if constexpr (BaseOp::isGather()) {
3446 sgprs[0] = setValueAtOffset(rewriter, loc, sgprs[0], consts[1], 30);
3447
3448 auto type = cast<TDMGatherBaseType>(op.getResult().getType());
3449 Type indexType = type.getIndexType();
3450 unsigned indexSize = indexType.getIntOrFloatBitWidth();
3451 assert(llvm::is_contained({16u, 32u}, indexSize) &&
3452 "expected index_size to be 16 or 32");
3453 unsigned idx = (indexSize / 16) - 1;
3454
3455 if (idx)
3456 sgprs[0] = setValueAtOffset(rewriter, loc, sgprs[0], consts[1], 31);
3457 }
3458
3459 ValueRange ldsIndices = adaptor.getLdsIndices();
3460 Value lds = adaptor.getLds();
3461 auto ldsMemRefType = cast<MemRefType>(op.getLds().getType());
3462
3464 rewriter, loc, ldsMemRefType, lds, ldsIndices);
3465
3466 ValueRange globalIndices = adaptor.getGlobalIndices();
3467 Value global = adaptor.getGlobal();
3468 auto globalMemRefType = cast<MemRefType>(op.getGlobal().getType());
3469
3471 rewriter, loc, globalMemRefType, global, globalIndices);
3472
3473 Type i32 = rewriter.getI32Type();
3474 Type i64 = rewriter.getI64Type();
3475
3476 sgprs[1] = LLVM::PtrToIntOp::create(rewriter, loc, i32, ldsPtr);
3477 Value castForGlobalAddr =
3478 LLVM::PtrToIntOp::create(rewriter, loc, i64, globalPtr);
3479
3480 sgprs[2] = LLVM::TruncOp::create(rewriter, loc, i32, castForGlobalAddr);
3481
3482 Value shift = LLVM::LShrOp::create(rewriter, loc, castForGlobalAddr,
3483 createI64Constant(rewriter, loc, 32));
3484
3485 Value highHalf = LLVM::TruncOp::create(rewriter, loc, i32, shift);
3486
3487 Value mask = createI32Constant(rewriter, loc, (1ull << 25) - 1);
3488 highHalf = LLVM::AndOp::create(rewriter, loc, highHalf, mask);
3489
3490 sgprs[3] = setValueAtOffset(rewriter, loc, highHalf, consts[2], 30);
3491
3492 Type v4i32 = this->typeConverter->convertType(VectorType::get(4, i32));
3493 assert(v4i32 && "expected type conversion to succeed");
3494 Value result = LLVM::PoisonOp::create(rewriter, loc, v4i32);
3495
3496 for (auto [sgpr, constant] : llvm::zip_equal(sgprs, consts))
3497 result =
3498 LLVM::InsertElementOp::create(rewriter, loc, result, sgpr, constant);
3499
3500 rewriter.replaceOp(op, result);
3501 return success();
3502 }
3503};
3504
3505template <typename DescriptorOp>
3506struct AMDGPULowerDescriptor : public ConvertOpToLLVMPattern<DescriptorOp> {
3507 using ConvertOpToLLVMPattern<DescriptorOp>::ConvertOpToLLVMPattern;
3508 using OpAdaptor = typename ConvertOpToLLVMPattern<DescriptorOp>::OpAdaptor;
3509
3510 AMDGPULowerDescriptor(const LLVMTypeConverter &converter, Chipset chipset)
3511 : ConvertOpToLLVMPattern<DescriptorOp>(converter), chipset(chipset) {}
3512 Chipset chipset;
3513
3514 Value getDGroup0(OpAdaptor adaptor) const { return adaptor.getBase(); }
3515
3516 Value setWorkgroupMask(DescriptorOp op, OpAdaptor adaptor,
3517 ConversionPatternRewriter &rewriter, Location loc,
3518 Value sgpr0) const {
3519 Value mask = op.getWorkgroupMask();
3520 if (!mask)
3521 return sgpr0;
3522
3523 Type i16 = rewriter.getI16Type();
3524 mask = LLVM::BitcastOp::create(rewriter, loc, i16, mask);
3525 Type i32 = rewriter.getI32Type();
3526 Value extendedMask = LLVM::ZExtOp::create(rewriter, loc, i32, mask);
3527 return setValueAtOffset(rewriter, loc, sgpr0, extendedMask, 0);
3528 }
3529
3530 Value setDataSize(DescriptorOp op, OpAdaptor adaptor,
3531 ConversionPatternRewriter &rewriter, Location loc,
3532 Value sgpr0, ArrayRef<Value> consts) const {
3533 unsigned elementTypeWidthInBits = op.getElementTypeWidth();
3534 assert(llvm::is_contained({8u, 16u, 32u, 64u}, elementTypeWidthInBits) &&
3535 "expected type width to be 8, 16, 32, or 64.");
3536 int64_t idx = llvm::Log2_32(elementTypeWidthInBits / 8);
3537 Value size = consts[idx];
3538 return setValueAtOffset(rewriter, loc, sgpr0, size, 16);
3539 }
3540
3541 Value setAtomicBarrier(DescriptorOp op, OpAdaptor adaptor,
3542 ConversionPatternRewriter &rewriter, Location loc,
3543 Value sgpr0, ArrayRef<Value> consts) const {
3544 if (!adaptor.getAtomicBarrierAddress())
3545 return sgpr0;
3546
3547 return setValueAtOffset(rewriter, loc, sgpr0, consts[1], 18);
3548 }
3549
3550 Value setIterateEnable(DescriptorOp op, OpAdaptor adaptor,
3551 ConversionPatternRewriter &rewriter, Location loc,
3552 Value sgpr0, ArrayRef<Value> consts) const {
3553 if (!adaptor.getGlobalIncrement())
3554 return sgpr0;
3555
3556 // Value is ignored when in gather mode.
3557 // TODO: emit error earlier?
3558 return setValueAtOffset(rewriter, loc, sgpr0, consts[1], 19);
3559 }
3560
3561 Value setPadEnable(DescriptorOp op, OpAdaptor adaptor,
3562 ConversionPatternRewriter &rewriter, Location loc,
3563 Value sgpr0, ArrayRef<Value> consts) const {
3564 if (!op.getPadAmount())
3565 return sgpr0;
3566
3567 return setValueAtOffset(rewriter, loc, sgpr0, consts[1], 20);
3568 }
3569
3570 Value setEarlyTimeout(DescriptorOp op, OpAdaptor adaptor,
3571 ConversionPatternRewriter &rewriter, Location loc,
3572 Value sgpr0, ArrayRef<Value> consts) const {
3573 if (!op.getWorkgroupMask())
3574 return sgpr0;
3575
3576 return setValueAtOffset(rewriter, loc, sgpr0, consts[1], 21);
3577 }
3578
3579 Value setPadInterval(DescriptorOp op, OpAdaptor adaptor,
3580 ConversionPatternRewriter &rewriter, Location loc,
3581 Value sgpr0, ArrayRef<Value> consts) const {
3582 if (!op.getPadAmount())
3583 return sgpr0;
3584
3585 // pre-condition: padInterval can be a power of two between 2 and 256.
3586 // TODO: Validation if the value breaks the pre-condition.
3587 // If the pre-condition fails, there is a possibility of
3588 // affecting the higher bits. In a following PR implement
3589 // RuntimeVerifiableOpInterface that instruments conditions that need to be
3590 // checked at runtime.
3591 IntegerType i32 = rewriter.getI32Type();
3592 Value padInterval = adaptor.getPadInterval();
3593 padInterval = LLVM::CountTrailingZerosOp::create(rewriter, loc, i32,
3594 padInterval, false);
3595 padInterval = LLVM::SubOp::create(rewriter, loc, padInterval, consts[1]);
3596 // post-condition: padInterval can be a value between 0 and 7.
3597 return setValueAtOffset(rewriter, loc, sgpr0, padInterval, 22);
3598 }
3599
3600 Value setPadAmount(DescriptorOp op, OpAdaptor adaptor,
3601 ConversionPatternRewriter &rewriter, Location loc,
3602 Value sgpr0, ArrayRef<Value> consts) const {
3603 if (!op.getPadAmount())
3604 return sgpr0;
3605
3606 // pre-condition: padAmount is a value between 1-128.
3607 // TODO: Validation if the value breaks the pre-condition.
3608 // If the pre-condition fails, there is a possibility of
3609 // affecting the higher bits. In a following PR implement
3610 // RuntimeVerifiableOpInterface that instruments conditions that need to be
3611 // checked at runtime.
3612 Value padAmount = adaptor.getPadAmount();
3613 padAmount = LLVM::SubOp::create(rewriter, loc, padAmount, consts[1]);
3614 // post-condition: padAmount is a value between 0-127.
3615 return setValueAtOffset(rewriter, loc, sgpr0, padAmount, 25);
3616 }
3617
3618 Value setAtomicBarrierAddress(DescriptorOp op, OpAdaptor adaptor,
3619 ConversionPatternRewriter &rewriter,
3620 Location loc, Value sgpr1,
3621 ArrayRef<Value> consts) const {
3622 if (!adaptor.getAtomicBarrierAddress())
3623 return sgpr1;
3624
3625 Value atomicBarrierAddress = adaptor.getAtomicBarrierAddress();
3626 auto barrierAddressTy =
3627 cast<MemRefType>(op.getAtomicBarrierAddress().getType());
3628 ValueRange atomicBarrierIndices = adaptor.getAtomicBarrierIndices();
3629 atomicBarrierAddress = ConvertToLLVMPattern::getStridedElementPtr(
3630 rewriter, loc, barrierAddressTy, atomicBarrierAddress,
3631 atomicBarrierIndices);
3632 IntegerType i32 = rewriter.getI32Type();
3633 // pre-condition: atomicBarrierAddress is aligned to 8 bytes which implies
3634 // that the 3 LSBs are zero.
3635 // TODO: Validation if the value breaks the pre-condition.
3636 // In a following PR implement RuntimeVerifiableOpInterface
3637 // that instruments conditions that need to be checked at runtime.
3638 atomicBarrierAddress =
3639 LLVM::PtrToIntOp::create(rewriter, loc, i32, atomicBarrierAddress);
3640 atomicBarrierAddress =
3641 LLVM::LShrOp::create(rewriter, loc, atomicBarrierAddress, consts[3]);
3642 Value mask = createI32Constant(rewriter, loc, 0xFFFF);
3643 atomicBarrierAddress =
3644 LLVM::AndOp::create(rewriter, loc, atomicBarrierAddress, mask);
3645 return setValueAtOffset(rewriter, loc, sgpr1, atomicBarrierAddress, 32);
3646 }
3647
3648 std::pair<Value, Value> setTensorDimX(DescriptorOp op, OpAdaptor adaptor,
3649 ConversionPatternRewriter &rewriter,
3650 Location loc, Value sgpr1, Value sgpr2,
3651 ArrayRef<Value> consts, uint64_t dimX,
3652 uint32_t offset) const {
3653 ArrayRef<int64_t> globalStaticSizes = adaptor.getGlobalStaticSizes();
3654 ValueRange globalDynamicSizes = adaptor.getGlobalDynamicSizes();
3655 SmallVector<OpFoldResult> mixedGlobalSizes =
3656 getMixedValues(globalStaticSizes, globalDynamicSizes, rewriter);
3657 if (mixedGlobalSizes.size() <= dimX)
3658 return {sgpr1, sgpr2};
3659
3660 OpFoldResult tensorDimXOpFoldResult = *(mixedGlobalSizes.rbegin() + dimX);
3661 // pre-condition: tensorDimX is less than 2^32-1
3662 // TODO: Validation if the value breaks the pre-condition.
3663 // In a following PR implement RuntimeVerifiableOpInterface that instruments
3664 // conditions that need to be checked at runtime. This could also be fixed
3665 // by saying that mixedGlobalSizes is a DynamicI32List.
3666 Value tensorDimX;
3667 if (auto attr = dyn_cast<Attribute>(tensorDimXOpFoldResult)) {
3668 tensorDimX =
3669 createI32Constant(rewriter, loc, cast<IntegerAttr>(attr).getInt());
3670 } else {
3671 IntegerType i32 = rewriter.getI32Type();
3672 tensorDimX = cast<Value>(tensorDimXOpFoldResult);
3673 tensorDimX = LLVM::TruncOp::create(rewriter, loc, i32, tensorDimX);
3674 }
3675
3676 sgpr1 = setValueAtOffset(rewriter, loc, sgpr1, tensorDimX, offset);
3677
3678 Value c16 = createI32Constant(rewriter, loc, 16);
3679 Value tensorDimXHigh = LLVM::LShrOp::create(rewriter, loc, tensorDimX, c16);
3680 sgpr2 = setValueAtOffset(rewriter, loc, sgpr2, tensorDimXHigh, offset + 16);
3681 return {sgpr1, sgpr2};
3682 }
3683
3684 std::pair<Value, Value> setTensorDim0(DescriptorOp op, OpAdaptor adaptor,
3685 ConversionPatternRewriter &rewriter,
3686 Location loc, Value sgpr1, Value sgpr2,
3687 ArrayRef<Value> consts) const {
3688 return setTensorDimX(op, adaptor, rewriter, loc, sgpr1, sgpr2, consts, 0,
3689 48);
3690 }
3691
3692 std::pair<Value, Value> setTensorDim1(DescriptorOp op, OpAdaptor adaptor,
3693 ConversionPatternRewriter &rewriter,
3694 Location loc, Value sgpr2, Value sgpr3,
3695 ArrayRef<Value> consts) const {
3696 return setTensorDimX(op, adaptor, rewriter, loc, sgpr2, sgpr3, consts, 1,
3697 80);
3698 }
3699
3700 Value setTileDimX(DescriptorOp op, OpAdaptor adaptor,
3701 ConversionPatternRewriter &rewriter, Location loc,
3702 Value sgpr, ArrayRef<Value> consts, size_t dimX,
3703 int64_t offset) const {
3704 ArrayRef<int64_t> sharedStaticSizes = adaptor.getSharedStaticSizes();
3705 ValueRange sharedDynamicSizes = adaptor.getSharedDynamicSizes();
3706 SmallVector<OpFoldResult> mixedSharedSizes =
3707 getMixedValues(sharedStaticSizes, sharedDynamicSizes, rewriter);
3708 if (mixedSharedSizes.size() <= dimX)
3709 return sgpr;
3710
3711 OpFoldResult tileDimXOpFoldResult = *(mixedSharedSizes.rbegin() + dimX);
3712 // pre-condition: tileDimX is less than 2^16-1
3713 // TODO: Validation if the value breaks the pre-condition.
3714 // If the pre-condition fails, there is a possibility of
3715 // affecting the higher bits. In a following PR implement
3716 // RuntimeVerifiableOpInterface that instruments conditions that need to be
3717 // checked at runtime. This could also be fixed by saying that
3718 // mixedSharedSizes is a DynamicI16List.
3719 Value tileDimX;
3720 if (auto attr = dyn_cast<Attribute>(tileDimXOpFoldResult)) {
3721 tileDimX =
3722 createI32Constant(rewriter, loc, cast<IntegerAttr>(attr).getInt());
3723 } else {
3724 IntegerType i32 = rewriter.getI32Type();
3725 tileDimX = cast<Value>(tileDimXOpFoldResult);
3726 tileDimX = LLVM::TruncOp::create(rewriter, loc, i32, tileDimX);
3727 }
3728
3729 return setValueAtOffset(rewriter, loc, sgpr, tileDimX, offset);
3730 }
3731
3732 Value setTileDim0(DescriptorOp op, OpAdaptor adaptor,
3733 ConversionPatternRewriter &rewriter, Location loc,
3734 Value sgpr3, ArrayRef<Value> consts) const {
3735 return setTileDimX(op, adaptor, rewriter, loc, sgpr3, consts, 0, 112);
3736 }
3737
3738 Value setTileDim1(DescriptorOp op, OpAdaptor adaptor,
3739 ConversionPatternRewriter &rewriter, Location loc,
3740 Value sgpr4, ArrayRef<Value> consts) const {
3741 return setTileDimX(op, adaptor, rewriter, loc, sgpr4, consts, 1, 128);
3742 }
3743
3744 Value setValidIndices(DescriptorOp op, OpAdaptor adaptor,
3745 ConversionPatternRewriter &rewriter, Location loc,
3746 Value sgpr4, ArrayRef<Value> consts) const {
3747 auto type = cast<VectorType>(op.getIndices().getType());
3748 ArrayRef<int64_t> shape = type.getShape();
3749 assert(shape.size() == 1 && "expected shape to be of rank 1.");
3750 unsigned length = shape.back();
3751 assert(0 < length && length <= 16 && "expected length to be at most 16.");
3752 Value value = createI32Constant(rewriter, loc, length);
3753 return setValueAtOffset(rewriter, loc, sgpr4, value, 128);
3754 }
3755
3756 Value setTileDim1OrValidIndices(DescriptorOp op, OpAdaptor adaptor,
3757 ConversionPatternRewriter &rewriter,
3758 Location loc, Value sgpr4,
3759 ArrayRef<Value> consts) const {
3760 if constexpr (DescriptorOp::isGather())
3761 return setValidIndices(op, adaptor, rewriter, loc, sgpr4, consts);
3762 return setTileDim1(op, adaptor, rewriter, loc, sgpr4, consts);
3763 }
3764
3765 Value setTileDim2(DescriptorOp op, OpAdaptor adaptor,
3766 ConversionPatternRewriter &rewriter, Location loc,
3767 Value sgpr4, ArrayRef<Value> consts) const {
3768 // Value is ignored when in gather mode.
3769 if constexpr (DescriptorOp::isGather())
3770 return sgpr4;
3771 return setTileDimX(op, adaptor, rewriter, loc, sgpr4, consts, 2, 144);
3772 }
3773
3774 std::pair<Value, Value>
3775 setTensorDimXStride(DescriptorOp op, OpAdaptor adaptor,
3776 ConversionPatternRewriter &rewriter, Location loc,
3777 Value sgprY, Value sgprZ, ArrayRef<Value> consts,
3778 size_t dimX, int64_t offset) const {
3779 ArrayRef<int64_t> globalStaticStrides = adaptor.getGlobalStaticStrides();
3780 ValueRange globalDynamicStrides = adaptor.getGlobalDynamicStrides();
3781 SmallVector<OpFoldResult> mixedGlobalStrides =
3782 getMixedValues(globalStaticStrides, globalDynamicStrides, rewriter);
3783
3784 if (mixedGlobalStrides.size() <= (dimX + 1))
3785 return {sgprY, sgprZ};
3786
3787 OpFoldResult tensorDimXStrideOpFoldResult =
3788 *(mixedGlobalStrides.rbegin() + dimX + 1);
3789 // pre-condition: tensorDimXStride is less than 2^48-1
3790 // TODO: Validation if the value breaks the pre-condition.
3791 // In a following PR implement RuntimeVerifiableOpInterface that instruments
3792 // conditions that need to be checked at runtime.
3793 Value tensorDimXStride;
3794 if (auto attr = dyn_cast<Attribute>(tensorDimXStrideOpFoldResult))
3795 tensorDimXStride =
3796 createI64Constant(rewriter, loc, cast<IntegerAttr>(attr).getInt());
3797 else
3798 tensorDimXStride = cast<Value>(tensorDimXStrideOpFoldResult);
3799
3800 constexpr int64_t first48bits = (1ll << 48) - 1;
3801 Value mask = createI64Constant(rewriter, loc, first48bits);
3802 tensorDimXStride =
3803 LLVM::AndOp::create(rewriter, loc, mask, tensorDimXStride);
3804 IntegerType i32 = rewriter.getI32Type();
3805 Value tensorDimXStrideLow =
3806 LLVM::TruncOp::create(rewriter, loc, i32, tensorDimXStride);
3807 sgprY = setValueAtOffset(rewriter, loc, sgprY, tensorDimXStrideLow, offset);
3808
3809 int64_t shift = (offset % 32) == 0 ? 32 : offset % 32;
3810 Value shiftVal = createI64Constant(rewriter, loc, shift);
3811 Value tensorDimXStrideHigh =
3812 LLVM::LShrOp::create(rewriter, loc, tensorDimXStride, shiftVal);
3813 tensorDimXStrideHigh =
3814 LLVM::TruncOp::create(rewriter, loc, i32, tensorDimXStrideHigh);
3815 sgprZ = setValueAtOffset(rewriter, loc, sgprZ, tensorDimXStrideHigh,
3816 offset + shift);
3817 return {sgprY, sgprZ};
3818 }
3819
3820 std::pair<Value, Value>
3821 setTensorDim0Stride(DescriptorOp op, OpAdaptor adaptor,
3822 ConversionPatternRewriter &rewriter, Location loc,
3823 Value sgpr5, Value sgpr6, ArrayRef<Value> consts) const {
3824 return setTensorDimXStride(op, adaptor, rewriter, loc, sgpr5, sgpr6, consts,
3825 0, 160);
3826 }
3827
3828 std::pair<Value, Value>
3829 setTensorDim1Stride(DescriptorOp op, OpAdaptor adaptor,
3830 ConversionPatternRewriter &rewriter, Location loc,
3831 Value sgpr5, Value sgpr6, ArrayRef<Value> consts) const {
3832 // Value is ignored when in gather mode.
3833 if constexpr (DescriptorOp::isGather())
3834 return {sgpr5, sgpr6};
3835 return setTensorDimXStride(op, adaptor, rewriter, loc, sgpr5, sgpr6, consts,
3836 1, 208);
3837 }
3838
3839 Value getDGroup1(DescriptorOp op, OpAdaptor adaptor,
3840 ConversionPatternRewriter &rewriter, Location loc,
3841 ArrayRef<Value> consts) const {
3842 Value sgprs[8];
3843 for (int64_t i = 0; i < 8; ++i) {
3844 sgprs[i] = consts[0];
3845 }
3846
3847 sgprs[0] = setWorkgroupMask(op, adaptor, rewriter, loc, sgprs[0]);
3848 sgprs[0] = setDataSize(op, adaptor, rewriter, loc, sgprs[0], consts);
3849 sgprs[0] = setAtomicBarrier(op, adaptor, rewriter, loc, sgprs[0], consts);
3850 sgprs[0] = setIterateEnable(op, adaptor, rewriter, loc, sgprs[0], consts);
3851 sgprs[0] = setPadEnable(op, adaptor, rewriter, loc, sgprs[0], consts);
3852 sgprs[0] = setEarlyTimeout(op, adaptor, rewriter, loc, sgprs[0], consts);
3853 sgprs[0] = setPadInterval(op, adaptor, rewriter, loc, sgprs[0], consts);
3854 sgprs[0] = setPadAmount(op, adaptor, rewriter, loc, sgprs[0], consts);
3855
3856 sgprs[1] =
3857 setAtomicBarrierAddress(op, adaptor, rewriter, loc, sgprs[1], consts);
3858 std::tie(sgprs[1], sgprs[2]) =
3859 setTensorDim0(op, adaptor, rewriter, loc, sgprs[1], sgprs[2], consts);
3860 std::tie(sgprs[2], sgprs[3]) =
3861 setTensorDim1(op, adaptor, rewriter, loc, sgprs[2], sgprs[3], consts);
3862
3863 sgprs[3] = setTileDim0(op, adaptor, rewriter, loc, sgprs[3], consts);
3864 sgprs[4] =
3865 setTileDim1OrValidIndices(op, adaptor, rewriter, loc, sgprs[4], consts);
3866 sgprs[4] = setTileDim2(op, adaptor, rewriter, loc, sgprs[4], consts);
3867 std::tie(sgprs[5], sgprs[6]) = setTensorDim0Stride(
3868 op, adaptor, rewriter, loc, sgprs[5], sgprs[6], consts);
3869 std::tie(sgprs[6], sgprs[7]) = setTensorDim1Stride(
3870 op, adaptor, rewriter, loc, sgprs[6], sgprs[7], consts);
3871
3872 IntegerType i32 = rewriter.getI32Type();
3873 Type v8i32 = this->typeConverter->convertType(VectorType::get(8, i32));
3874 assert(v8i32 && "expected type conversion to succeed");
3875 Value dgroup1 = LLVM::PoisonOp::create(rewriter, loc, v8i32);
3876
3877 for (auto [sgpr, constant] : llvm::zip_equal(sgprs, consts)) {
3878 dgroup1 =
3879 LLVM::InsertElementOp::create(rewriter, loc, dgroup1, sgpr, constant);
3880 }
3881
3882 return dgroup1;
3883 }
3884
3885 Value setTensorDimX(DescriptorOp op, OpAdaptor adaptor,
3886 ConversionPatternRewriter &rewriter, Location loc,
3887 Value sgpr0, ArrayRef<Value> consts, int64_t dimX,
3888 int64_t offset) const {
3889 ArrayRef<int64_t> globalStaticSizes = adaptor.getGlobalStaticSizes();
3890 ValueRange globalDynamicSizes = adaptor.getGlobalDynamicSizes();
3891 SmallVector<OpFoldResult> mixedGlobalSizes =
3892 getMixedValues(globalStaticSizes, globalDynamicSizes, rewriter);
3893 if (mixedGlobalSizes.size() <= static_cast<unsigned long>(dimX))
3894 return sgpr0;
3895
3896 OpFoldResult tensorDimXOpFoldResult = *(mixedGlobalSizes.rbegin() + dimX);
3897 Value tensorDimX;
3898 if (auto attr = dyn_cast<Attribute>(tensorDimXOpFoldResult)) {
3899 tensorDimX =
3900 createI32Constant(rewriter, loc, cast<IntegerAttr>(attr).getInt());
3901 } else {
3902 IntegerType i32 = rewriter.getI32Type();
3903 tensorDimX = cast<Value>(tensorDimXOpFoldResult);
3904 tensorDimX = LLVM::TruncOp::create(rewriter, loc, i32, tensorDimX);
3905 }
3906
3907 return setValueAtOffset(rewriter, loc, sgpr0, tensorDimX, offset);
3908 }
3909
3910 Value setTensorDim2(DescriptorOp op, OpAdaptor adaptor,
3911 ConversionPatternRewriter &rewriter, Location loc,
3912 Value sgpr0, ArrayRef<Value> consts) const {
3913 return setTensorDimX(op, adaptor, rewriter, loc, sgpr0, consts, 2, 0);
3914 }
3915
3916 Value truncateAndSetValueAtOffset(ConversionPatternRewriter &rewriter,
3917 Location loc, Value accumulator,
3918 Value value, int64_t shift) const {
3919
3920 IntegerType i32 = rewriter.getI32Type();
3921 value = LLVM::TruncOp::create(rewriter, loc, i32, value);
3922 return setValueAtOffset(rewriter, loc, accumulator, value, shift);
3923 }
3924
3925 Value setLDSAddrIncrement(DescriptorOp op, OpAdaptor adaptor,
3926 ConversionPatternRewriter &rewriter, Location loc,
3927 Value sgpr1, ArrayRef<Value> consts,
3928 int64_t offset) const {
3929 Value ldsAddrIncrement = adaptor.getLdsIncrement();
3930 return setValueAtOffset(rewriter, loc, sgpr1, ldsAddrIncrement, offset);
3931 }
3932
3933 std::pair<Value, Value>
3934 setGlobalAddrIncrement(DescriptorOp op, OpAdaptor adaptor,
3935 ConversionPatternRewriter &rewriter, Location loc,
3936 Value sgpr2, Value sgpr3, ArrayRef<Value> consts,
3937 int64_t offset) const {
3938 Value globalAddrIncrement = adaptor.getGlobalIncrement();
3939 sgpr2 = truncateAndSetValueAtOffset(rewriter, loc, sgpr2,
3940 globalAddrIncrement, offset);
3941 Value shift = createI64Constant(rewriter, loc, 32);
3942 globalAddrIncrement =
3943 LLVM::LShrOp::create(rewriter, loc, globalAddrIncrement, shift);
3944 constexpr int64_t first16BitsHigh = (1ll << 16) - 1;
3945 sgpr3 = truncateAndSetValueAtOffset(rewriter, loc, sgpr3,
3946 globalAddrIncrement, offset + 32);
3947 Value mask = createI32Constant(rewriter, loc, first16BitsHigh);
3948 sgpr3 = LLVM::AndOp::create(rewriter, loc, sgpr3, mask);
3949 return {sgpr2, sgpr3};
3950 }
3951
3952 Value setTensorDim3OrLDSAddrIncrement(DescriptorOp op, OpAdaptor adaptor,
3953 ConversionPatternRewriter &rewriter,
3954 Location loc, Value sgpr1,
3955 ArrayRef<Value> consts) const {
3956 Value ldsIncrement = op.getLdsIncrement();
3957 constexpr int64_t dim = 3;
3958 constexpr int64_t offset = 32;
3959 if (!ldsIncrement)
3960 return setTensorDimX(op, adaptor, rewriter, loc, sgpr1, consts, dim,
3961 offset);
3962 return setLDSAddrIncrement(op, adaptor, rewriter, loc, sgpr1, consts,
3963 offset);
3964 }
3965
3966 std::pair<Value, Value> setTensorDim2StrideOrGlobalAddrIncrement(
3967 DescriptorOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter,
3968 Location loc, Value sgpr2, Value sgpr3, ArrayRef<Value> consts) const {
3969 Value globalIncrement = op.getGlobalIncrement();
3970 constexpr int32_t dim = 2;
3971 constexpr int32_t offset = 64;
3972 if (!globalIncrement)
3973 return setTensorDimXStride(op, adaptor, rewriter, loc, sgpr2, sgpr3,
3974 consts, dim, offset);
3975 return setGlobalAddrIncrement(op, adaptor, rewriter, loc, sgpr2, sgpr3,
3976 consts, offset);
3977 }
3978
3979 Value setIterateCount(DescriptorOp op, OpAdaptor adaptor,
3980 ConversionPatternRewriter &rewriter, Location loc,
3981 Value sgpr3, ArrayRef<Value> consts,
3982 int32_t offset) const {
3983 Value iterationCount = adaptor.getIterationCount();
3984 IntegerType i32 = rewriter.getI32Type();
3985 // pre-condition: iterationCount is in the inclusive interval [1, 256].
3986 // TODO: validation if the value breaks the pre-condition.
3987 // If the pre-condition fails, there is a possibility of
3988 // affecting the higher bits. In a following PR implement
3989 // RuntimeVerifiableOpInterface that instruments conditions that need to be
3990 // checked at runtime.
3991 iterationCount = LLVM::TruncOp::create(rewriter, loc, i32, iterationCount);
3992 iterationCount =
3993 LLVM::SubOp::create(rewriter, loc, iterationCount, consts[1]);
3994 return setValueAtOffset(rewriter, loc, sgpr3, iterationCount, offset);
3995 }
3996
3997 Value setTileDim3OrIterateCount(DescriptorOp op, OpAdaptor adaptor,
3998 ConversionPatternRewriter &rewriter,
3999 Location loc, Value sgpr3,
4000 ArrayRef<Value> consts) const {
4001 Value iterateCount = op.getIterationCount();
4002 constexpr int32_t dim = 2;
4003 constexpr int32_t offset = 112;
4004 if (!iterateCount)
4005 return setTileDimX(op, adaptor, rewriter, loc, sgpr3, consts, dim,
4006 offset);
4007
4008 return setIterateCount(op, adaptor, rewriter, loc, sgpr3, consts, offset);
4009 }
4010
4011 Value getDGroup2(DescriptorOp op, OpAdaptor adaptor,
4012 ConversionPatternRewriter &rewriter, Location loc,
4013 ArrayRef<Value> consts) const {
4014 if constexpr (DescriptorOp::isGather())
4015 return getDGroup2Gather(op, adaptor, rewriter, loc, consts);
4016 return getDGroup2NonGather(op, adaptor, rewriter, loc, consts);
4017 }
4018
4019 Value getDGroup2NonGather(DescriptorOp op, OpAdaptor adaptor,
4020 ConversionPatternRewriter &rewriter, Location loc,
4021 ArrayRef<Value> consts) const {
4022 IntegerType i32 = rewriter.getI32Type();
4023 Type v4i32 = this->typeConverter->convertType(VectorType::get(4, i32));
4024 assert(v4i32 && "expected type conversion to succeed.");
4025
4026 bool onlyNeedsTwoDescriptors = !op.getLdsIncrement() && op.getRank() <= 2;
4027 if (onlyNeedsTwoDescriptors)
4028 return LLVM::ZeroOp::create(rewriter, loc, v4i32);
4029
4030 constexpr int64_t sgprlen = 4;
4031 Value sgprs[sgprlen];
4032 for (int i = 0; i < sgprlen; ++i)
4033 sgprs[i] = consts[0];
4034
4035 sgprs[0] = setTensorDim2(op, adaptor, rewriter, loc, sgprs[0], consts);
4036 sgprs[1] = setTensorDim3OrLDSAddrIncrement(op, adaptor, rewriter, loc,
4037 sgprs[1], consts);
4038 std::tie(sgprs[2], sgprs[3]) = setTensorDim2StrideOrGlobalAddrIncrement(
4039 op, adaptor, rewriter, loc, sgprs[2], sgprs[3], consts);
4040 sgprs[3] =
4041 setTileDim3OrIterateCount(op, adaptor, rewriter, loc, sgprs[3], consts);
4042
4043 Value dgroup2 = LLVM::PoisonOp::create(rewriter, loc, v4i32);
4044 for (auto [sgpr, constant] : llvm::zip(sgprs, consts))
4045 dgroup2 =
4046 LLVM::InsertElementOp::create(rewriter, loc, dgroup2, sgpr, constant);
4047
4048 return dgroup2;
4049 }
4050
4051 Value getGatherIndices(DescriptorOp op, OpAdaptor adaptor,
4052 ConversionPatternRewriter &rewriter, Location loc,
4053 ArrayRef<Value> consts, bool firstHalf) const {
4054 IntegerType i32 = rewriter.getI32Type();
4055 Type v4i32 = this->typeConverter->convertType(VectorType::get(4, i32));
4056 assert(v4i32 && "expected type conversion to succeed.");
4057
4058 Value indices = adaptor.getIndices();
4059 auto vectorType = cast<VectorType>(indices.getType());
4060 unsigned length = vectorType.getShape().back();
4061 Type elementType = vectorType.getElementType();
4062 unsigned maxLength = elementType == i32 ? 4 : 8;
4063 int32_t offset = firstHalf ? 0 : maxLength;
4064 unsigned discountedLength =
4065 std::max(static_cast<int32_t>(length - offset), 0);
4066
4067 unsigned targetSize = std::min(maxLength, discountedLength);
4068
4069 SmallVector<Value> indicesVector;
4070 for (unsigned i = offset; i < targetSize + offset; ++i) {
4071 Value idx;
4072 if (i < consts.size())
4073 idx = consts[i];
4074 else
4075 idx = createI32Constant(rewriter, loc, i);
4076 Value elem = LLVM::ExtractElementOp::create(rewriter, loc, indices, idx);
4077 indicesVector.push_back(elem);
4078 }
4079
4080 SmallVector<Value> indicesI32Vector;
4081 if (elementType == i32) {
4082 indicesI32Vector = indicesVector;
4083 } else {
4084 for (unsigned i = 0; i < targetSize; ++i) {
4085 Value index = indicesVector[i];
4086 indicesI32Vector.push_back(
4087 LLVM::ZExtOp::create(rewriter, loc, i32, index));
4088 }
4089 if ((targetSize % 2) != 0)
4090 // Add padding when not divisible by two.
4091 indicesI32Vector.push_back(consts[0]);
4092 }
4093
4094 SmallVector<Value> indicesToInsert;
4095 if (elementType == i32) {
4096 indicesToInsert = indicesI32Vector;
4097 } else {
4098 unsigned size = indicesI32Vector.size() / 2;
4099 for (unsigned i = 0; i < size; ++i) {
4100 Value first = indicesI32Vector[2 * i];
4101 Value second = indicesI32Vector[2 * i + 1];
4102 Value joined = setValueAtOffset(rewriter, loc, first, second, 16);
4103 indicesToInsert.push_back(joined);
4104 }
4105 }
4106
4107 Value dgroup = LLVM::PoisonOp::create(rewriter, loc, v4i32);
4108 for (auto [sgpr, constant] : llvm::zip_first(indicesToInsert, consts))
4109 dgroup =
4110 LLVM::InsertElementOp::create(rewriter, loc, dgroup, sgpr, constant);
4111
4112 return dgroup;
4113 }
4114
4115 Value getDGroup2Gather(DescriptorOp op, OpAdaptor adaptor,
4116 ConversionPatternRewriter &rewriter, Location loc,
4117 ArrayRef<Value> consts) const {
4118 return getGatherIndices(op, adaptor, rewriter, loc, consts, true);
4119 }
4120
4121 std::pair<Value, Value>
4122 setTensorDim3Stride(DescriptorOp op, OpAdaptor adaptor,
4123 ConversionPatternRewriter &rewriter, Location loc,
4124 Value sgpr0, Value sgpr1, ArrayRef<Value> consts) const {
4125 constexpr int32_t dim = 3;
4126 constexpr int32_t offset = 0;
4127 return setTensorDimXStride(op, adaptor, rewriter, loc, sgpr0, sgpr1, consts,
4128 dim, offset);
4129 }
4130
4131 std::pair<Value, Value> setTensorDim4(DescriptorOp op, OpAdaptor adaptor,
4132 ConversionPatternRewriter &rewriter,
4133 Location loc, Value sgpr1, Value sgpr2,
4134 ArrayRef<Value> consts) const {
4135 constexpr int32_t dim = 4;
4136 constexpr int32_t offset = 48;
4137 return setTensorDimX(op, adaptor, rewriter, loc, sgpr1, sgpr2, consts, dim,
4138 offset);
4139 }
4140
4141 Value setTileDim4(DescriptorOp op, OpAdaptor adaptor,
4142 ConversionPatternRewriter &rewriter, Location loc,
4143 Value sgpr2, ArrayRef<Value> consts) const {
4144 constexpr int32_t dim = 4;
4145 constexpr int32_t offset = 80;
4146 return setTileDimX(op, adaptor, rewriter, loc, sgpr2, consts, dim, offset);
4147 }
4148
4149 Value getDGroup3(DescriptorOp op, OpAdaptor adaptor,
4150 ConversionPatternRewriter &rewriter, Location loc,
4151 ArrayRef<Value> consts) const {
4152 if constexpr (DescriptorOp::isGather())
4153 return getDGroup3Gather(op, adaptor, rewriter, loc, consts);
4154 return getDGroup3NonGather(op, adaptor, rewriter, loc, consts);
4155 }
4156
4157 Value getDGroup3NonGather(DescriptorOp op, OpAdaptor adaptor,
4158 ConversionPatternRewriter &rewriter, Location loc,
4159 ArrayRef<Value> consts) const {
4160 IntegerType i32 = rewriter.getI32Type();
4161 Type v4i32 = this->typeConverter->convertType(VectorType::get(4, i32));
4162 assert(v4i32 && "expected type conversion to succeed.");
4163 bool onlyNeedsTwoDescriptors = !op.getLdsIncrement() && op.getRank() <= 2;
4164 if (onlyNeedsTwoDescriptors)
4165 return LLVM::ZeroOp::create(rewriter, loc, v4i32);
4166
4167 constexpr int32_t sgprlen = 4;
4168 Value sgprs[sgprlen];
4169 for (int i = 0; i < sgprlen; ++i)
4170 sgprs[i] = consts[0];
4171
4172 std::tie(sgprs[0], sgprs[1]) = setTensorDim3Stride(
4173 op, adaptor, rewriter, loc, sgprs[0], sgprs[1], consts);
4174 std::tie(sgprs[1], sgprs[2]) =
4175 setTensorDim4(op, adaptor, rewriter, loc, sgprs[1], sgprs[2], consts);
4176 sgprs[2] = setTileDim4(op, adaptor, rewriter, loc, sgprs[2], consts);
4177
4178 Value dgroup3 = LLVM::PoisonOp::create(rewriter, loc, v4i32);
4179 for (auto [sgpr, constant] : llvm::zip(sgprs, consts))
4180 dgroup3 =
4181 LLVM::InsertElementOp::create(rewriter, loc, dgroup3, sgpr, constant);
4182
4183 return dgroup3;
4184 }
4185
4186 Value getDGroup3Gather(DescriptorOp op, OpAdaptor adaptor,
4187 ConversionPatternRewriter &rewriter, Location loc,
4188 ArrayRef<Value> consts) const {
4189 return getGatherIndices(op, adaptor, rewriter, loc, consts, false);
4190 }
4191
4192 LogicalResult
4193 matchAndRewrite(DescriptorOp op, OpAdaptor adaptor,
4194 ConversionPatternRewriter &rewriter) const override {
4195 if (chipset < kGfx1250)
4196 return op->emitOpError(
4197 "make_dma_descriptor is only supported on gfx1250");
4198
4199 Location loc = op.getLoc();
4200
4201 SmallVector<Value> consts;
4202 for (int64_t i = 0; i < 8; ++i)
4203 consts.push_back(createI32Constant(rewriter, loc, i));
4204
4205 Value dgroup0 = this->getDGroup0(adaptor);
4206 Value dgroup1 = this->getDGroup1(op, adaptor, rewriter, loc, consts);
4207 Value dgroup2 = this->getDGroup2(op, adaptor, rewriter, loc, consts);
4208 Value dgroup3 = this->getDGroup3(op, adaptor, rewriter, loc, consts);
4209 SmallVector<Value> results = {dgroup0, dgroup1, dgroup2, dgroup3};
4210 rewriter.replaceOpWithMultiple(op, {results});
4211 return success();
4212 }
4213};
4214
4215template <typename SourceOp, typename TargetOp>
4216struct AMDGPUTensorLoadStoreOpLowering
4217 : public ConvertOpToLLVMPattern<SourceOp> {
4218 using ConvertOpToLLVMPattern<SourceOp>::ConvertOpToLLVMPattern;
4220 AMDGPUTensorLoadStoreOpLowering(const LLVMTypeConverter &converter,
4221 Chipset chipset)
4222 : ConvertOpToLLVMPattern<SourceOp>(converter), chipset(chipset) {}
4223 Chipset chipset;
4224
4225 LogicalResult
4226 matchAndRewrite(SourceOp op, Adaptor adaptor,
4227 ConversionPatternRewriter &rewriter) const override {
4228 if (chipset < kGfx1250)
4229 return op->emitOpError("is only supported on gfx1250");
4230
4231 ValueRange desc = adaptor.getDesc();
4232 // Create a <v8 x i32> 0 as the fifth argument to match llvm intrinsic. It
4233 // will move into the TDM descriptor once it becomes relevant for future use
4234 auto v8i32 = VectorType::get(8, rewriter.getI32Type());
4235 Value dgroup4 = LLVM::ZeroOp::create(rewriter, op.getLoc(), v8i32);
4236 rewriter.replaceOpWithNewOp<TargetOp>(op, desc[0], desc[1], desc[2],
4237 desc[3], dgroup4, /*cachePolicy=*/0,
4238 /*alias_scopes=*/nullptr,
4239 /*noalias_scopes=*/nullptr,
4240 /*tbaa=*/nullptr);
4241 return success();
4242 }
4243};
4244
4245struct GlobalPrefetchOpLowering
4246 : public ConvertOpToLLVMPattern<GlobalPrefetchOp> {
4247 GlobalPrefetchOpLowering(const LLVMTypeConverter &converter, Chipset chipset)
4248 : ConvertOpToLLVMPattern<GlobalPrefetchOp>(converter), chipset(chipset) {}
4249
4250 LogicalResult
4251 matchAndRewrite(GlobalPrefetchOp op, GlobalPrefetchOpAdaptor adaptor,
4252 ConversionPatternRewriter &rewriter) const override {
4253 if (chipset < kGfx1250)
4254 return op->emitOpError("is only supported on gfx1250+");
4255
4256 const bool isSpeculative = op.getSpeculative();
4257 const int32_t immArgValue = getGlobalPrefetchLLVMEncoding(
4258 op.getTemporalHint(), op.getCacheScope(), isSpeculative);
4259 IntegerAttr immArgAttr = rewriter.getI32IntegerAttr(immArgValue);
4260
4261 ValueRange indices = adaptor.getIndices();
4262 Value memRef = adaptor.getSrc();
4263 MemRefDescriptor descriptor(memRef);
4264 MemRefType memRefType = op.getSrc().getType();
4265 Location loc = op->getLoc();
4266 auto inboundsFlags = isSpeculative ? LLVM::GEPNoWrapFlags::none
4267 : LLVM::GEPNoWrapFlags::inbounds |
4268 LLVM::GEPNoWrapFlags::nuw;
4269 Value prefetchPtr = getStridedElementPtr(
4270 rewriter, loc, memRefType, descriptor, indices, inboundsFlags);
4271
4272 rewriter.replaceOpWithNewOp<ROCDL::GlobalPrefetchOp>(
4273 op, prefetchPtr, immArgAttr, mlir::ArrayAttr{}, mlir::ArrayAttr{},
4274 mlir::ArrayAttr{});
4275 return success();
4276 }
4277
4278private:
4279 Chipset chipset;
4280};
4281
4282struct ConvertAMDGPUToROCDLPass
4283 : public impl::ConvertAMDGPUToROCDLPassBase<ConvertAMDGPUToROCDLPass> {
4284 using Base::Base;
4285
4286 void runOnOperation() override {
4287 MLIRContext *ctx = &getContext();
4288 FailureOr<Chipset> maybeChipset = Chipset::parse(chipset);
4289 if (failed(maybeChipset)) {
4290 emitError(UnknownLoc::get(ctx), "Invalid chipset name: " + chipset);
4291 return signalPassFailure();
4292 }
4293
4294 RewritePatternSet patterns(ctx);
4295 LLVMTypeConverter converter(ctx);
4296
4297 populateAMDGPUToROCDLConversionPatterns(converter, patterns, *maybeChipset);
4298 amdgpu::populateCommonGPUTypeAndAttributeConversions(converter);
4299 LLVMConversionTarget target(getContext());
4300 target.addIllegalDialect<::mlir::amdgpu::AMDGPUDialect>();
4301 target.addLegalDialect<::mlir::LLVM::LLVMDialect>();
4302 target.addLegalDialect<::mlir::ROCDL::ROCDLDialect>();
4303 if (failed(applyPartialConversion(getOperation(), target,
4304 std::move(patterns))))
4305 signalPassFailure();
4306 }
4307};
4308} // namespace
4309
4311 TypeConverter &typeConverter) {
4313 typeConverter, [](gpu::AddressSpace space) {
4314 switch (space) {
4315 case gpu::AddressSpace::Global:
4316 return ROCDL::ROCDLDialect::kGlobalMemoryAddressSpace;
4317 case gpu::AddressSpace::Workgroup:
4318 return ROCDL::ROCDLDialect::kSharedMemoryAddressSpace;
4319 case gpu::AddressSpace::Private:
4320 return ROCDL::ROCDLDialect::kPrivateMemoryAddressSpace;
4321 case gpu::AddressSpace::Constant:
4322 return ROCDL::ROCDLDialect::kConstantMemoryAddressSpace;
4323 }
4324 llvm_unreachable("unknown address space enum value");
4325 });
4326}
4327
4329 TypeConverter &typeConverter) {
4330 typeConverter.addTypeAttributeConversion(
4331 [](BaseMemRefType type, amdgpu::AddressSpaceAttr as)
4332 -> TypeConverter::AttributeConversionResult {
4333 MLIRContext *ctx = as.getContext();
4334 Type i64 = IntegerType::get(ctx, 64);
4335 switch (as.getValue()) {
4336 case amdgpu::AddressSpace::FatRawBuffer:
4337 return IntegerAttr::get(i64, 7);
4338 case amdgpu::AddressSpace::BufferRsrc:
4339 return IntegerAttr::get(i64, 8);
4340 case amdgpu::AddressSpace::FatStructuredBuffer:
4341 return IntegerAttr::get(i64, 9);
4342 }
4343 return TypeConverter::AttributeConversionResult::abort();
4344 });
4345 typeConverter.addConversion([&](DsBarrierStateType type) -> Type {
4346 return IntegerType::get(type.getContext(), 64);
4347 });
4348 typeConverter.addConversion([&](TDMBaseType type) -> Type {
4349 Type i32 = IntegerType::get(type.getContext(), 32);
4350 return typeConverter.convertType(VectorType::get(4, i32));
4351 });
4352 typeConverter.addConversion([&](TDMGatherBaseType type) -> Type {
4353 Type i32 = IntegerType::get(type.getContext(), 32);
4354 return typeConverter.convertType(VectorType::get(4, i32));
4355 });
4356 typeConverter.addConversion(
4357 [&](TDMDescriptorType type,
4358 SmallVectorImpl<Type> &result) -> std::optional<LogicalResult> {
4359 Type i32 = IntegerType::get(type.getContext(), 32);
4360 Type v4i32 = typeConverter.convertType(VectorType::get(4, i32));
4361 Type v8i32 = typeConverter.convertType(VectorType::get(8, i32));
4362 llvm::append_values(result, v4i32, v8i32, v4i32, v4i32);
4363 return success();
4364 });
4365
4366 auto addUnrealizedCast = [](OpBuilder &builder, TypeRange types,
4367 ValueRange inputs,
4369 // Only create unrealized_conversion_cast for TDMDescriptorType.
4370 // All other types which are not expected, should be
4371 // materialized by other target materialization functions.
4372 if (inputs.size() != 1)
4373 return {};
4374
4375 if (!isa<TDMDescriptorType>(inputs[0].getType()))
4376 return {};
4377
4378 auto cast = UnrealizedConversionCastOp::create(builder, loc, types, inputs);
4379 return cast.getResults();
4380 };
4381
4382 typeConverter.addTargetMaterialization(addUnrealizedCast);
4383}
4384
4386 RewritePatternSet &patterns,
4387 Chipset chipset) {
4389 patterns
4390 .add<FatRawBufferCastLowering,
4391 RawBufferOpLowering<RawBufferLoadOp, ROCDL::RawPtrBufferLoadOp>,
4392 RawBufferOpLowering<RawBufferStoreOp, ROCDL::RawPtrBufferStoreOp>,
4393 RawBufferOpLowering<RawBufferAtomicFaddOp,
4394 ROCDL::RawPtrBufferAtomicFaddOp>,
4395 RawBufferOpLowering<RawBufferAtomicFmaxOp,
4396 ROCDL::RawPtrBufferAtomicFmaxOp>,
4397 RawBufferOpLowering<RawBufferAtomicSmaxOp,
4398 ROCDL::RawPtrBufferAtomicSmaxOp>,
4399 RawBufferOpLowering<RawBufferAtomicUminOp,
4400 ROCDL::RawPtrBufferAtomicUminOp>,
4401 RawBufferOpLowering<RawBufferAtomicCmpswapOp,
4402 ROCDL::RawPtrBufferAtomicCmpSwap>,
4403 AMDGPUDPPLowering, MemoryCounterWaitOpLowering, LDSBarrierOpLowering,
4404 SchedBarrierOpLowering, MFMAOpLowering, ScaledMFMAOpLowering,
4405 SparseMFMAOpLowering, WMMAOpLowering, ScaledWMMAOpLowering,
4406 SparseWMMAOpLowering, DotOpLowering, ExtPackedFp8OpLowering,
4407 ScaledExtPackedMatrixOpLowering, ScaledExtPackedOpLowering,
4408 PackedScaledTruncOpLowering, PackedTrunc2xFp8OpLowering,
4409 PackedStochRoundFp8OpLowering, GatherToLDSOpLowering,
4410 GlobalLoadAsyncToLDSOpLowering, TransposeLoadOpLowering,
4411 AMDGPUPermlaneLowering, AMDGPUMakeDmaBaseLowering<MakeDmaBaseOp>,
4412 AMDGPUMakeDmaBaseLowering<MakeGatherDmaBaseOp>,
4413 AMDGPULowerDescriptor<MakeDmaDescriptorOp>,
4414 AMDGPULowerDescriptor<MakeGatherDmaDescriptorOp>,
4415 AMDGPUTensorLoadStoreOpLowering<TensorLoadToLDSOp,
4416 ROCDL::TensorLoadToLDSOp>,
4417 AMDGPUTensorLoadStoreOpLowering<TensorStoreFromLDSOp,
4418 ROCDL::TensorStoreFromLDSOp>,
4419 DsBarrierInitOpLowering, DsBarrierPollStateOpLowering,
4420 DsAsyncBarrierArriveOpLowering, DsBarrierArriveOpLowering,
4421 GlobalPrefetchOpLowering>(converter, chipset);
4422 patterns.add<AMDGPUSwizzleBitModeLowering, DsBarrierStatePhaseOpLowering,
4423 DsBarrierStatePendingCountOpLowering,
4424 DsBarrierStateInitCountOpLowering,
4425 DsBarrierStatePhaseParityLowering>(converter);
4426}
static bool typeIsExpectedFp8ForChipset(Chipset chipset, Type type)
Return true if type is the E4M3FN variant of an 8-bit float that is supported by the _fp8 instruction...
constexpr Chipset kGfx942
static std::optional< StringRef > wmmaOpToIntrinsicRDNA(Type elemSourceType, Type elemBSourceType, Type elemDestType, uint32_t k, bool isRDNA3)
Returns the rocdl intrinsic corresponding to a WMMA operation wmma for RDNA3/4 architectures.
static bool hasDot10Insts(const Chipset &chipset)
static bool hasDot7Insts(const Chipset &chipset)
static std::optional< SparseWMMAOpInfo > sparseWMMAOpToIntrinsic(SparseWMMAOp swmmac, Chipset chipset)
static std::optional< std::tuple< StringRef, uint32_t, uint32_t > > mfmaOpToScaledIntrinsic(Type aType, Type bType, Type destType, uint32_t m, uint32_t n, uint32_t k, uint32_t b, Chipset chipset)
If there is a scaled MFMA instruction for the input element types aType and bType,...
static std::optional< StringRef > mfmaOpToIntrinsic(MFMAOp mfma, Chipset chipset)
Return the rocdl intrinsic corresponding to a MFMA operation mfma if one exists.
constexpr Chipset kGfx908
static void wmmaPushInputOperand(ConversionPatternRewriter &rewriter, Location loc, const TypeConverter *typeConverter, bool isUnsigned, Value llvmInput, Value mlirInput, SmallVectorImpl< Value > &operands, SmallVectorImpl< NamedAttribute > &attrs, StringRef attrName)
Push an input operand.
constexpr Chipset kGfx1250
static Value castScaleOperand(ConversionPatternRewriter &rewriter, Location loc, Value input)
Converts the scaled MFMA/WMMA operands, scalesA and scalesB, from MLIR AMDGPU dialect convention to R...
constexpr Chipset kGfx90a
static std::optional< StringRef > getScaledWmmaIntrinsicName(int64_t m, int64_t n, int64_t k, bool isScale16)
Determines the ROCDL intrinsic name for scaled WMMA based on dimensions and scale block size (16 or 3...
static void wmmaPushOutputOperand(ConversionPatternRewriter &rewriter, Location loc, const TypeConverter *typeConverter, Value output, int32_t subwordOffset, bool clamp, SmallVectorImpl< Value > &operands, SmallVectorImpl< NamedAttribute > &attrs)
Push the output operand.
static bool typeIsExpectedBf8ForChipset(Chipset chipset, Type type)
Return true if type is the E5M2 variant of an 8-bit float that is supported by the _bf8 instructions ...
static std::optional< StringRef > wmmaOpToIntrinsic(WMMAOp wmma, Chipset chipset)
Returns the rocdl intrinsic corresponding to a WMMA operation wmma if one exists.
static bool hasDot11Insts(const Chipset &chipset)
static std::optional< StringRef > smfmacOpToIntrinsic(SparseMFMAOp op, Chipset chipset)
Returns the rocdl intrinsic corresponding to a SparseMFMA (smfmac) operation if one exists.
static Value makeBufferRsrc(ConversionPatternRewriter &rewriter, Location loc, Value basePointer, Value numRecords, bool boundsCheck, amdgpu::Chipset chipset, Value cacheSwizzleStride=nullptr, unsigned addressSpace=8)
static Value createI64Constant(ConversionPatternRewriter &rewriter, Location loc, int64_t value)
static bool hasDot9Insts(const Chipset &chipset)
static std::optional< StringRef > wmmaOpToIntrinsicGfx1250(Type elemSourceType, Type elemBSourceType, Type elemDestType, uint32_t k)
Return the rocdl intrinsic corresponding to a WMMA operation wmma for the gfx1250 architecture.
static Value getNumRecords(ConversionPatternRewriter &rewriter, Location loc, MemRefType memrefType, MemRefDescriptor &memrefDescriptor, ArrayRef< int64_t > strides, int64_t elementByteWidth, amdgpu::Chipset chipset, bool boundsCheck)
Compute the contents of the num_records field for a given memref descriptor - that is,...
static Value packSmallFloatVectorOperand(ConversionPatternRewriter &rewriter, Location loc, Value input, bool allowBf16=true)
Pack small float vector operands (fp4/fp6/fp8/bf16) into the format expected by scaled matrix multipl...
static std::optional< uint32_t > getWmmaScaleFormat(Type elemType)
Maps f8 scale element types to WMMA scale format codes.
static Value convertPackedVectorOperand(ConversionPatternRewriter &rewriter, Location loc, Value input, bool allowBf16=true)
Converts packed vector operands to the expected ROCDL types.
static Value getLinearIndexI32(ConversionPatternRewriter &rewriter, Location loc, MemRefDescriptor &memRefDescriptor, ValueRange indices, ArrayRef< int64_t > strides)
Returns the linear index used to access an element in the memref.
static Value convertUnsignedToI32(ConversionPatternRewriter &rewriter, Location loc, Value val)
Convert an unsigned number val to i32.
static bool hasDot8Insts(const Chipset &chipset)
static bool hasDot2Insts(const Chipset &chipset)
static Value createI32Constant(ConversionPatternRewriter &rewriter, Location loc, int32_t value)
static bool hasDot12Insts(const Chipset &chipset)
static std::optional< uint32_t > smallFloatTypeToFormatCode(Type mlirElemType)
static Value convertUnsignedToI64(ConversionPatternRewriter &rewriter, Location loc, Value val)
Convert an unsigned number val to i64.
constexpr Chipset kGfx950
static bool hasDot1Insts(const Chipset &chipset)
return success()
b
Return true if permutation is a valid permutation of the outer_dims_perm (case OuterOrInnerPerm::Oute...
ArrayAttr()
b getContext())
auto load
*if copies could not be generated due to yet unimplemented cases *copyInPlacementStart and copyOutPlacementStart in copyPlacementBlock *specify the insertion points where the incoming copies and outgoing should be the output argument nBegin is set to its * replacement(set to `begin` if no invalidation happens). Since outgoing *copies could have been inserted at `end`
static constexpr unsigned kSizePosInMemRefDescriptor
static constexpr unsigned kStridePosInMemRefDescriptor
static constexpr unsigned kOffsetPosInMemRefDescriptor
static constexpr unsigned kAllocatedPtrPosInMemRefDescriptor
static constexpr unsigned kAlignedPtrPosInMemRefDescriptor
static Value clamp(ImplicitLocOpBuilder &builder, Value value, Value lowerBound, Value upperBound)
This class provides a shared interface for ranked and unranked memref types.
Utility class for operation conversions targeting the LLVM dialect that match exactly one source oper...
Definition Pattern.h:227
ConvertOpToLLVMPattern(const LLVMTypeConverter &typeConverter, PatternBenefit benefit=1)
Definition Pattern.h:233
typename SourceOp::template GenericAdaptor< ArrayRef< ValueRange > > OneToNOpAdaptor
Definition Pattern.h:230
typename SourceOp::Adaptor OpAdaptor
Definition Pattern.h:229
Value getStridedElementPtr(ConversionPatternRewriter &rewriter, Location loc, MemRefType type, Value memRefDesc, ValueRange indices, LLVM::GEPNoWrapFlags noWrapFlags=LLVM::GEPNoWrapFlags::none) const
Convenience wrapper for the corresponding helper utility.
Definition Pattern.cpp:66
llvm::TypeSize getTypeSizeInBits(Type t) const
Returns the size in bits of the given type in the current scope.
Conversion from types to the LLVM IR dialect.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition Location.h:76
MLIRContext is the top-level object for a collection of MLIR operations.
Definition MLIRContext.h:63
Helper class to produce LLVM dialect operations extracting or inserting elements of a MemRef descript...
Value stride(OpBuilder &builder, Location loc, unsigned pos)
Builds IR extracting the pos-th size from the descriptor.
Value size(OpBuilder &builder, Location loc, unsigned pos)
Builds IR extracting the pos-th size from the descriptor.
NamedAttribute represents a combination of a name and an Attribute value.
Definition Attributes.h:164
This class helps build Operations.
Definition Builders.h:209
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Definition Operation.h:433
result_range getResults()
Definition Operation.h:441
unsigned getNumResults()
Return the number of results held by this operation.
Definition Operation.h:430
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
This class provides an abstraction over the various different ranges of value types.
Definition TypeRange.h:40
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition Types.h:74
bool isF64() const
Definition Types.cpp:41
MLIRContext * getContext() const
Return the MLIRContext in which this type was uniqued.
Definition Types.cpp:35
bool isF8E5M2() const
Definition Types.cpp:45
bool isSignedInteger() const
Return true if this is a signed integer type (with the specified width).
Definition Types.cpp:78
bool isF8E4M3FN() const
Definition Types.cpp:44
bool isF32() const
Definition Types.cpp:40
bool isUnsignedInteger() const
Return true if this is an unsigned integer type (with the specified width).
Definition Types.cpp:90
bool isInteger() const
Return true if this is an integer type (with the specified width).
Definition Types.cpp:58
bool isF16() const
Definition Types.cpp:38
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
Definition Types.cpp:124
bool isBF16() const
Definition Types.cpp:37
This class provides an abstraction over the different types of ranges over Values.
Definition ValueRange.h:389
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
Type getType() const
Return the type of this value.
Definition Value.h:105
Value getStridedElementPtr(OpBuilder &builder, Location loc, const LLVMTypeConverter &converter, MemRefType type, Value memRefDesc, ValueRange indices, LLVM::GEPNoWrapFlags noWrapFlags=LLVM::GEPNoWrapFlags::none)
Performs the index computation to get to the element at indices of the memory pointed to by memRefDes...
Definition Pattern.cpp:603
int32_t getGlobalPrefetchLLVMEncoding(amdgpu::LoadTemporalHint hint, amdgpu::Scope scope, bool isSpeculative)
Definition AMDGPUEnums.h:18
bool hasOcpFp8(const Chipset &chipset)
Definition Chipset.h:52
void populateCommonGPUTypeAndAttributeConversions(TypeConverter &typeConverter)
Remap common GPU memory spaces (Workgroup, Private, etc) to LLVM address spaces.
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition Remarks.h:717
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
Definition Matchers.h:490
SmallVector< OpFoldResult > getMixedValues(ArrayRef< int64_t > staticValues, ValueRange dynamicValues, MLIRContext *context)
Return a vector of OpFoldResults with the same size a staticValues, but all elements for which Shaped...
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
Definition Utils.cpp:307
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
detail::constant_int_predicate_matcher m_Zero()
Matches a constant scalar / vector splat / tensor splat integer zero.
Definition Matchers.h:442
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
void populateGpuMemorySpaceAttributeConversions(TypeConverter &typeConverter, const MemorySpaceMapping &mapping)
Populates memory space attribute conversion rules for lowering gpu.address_space to integer values.
void populateAMDGPUToROCDLConversionPatterns(LLVMTypeConverter &converter, RewritePatternSet &patterns, amdgpu::Chipset chipset)
Note: This function will also add conversions for the AMDGPU-specific address spaces and types,...
llvm::TypeSwitch< T, ResultT > TypeSwitch
Definition LLVM.h:139
void populateAMDGPUTypeAndAttributeConversions(TypeConverter &typeConverter)
Remap AMDGPU memory spaces to LLVM address spaces by mapping amdgpu::AddressSpace::fat_raw_buffer to ...
Returns the rocdl intrinsic corresponding to a SparseWMMA operation swmmac if one exists.
Represents the amdgpu gfx chipset version, e.g., gfx90a, gfx942, gfx1103.
Definition Chipset.h:22
unsigned majorVersion
Definition Chipset.h:23
unsigned minorVersion
Definition Chipset.h:24
static FailureOr< Chipset > parse(StringRef name)
Parses the chipset version string and returns the chipset on success, and failure otherwise.
Definition Chipset.cpp:14
unsigned steppingVersion
Definition Chipset.h:25