MLIR 23.0.0git
AMDGPUToROCDL.cpp
Go to the documentation of this file.
1//===- AMDGPUToROCDL.cpp - AMDGPU to ROCDL dialect conversion -------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
10
21#include "mlir/IR/Attributes.h"
24#include "mlir/IR/Matchers.h"
26#include "mlir/Pass/Pass.h"
27
29
30#include "llvm/ADT/STLExtras.h"
31#include "llvm/ADT/TypeSwitch.h"
32#include "llvm/Support/AMDGPUAddrSpace.h"
33#include "llvm/Support/Casting.h"
34#include "llvm/Support/ErrorHandling.h"
35#include <cstdint>
36#include <optional>
37
38namespace mlir {
39#define GEN_PASS_DEF_CONVERTAMDGPUTOROCDLPASS
40#include "mlir/Conversion/Passes.h.inc"
41} // namespace mlir
42
43using namespace mlir;
44using namespace mlir::amdgpu;
45
46// Define commonly used chipsets versions for convenience.
47constexpr Chipset kGfx908 = Chipset(9, 0, 8);
48constexpr Chipset kGfx90a = Chipset(9, 0, 0xa);
49constexpr Chipset kGfx942 = Chipset(9, 4, 2);
50constexpr Chipset kGfx950 = Chipset(9, 5, 0);
51constexpr Chipset kGfx1200 = Chipset(12, 0, 0);
52constexpr Chipset kGfx1250 = Chipset(12, 5, 0);
53
54// Predicates mirroring the LLVM AMDGPU `HasDot{N}Insts` features that gate
55// the `v_dot*` instructions consumed by the `amdgpu.dot` lowering.
56static bool hasDot1Insts(const Chipset &chipset) {
57 if (chipset.majorVersion == 9)
58 return chipset >= Chipset(9, 0, 6);
59 if (chipset.majorVersion == 10) {
60 if (chipset.minorVersion == 1)
61 return chipset.steppingVersion == 1u || chipset.steppingVersion == 2u;
62 return chipset.minorVersion >= 3u;
63 }
64 return false;
65}
66
67static bool hasDot2Insts(const Chipset &chipset) {
68 return hasDot1Insts(chipset);
69}
70
71static bool hasDot7Insts(const Chipset &chipset) {
72 return chipset.majorVersion >= 11 || hasDot1Insts(chipset);
73}
74
75static bool hasDot8Insts(const Chipset &chipset) {
76 return chipset.majorVersion >= 11;
77}
78
79static bool hasDot9Insts(const Chipset &chipset) {
80 if (chipset.majorVersion == 11)
81 return true;
82 return chipset.majorVersion == 12 && chipset.minorVersion == 0;
83}
84
85static bool hasDot10Insts(const Chipset &chipset) {
86 if (chipset.majorVersion == 11)
87 return true;
88 if (chipset.majorVersion == 12)
89 return chipset.minorVersion == 0;
90 return hasDot1Insts(chipset);
91}
92
93static bool hasDot11Insts(const Chipset &chipset) {
94 if (chipset.majorVersion == 11)
95 return chipset.minorVersion == 7u;
96 return chipset.majorVersion == 12 && chipset.minorVersion == 0;
97}
98
99static bool hasDot12Insts(const Chipset &chipset) {
100 if (chipset == Chipset(9, 5, 0))
101 return true;
102 if (chipset.majorVersion == 11)
103 return true;
104 return chipset.majorVersion == 12 && chipset.minorVersion == 0;
105}
106
107/// Convert an unsigned number `val` to i32.
108static Value convertUnsignedToI32(ConversionPatternRewriter &rewriter,
109 Location loc, Value val) {
110 IntegerType i32 = rewriter.getI32Type();
111 // Force check that `val` is of int type.
112 auto valTy = cast<IntegerType>(val.getType());
113 if (i32 == valTy)
114 return val;
115 return valTy.getWidth() > 32
116 ? Value(LLVM::TruncOp::create(rewriter, loc, i32, val))
117 : Value(LLVM::ZExtOp::create(rewriter, loc, i32, val));
118}
119
120static Value createI32Constant(ConversionPatternRewriter &rewriter,
121 Location loc, int32_t value) {
122 return LLVM::ConstantOp::create(rewriter, loc, rewriter.getI32Type(), value);
123}
124
125/// Convert an unsigned number `val` to i64.
126static Value convertUnsignedToI64(ConversionPatternRewriter &rewriter,
127 Location loc, Value val) {
128 IntegerType i64 = rewriter.getI64Type();
129 // Force check that `val` is of int type.
130 auto valTy = cast<IntegerType>(val.getType());
131 if (i64 == valTy)
132 return val;
133 return valTy.getWidth() > 64
134 ? Value(LLVM::TruncOp::create(rewriter, loc, i64, val))
135 : Value(LLVM::ZExtOp::create(rewriter, loc, i64, val));
136}
137
138static Value createI64Constant(ConversionPatternRewriter &rewriter,
139 Location loc, int64_t value) {
140 return LLVM::ConstantOp::create(rewriter, loc, rewriter.getI64Type(), value);
141}
142
143/// Returns the linear index used to access an element in the memref.
144static Value getLinearIndexI32(ConversionPatternRewriter &rewriter,
145 Location loc, MemRefDescriptor &memRefDescriptor,
147 IntegerType i32 = rewriter.getI32Type();
148 Value index;
149 for (auto [i, increment, stride] : llvm::enumerate(indices, strides)) {
150 if (stride != 1) { // Skip if stride is 1.
151 Value strideValue =
152 ShapedType::isDynamic(stride)
153 ? convertUnsignedToI32(rewriter, loc,
154 memRefDescriptor.stride(rewriter, loc, i))
155 : LLVM::ConstantOp::create(rewriter, loc, i32, stride);
156 increment = LLVM::MulOp::create(rewriter, loc, increment, strideValue);
157 }
158 index = index ? LLVM::AddOp::create(rewriter, loc, index, increment)
159 : increment;
160 }
161 return index ? index : createI32Constant(rewriter, loc, 0);
162}
163
164/// Compute the contents of the `num_records` field for a given memref
165/// descriptor - that is, the number of bytes that's one element past the
166/// greatest possible valid index into the memref.
167static Value getNumRecords(ConversionPatternRewriter &rewriter, Location loc,
168 MemRefType memrefType,
169 MemRefDescriptor &memrefDescriptor,
170 ArrayRef<int64_t> strides, int64_t elementByteWidth,
171 amdgpu::Chipset chipset, bool boundsCheck) {
172 if (chipset >= kGfx1250 && !boundsCheck) {
173 constexpr int64_t first45bits = (1ll << 45) - 1;
174 return createI64Constant(rewriter, loc, first45bits);
175 }
176 if (memrefType.hasStaticShape() &&
177 !llvm::any_of(strides, ShapedType::isDynamic)) {
178 int64_t size = memrefType.getRank() == 0 ? 1 : 0;
179 ArrayRef<int64_t> shape = memrefType.getShape();
180 for (uint32_t i = 0, e = memrefType.getRank(); i < e; ++i)
181 size = std::max(shape[i] * strides[i], size);
182 size = size * elementByteWidth;
183 return createI64Constant(rewriter, loc, size);
184 }
185 Value maxIndex;
186 for (uint32_t i = 0, e = memrefType.getRank(); i < e; ++i) {
187 Value size = memrefDescriptor.size(rewriter, loc, i);
188 Value stride = memrefDescriptor.stride(rewriter, loc, i);
189 Value maxThisDim = LLVM::MulOp::create(rewriter, loc, size, stride);
190 maxIndex = maxIndex
191 ? LLVM::UMaxOp::create(rewriter, loc, maxIndex, maxThisDim)
192 : maxThisDim;
193 }
194 Value maxIndexI64 = convertUnsignedToI64(rewriter, loc, maxIndex);
195 Value byteWidthConst = createI64Constant(rewriter, loc, elementByteWidth);
196 return LLVM::MulOp::create(rewriter, loc, maxIndexI64, byteWidthConst);
197}
198
199static Value makeBufferRsrc(ConversionPatternRewriter &rewriter, Location loc,
200 Value basePointer, Value numRecords,
201 bool boundsCheck, amdgpu::Chipset chipset,
202 Value cacheSwizzleStride = nullptr,
203 unsigned addressSpace = 8) {
204 // The stride value is generally 0. However, on MI-300 and onward, you can
205 // enable a cache swizzling mode by setting bit 14 of the stride field
206 // and setting that stride to a cache stride.
207 Type i16 = rewriter.getI16Type();
208 Value stride;
209 if (chipset.majorVersion == 9 && chipset >= kGfx942 && cacheSwizzleStride) {
210 Value cacheStrideZext =
211 LLVM::ZExtOp::create(rewriter, loc, i16, cacheSwizzleStride);
212 Value swizzleBit = LLVM::ConstantOp::create(
213 rewriter, loc, i16, rewriter.getI16IntegerAttr(1 << 14));
214 stride = LLVM::OrOp::create(rewriter, loc, cacheStrideZext, swizzleBit,
215 /*isDisjoint=*/true);
216 } else {
217 stride = LLVM::ConstantOp::create(rewriter, loc, i16,
218 rewriter.getI16IntegerAttr(0));
219 }
220
221 uint32_t flags = 0;
222 if (chipset >= kGfx1250) {
223 // Flag word:
224 // bit 0: swizzle
225 // bit 1: 0 means (total_offset + payload > numRecords)
226 // 1 means ((total_offset + payload >) numRecords) || ((offset +
227 // payload) > stride) only applied when swizzle_enable = 0. keep at
228 // zero.
229 // whether oob is done depends on numRecords.
230 // bits 2-3: Type (must be 0)
231 } else {
232 // Get the number of elements.
233 // Flag word:
234 // bits 0-11: dst sel, ignored by these intrinsics
235 // bits 12-14: data format (ignored, must be nonzero, 7=float)
236 // bits 15-18: data format (ignored, must be nonzero, 4=32bit)
237 // bit 19: In nested heap (0 here)
238 // bit 20: Behavior on unmap (0 means "return 0 / ignore")
239 // bits 21-22: Index stride for swizzles (N/A)
240 // bit 23: Add thread ID (0)
241 // bit 24: Reserved to 1 (RDNA) or 0 (CDNA)
242 // bits 25-26: Reserved (0)
243 // bit 27: Buffer is non-volatile (CDNA only)
244 // bits 28-29: Out of bounds select (0 = structured, 1 = check index, 2 =
245 // none, 3 = either swizzles or testing against offset field) RDNA only
246 // bits 30-31: Type (must be 0)
247 flags |= (7 << 12) | (4 << 15);
248 if (chipset.majorVersion >= 10) {
249 flags |= (1 << 24);
250 uint32_t oob = boundsCheck ? 3 : 2;
251 flags |= (oob << 28);
252 }
253 }
254 Value flagsConst = createI32Constant(rewriter, loc, flags);
255 Type rsrcType =
256 LLVM::LLVMPointerType::get(rewriter.getContext(), addressSpace);
257 Value resource = rewriter.createOrFold<ROCDL::MakeBufferRsrcOp>(
258 loc, rsrcType, basePointer, stride, numRecords, flagsConst);
259 return resource;
260}
261
262namespace {
263struct FatRawBufferCastLowering
264 : public ConvertOpToLLVMPattern<FatRawBufferCastOp> {
265 FatRawBufferCastLowering(const LLVMTypeConverter &converter, Chipset chipset)
266 : ConvertOpToLLVMPattern<FatRawBufferCastOp>(converter),
267 chipset(chipset) {}
268
269 Chipset chipset;
270
271 LogicalResult
272 matchAndRewrite(FatRawBufferCastOp op, FatRawBufferCastOpAdaptor adaptor,
273 ConversionPatternRewriter &rewriter) const override {
274 Location loc = op.getLoc();
275 Value memRef = adaptor.getSource();
276 Value unconvertedMemref = op.getSource();
277 MemRefType memrefType = cast<MemRefType>(unconvertedMemref.getType());
278 MemRefDescriptor descriptor(memRef);
279
280 DataLayout dataLayout = DataLayout::closest(op);
281 int64_t elementByteWidth =
282 dataLayout.getTypeSizeInBits(memrefType.getElementType()) / 8;
283
284 int64_t unusedOffset = 0;
285 SmallVector<int64_t, 5> strideVals;
286 if (failed(memrefType.getStridesAndOffset(strideVals, unusedOffset)))
287 return op.emitOpError("Can't lower non-stride-offset memrefs");
288
289 Value numRecords = adaptor.getValidBytes();
290 if (!numRecords)
291 numRecords =
292 getNumRecords(rewriter, loc, memrefType, descriptor, strideVals,
293 elementByteWidth, chipset, adaptor.getBoundsCheck());
294
295 Value basePointer =
296 adaptor.getResetOffset()
297 ? descriptor.bufferPtr(rewriter, loc, *getTypeConverter(),
298 memrefType)
299 : descriptor.alignedPtr(rewriter, loc);
300
301 Value offset = adaptor.getResetOffset()
302 ? LLVM::ConstantOp::create(rewriter, loc, getIndexType(),
303 rewriter.getIndexAttr(0))
304 : descriptor.offset(rewriter, loc);
305
306 bool hasSizes = memrefType.getRank() > 0;
307 // No need to unpack() and pack() all the individual sizes and strides,
308 // so we'll just extract the arrays.
309 Value sizes = hasSizes
310 ? LLVM::ExtractValueOp::create(rewriter, loc, descriptor,
312 : Value{};
313 Value strides =
314 hasSizes ? LLVM::ExtractValueOp::create(rewriter, loc, descriptor,
316 : Value{};
317
318 Value fatPtr = makeBufferRsrc(
319 rewriter, loc, basePointer, numRecords, adaptor.getBoundsCheck(),
320 chipset, adaptor.getCacheSwizzleStride(), /*addressSpace=*/7);
321
322 Value result = MemRefDescriptor::poison(
323 rewriter, loc,
324 getTypeConverter()->convertType(op.getResult().getType()));
325 SmallVector<int64_t> pos{kAllocatedPtrPosInMemRefDescriptor};
326 result = LLVM::InsertValueOp::create(rewriter, loc, result, fatPtr, pos);
327 result = LLVM::InsertValueOp::create(rewriter, loc, result, fatPtr,
329 result = LLVM::InsertValueOp::create(rewriter, loc, result, offset,
331 if (hasSizes) {
332 result = LLVM::InsertValueOp::create(rewriter, loc, result, sizes,
334 result = LLVM::InsertValueOp::create(rewriter, loc, result, strides,
336 }
337 rewriter.replaceOp(op, result);
338 return success();
339 }
340};
341
342/// Define lowering patterns for raw buffer ops
343template <typename GpuOp, typename Intrinsic>
344struct RawBufferOpLowering : public ConvertOpToLLVMPattern<GpuOp> {
345 RawBufferOpLowering(const LLVMTypeConverter &converter, Chipset chipset)
346 : ConvertOpToLLVMPattern<GpuOp>(converter), chipset(chipset) {}
347
348 Chipset chipset;
349 static constexpr uint32_t maxVectorOpWidth = 128;
350
351 LogicalResult
352 matchAndRewrite(GpuOp gpuOp, typename GpuOp::Adaptor adaptor,
353 ConversionPatternRewriter &rewriter) const override {
354 Location loc = gpuOp.getLoc();
355 Value memref = adaptor.getMemref();
356 Value unconvertedMemref = gpuOp.getMemref();
357 MemRefType memrefType = cast<MemRefType>(unconvertedMemref.getType());
358
359 if (chipset.majorVersion < 9)
360 return gpuOp.emitOpError("raw buffer ops require GCN or higher");
361
362 Value storeData = adaptor.getODSOperands(0)[0];
363 if (storeData == memref) // no write component to this op
364 storeData = Value();
365 Type wantedDataType;
366 if (storeData)
367 wantedDataType = storeData.getType();
368 else
369 wantedDataType = gpuOp.getODSResults(0)[0].getType();
370
371 Value atomicCmpData = Value();
372 // Operand index 1 of a load is the indices, trying to read them can crash.
373 if (storeData) {
374 Value maybeCmpData = adaptor.getODSOperands(1)[0];
375 if (maybeCmpData != memref)
376 atomicCmpData = maybeCmpData;
377 }
378
379 Type llvmWantedDataType = this->typeConverter->convertType(wantedDataType);
380
381 Type i32 = rewriter.getI32Type();
382
383 // Get the type size in bytes.
384 DataLayout dataLayout = DataLayout::closest(gpuOp);
385 int64_t elementByteWidth =
386 dataLayout.getTypeSizeInBits(memrefType.getElementType()) / 8;
387 Value byteWidthConst = createI32Constant(rewriter, loc, elementByteWidth);
388
389 // If we want to load a vector<NxT> with total size <= 32
390 // bits, use a scalar load and bitcast it. Similarly, if bitsize(T) < 32
391 // and the total load size is >= 32, use a vector load of N / (bitsize(T) /
392 // 32) x i32 and bitcast. Also, the CAS intrinsic requires integer operands,
393 // so bitcast any floats to integers.
394 Type llvmBufferValType = llvmWantedDataType;
395 if (atomicCmpData) {
396 if (auto floatType = dyn_cast<FloatType>(wantedDataType))
397 llvmBufferValType = this->getTypeConverter()->convertType(
398 rewriter.getIntegerType(floatType.getWidth()));
399 }
400 if (auto dataVector = dyn_cast<VectorType>(wantedDataType)) {
401 uint32_t vecLen = dataVector.getNumElements();
402 uint32_t elemBits =
403 dataLayout.getTypeSizeInBits(dataVector.getElementType());
404 uint32_t totalBits = elemBits * vecLen;
405 bool usePackedFp16 =
406 isa_and_present<RawBufferAtomicFaddOp>(*gpuOp) && vecLen == 2;
407 if (totalBits > maxVectorOpWidth)
408 return gpuOp.emitOpError(
409 "Total width of loads or stores must be no more than " +
410 Twine(maxVectorOpWidth) + " bits, but we call for " +
411 Twine(totalBits) +
412 " bits. This should've been caught in validation");
413 if (!usePackedFp16 && elemBits < 32) {
414 if (totalBits > 32) {
415 if (totalBits % 32 != 0)
416 return gpuOp.emitOpError("Load or store of more than 32-bits that "
417 "doesn't fit into words. Can't happen\n");
418 llvmBufferValType = this->typeConverter->convertType(
419 VectorType::get(totalBits / 32, i32));
420 } else {
421 llvmBufferValType = this->typeConverter->convertType(
422 rewriter.getIntegerType(totalBits));
423 }
424 }
425 }
426 if (auto vecType = dyn_cast<VectorType>(llvmBufferValType)) {
427 // Buffer intrinsics doesn't support 1-element vectors, cast them to
428 // scalars.
429 if (vecType.getNumElements() == 1)
430 llvmBufferValType = vecType.getElementType();
431 }
432
433 SmallVector<Value, 6> args;
434 if (storeData) {
435 if (llvmBufferValType != llvmWantedDataType) {
436 Value castForStore = LLVM::BitcastOp::create(
437 rewriter, loc, llvmBufferValType, storeData);
438 args.push_back(castForStore);
439 } else {
440 args.push_back(storeData);
441 }
442 }
443
444 if (atomicCmpData) {
445 if (llvmBufferValType != llvmWantedDataType) {
446 Value castForCmp = LLVM::BitcastOp::create(
447 rewriter, loc, llvmBufferValType, atomicCmpData);
448 args.push_back(castForCmp);
449 } else {
450 args.push_back(atomicCmpData);
451 }
452 }
453
454 // Construct buffer descriptor from memref, attributes
455 int64_t offset = 0;
456 SmallVector<int64_t, 5> strides;
457 if (failed(memrefType.getStridesAndOffset(strides, offset)))
458 return gpuOp.emitOpError("Can't lower non-stride-offset memrefs");
459
460 MemRefDescriptor memrefDescriptor(memref);
461
462 Value ptr = memrefDescriptor.bufferPtr(
463 rewriter, loc, *this->getTypeConverter(), memrefType);
464 Value numRecords =
465 getNumRecords(rewriter, loc, memrefType, memrefDescriptor, strides,
466 elementByteWidth, chipset, adaptor.getBoundsCheck());
467 Value resource = makeBufferRsrc(rewriter, loc, ptr, numRecords,
468 adaptor.getBoundsCheck(), chipset);
469 args.push_back(resource);
470
471 // Indexing (voffset)
472 Value voffset = getLinearIndexI32(rewriter, loc, memrefDescriptor,
473 adaptor.getIndices(), strides);
474 if (std::optional<int32_t> indexOffset = adaptor.getIndexOffset();
475 indexOffset && *indexOffset > 0) {
476 Value extraOffsetConst = createI32Constant(rewriter, loc, *indexOffset);
477 voffset = voffset ? LLVM::AddOp::create(rewriter, loc, voffset,
478 extraOffsetConst)
479 : extraOffsetConst;
480 }
481 voffset = LLVM::MulOp::create(rewriter, loc, voffset, byteWidthConst);
482 args.push_back(voffset);
483
484 // SGPR offset.
485 Value sgprOffset = adaptor.getSgprOffset();
486 if (!sgprOffset)
487 sgprOffset = createI32Constant(rewriter, loc, 0);
488 sgprOffset = LLVM::MulOp::create(rewriter, loc, sgprOffset, byteWidthConst);
489 args.push_back(sgprOffset);
490
491 // bit 0: GLC = 0 (atomics drop value, less coherency)
492 // bits 1-2: SLC, DLC = 0 (similarly)
493 // bit 3: swizzled (0 for raw)
494 args.push_back(createI32Constant(rewriter, loc, 0));
495
496 llvm::SmallVector<Type, 1> resultTypes(gpuOp->getNumResults(),
497 llvmBufferValType);
498 Operation *lowered = Intrinsic::create(rewriter, loc, resultTypes, args,
499 ArrayRef<NamedAttribute>());
500 if (lowered->getNumResults() == 1) {
501 Value replacement = lowered->getResult(0);
502 if (llvmBufferValType != llvmWantedDataType) {
503 replacement = LLVM::BitcastOp::create(rewriter, loc, llvmWantedDataType,
505 }
506 rewriter.replaceOp(gpuOp, replacement);
507 } else {
508 rewriter.eraseOp(gpuOp);
509 }
510 return success();
511 }
512};
513
514// TODO: AMDGPU backend already have all this bitpacking logic, we should move
515// it to some common place.
516/// Vmcnt, Expcnt and Lgkmcnt are decoded as follows:
517/// Vmcnt = Waitcnt[3:0] (pre-gfx9)
518/// Vmcnt = Waitcnt[15:14,3:0] (gfx9,10)
519/// Vmcnt = Waitcnt[15:10] (gfx11)
520/// Expcnt = Waitcnt[6:4] (pre-gfx11)
521/// Expcnt = Waitcnt[2:0] (gfx11)
522/// Lgkmcnt = Waitcnt[11:8] (pre-gfx10)
523/// Lgkmcnt = Waitcnt[13:8] (gfx10)
524/// Lgkmcnt = Waitcnt[9:4] (gfx11)
525static FailureOr<unsigned> encodeWaitcnt(Chipset chipset, unsigned vmcnt,
526 unsigned expcnt, unsigned lgkmcnt) {
527 if (chipset.majorVersion < 9) {
528 vmcnt = std::min(15u, vmcnt);
529 expcnt = std::min(7u, expcnt);
530 lgkmcnt = std::min(15u, lgkmcnt);
531 return vmcnt | (expcnt << 4) | (lgkmcnt << 8);
532 }
533 if (chipset.majorVersion == 9) {
534 vmcnt = std::min(63u, vmcnt);
535 expcnt = std::min(7u, expcnt);
536 lgkmcnt = std::min(15u, lgkmcnt);
537 unsigned lowBits = vmcnt & 0xF;
538 unsigned highBits = (vmcnt >> 4) << 14;
539 unsigned otherCnts = (expcnt << 4) | (lgkmcnt << 8);
540 return lowBits | highBits | otherCnts;
541 }
542 if (chipset.majorVersion == 10) {
543 vmcnt = std::min(63u, vmcnt);
544 expcnt = std::min(7u, expcnt);
545 lgkmcnt = std::min(63u, lgkmcnt);
546 unsigned lowBits = vmcnt & 0xF;
547 unsigned highBits = (vmcnt >> 4) << 14;
548 unsigned otherCnts = (expcnt << 4) | (lgkmcnt << 8);
549 return lowBits | highBits | otherCnts;
550 }
551 if (chipset.majorVersion == 11) {
552 vmcnt = std::min(63u, vmcnt);
553 expcnt = std::min(7u, expcnt);
554 lgkmcnt = std::min(63u, lgkmcnt);
555 return (vmcnt << 10) | expcnt | (lgkmcnt << 4);
556 }
557 return failure();
558}
559
560struct MemoryCounterWaitOpLowering
561 : public ConvertOpToLLVMPattern<MemoryCounterWaitOp> {
562 MemoryCounterWaitOpLowering(const LLVMTypeConverter &converter,
563 Chipset chipset)
564 : ConvertOpToLLVMPattern<MemoryCounterWaitOp>(converter),
565 chipset(chipset) {}
566
567 Chipset chipset;
568
569 LogicalResult
570 matchAndRewrite(MemoryCounterWaitOp op, OpAdaptor adaptor,
571 ConversionPatternRewriter &rewriter) const override {
572 if (chipset.majorVersion >= 12) {
573 Location loc = op.getLoc();
574 if (std::optional<int> ds = adaptor.getDs())
575 ROCDL::WaitDscntOp::create(rewriter, loc, *ds);
576
577 if (std::optional<int> load = adaptor.getLoad())
578 ROCDL::WaitLoadcntOp::create(rewriter, loc, *load);
579
580 if (std::optional<int> store = adaptor.getStore())
581 ROCDL::WaitStorecntOp::create(rewriter, loc, *store);
582
583 if (std::optional<int> exp = adaptor.getExp())
584 ROCDL::WaitExpcntOp::create(rewriter, loc, *exp);
585
586 if (std::optional<int> tensor = adaptor.getTensor())
587 ROCDL::WaitTensorcntOp::create(rewriter, loc, *tensor);
588
589 rewriter.eraseOp(op);
590 return success();
591 }
592
593 if (adaptor.getTensor())
594 return op.emitOpError("unsupported chipset");
595
596 auto getVal = [](Attribute attr) -> unsigned {
597 if (attr)
598 return cast<IntegerAttr>(attr).getInt();
599
600 // This value will be clamped to the maximum value for the chipset.
601 return 1024;
602 };
603 unsigned ds = getVal(adaptor.getDsAttr());
604 unsigned exp = getVal(adaptor.getExpAttr());
605
606 unsigned vmcnt = 1024;
607 Attribute load = adaptor.getLoadAttr();
608 Attribute store = adaptor.getStoreAttr();
609 if (load && store) {
610 vmcnt = getVal(load) + getVal(store);
611 } else if (load) {
612 vmcnt = getVal(load);
613 } else if (store) {
614 vmcnt = getVal(store);
615 }
616
617 FailureOr<unsigned> waitcnt = encodeWaitcnt(chipset, vmcnt, exp, ds);
618 if (failed(waitcnt))
619 return op.emitOpError("unsupported chipset");
620
621 rewriter.replaceOpWithNewOp<ROCDL::SWaitcntOp>(op, *waitcnt);
622 return success();
623 }
624};
625
626struct LDSBarrierOpLowering : public ConvertOpToLLVMPattern<LDSBarrierOp> {
627 LDSBarrierOpLowering(const LLVMTypeConverter &converter, Chipset chipset)
628 : ConvertOpToLLVMPattern<LDSBarrierOp>(converter), chipset(chipset) {}
629
630 Chipset chipset;
631
632 LogicalResult
633 matchAndRewrite(LDSBarrierOp op, LDSBarrierOp::Adaptor adaptor,
634 ConversionPatternRewriter &rewriter) const override {
635 Location loc = op.getLoc();
636 // This ensures that waits on global memory aren't introduced on
637 // chips that don't have the BackOffBarrier feature enabled in LLVM.
638 bool requiresInlineAsm = chipset < kGfx90a;
639
640 Attribute mmra =
641 rewriter.getAttr<LLVM::MMRATagAttr>("amdgpu-synchronize-as", "local");
642 // Note: while there *is* a workgroup-one-as scope, this, when combined with
643 // the MMRA, will lead to the fence having no effect. This is because the
644 // codepaths for an atomic load or store will observe that a
645 // one-address-space atomic to LDS requires no synchronization because
646 // operations on LDS are totally ordered with respect to each other, and so
647 // will not emit the correct waitcnt operations that these fences are
648 // intended to produce. Therefore, we use a broader type of fence and rely
649 // on the MMRA to relax it to the semantics we want.
650 StringRef scope = "workgroup";
651
652 auto relFence = LLVM::FenceOp::create(rewriter, loc,
653 LLVM::AtomicOrdering::release, scope);
654 relFence->setDiscardableAttr(LLVM::LLVMDialect::getMmraAttrName(), mmra);
655 if (requiresInlineAsm) {
656 auto asmDialectAttr = LLVM::AsmDialectAttr::get(rewriter.getContext(),
657 LLVM::AsmDialect::AD_ATT);
658 const char *asmStr = ";;;WARNING: BREAKS DEBUG WATCHES\ns_barrier";
659 const char *constraints = "";
660 LLVM::InlineAsmOp::create(
661 rewriter, loc,
662 /*resultTypes=*/TypeRange(), /*operands=*/ValueRange(),
663 /*asm_string=*/asmStr, constraints, /*has_side_effects=*/true,
664 /*is_align_stack=*/false, LLVM::TailCallKind::None,
665 /*asm_dialect=*/asmDialectAttr,
666 /*operand_attrs=*/ArrayAttr());
667 } else if (chipset.majorVersion < 12) {
668 ROCDL::SBarrierOp::create(rewriter, loc);
669 } else {
670 ROCDL::BarrierSignalOp::create(rewriter, loc, -1);
671 ROCDL::BarrierWaitOp::create(rewriter, loc, -1);
672 }
673
674 auto acqFence = LLVM::FenceOp::create(rewriter, loc,
675 LLVM::AtomicOrdering::acquire, scope);
676 acqFence->setDiscardableAttr(LLVM::LLVMDialect::getMmraAttrName(), mmra);
677 rewriter.replaceOp(op, acqFence);
678 return success();
679 }
680};
681
682struct SchedBarrierOpLowering : public ConvertOpToLLVMPattern<SchedBarrierOp> {
683 SchedBarrierOpLowering(const LLVMTypeConverter &converter, Chipset chipset)
684 : ConvertOpToLLVMPattern<SchedBarrierOp>(converter), chipset(chipset) {}
685
686 Chipset chipset;
687
688 LogicalResult
689 matchAndRewrite(SchedBarrierOp op, SchedBarrierOp::Adaptor adaptor,
690 ConversionPatternRewriter &rewriter) const override {
691 rewriter.replaceOpWithNewOp<ROCDL::SchedBarrier>(op,
692 (uint32_t)op.getOpts());
693 return success();
694 }
695};
696
697} // namespace
698
699/// Pack small float vector operands (fp4/fp6/fp8/bf16) into the format
700/// expected by scaled matrix multiply intrinsics (MFMA/WMMA).
701///
702/// Specifically:
703/// 1. If the element type is bfloat16, bitcast it to i16 unless rocdl intrinsic
704/// allows bf16. Newer MFMAs support bf16 types on operand, check
705/// IntrinsicsAMDGPU.td file for reference.
706/// 2. If instead we have a more than 64-bit quantity, use a <N / 4 x i32>
707/// instead, which is what the f8f6f4 intrinsics use.
708/// 3. If `input` is a vector of N <= 8 bytes, bitcast it to a (N * 8)-bit
709/// integer.
710///
711/// Note that the type of `input` has already been LLVM type converted:
712/// therefore 8-bit and smaller floats are represented as their corresponding
713/// `iN` integers.
714static Value packSmallFloatVectorOperand(ConversionPatternRewriter &rewriter,
715 Location loc, Value input,
716 bool allowBf16 = true) {
717 Type inputType = input.getType();
718 if (auto vectorType = dyn_cast<VectorType>(inputType)) {
719 if (vectorType.getElementType().isBF16() && !allowBf16)
720 return LLVM::BitcastOp::create(
721 rewriter, loc, vectorType.clone(rewriter.getI16Type()), input);
722 if (vectorType.getElementType().isInteger(8) &&
723 vectorType.getNumElements() <= 8)
724 return LLVM::BitcastOp::create(
725 rewriter, loc,
726 rewriter.getIntegerType(vectorType.getNumElements() * 8), input);
727 if (isa<IntegerType>(vectorType.getElementType()) &&
728 vectorType.getElementTypeBitWidth() <= 8) {
729 int64_t numWords = llvm::divideCeil(
730 vectorType.getNumElements() * vectorType.getElementTypeBitWidth(),
731 32);
732 return LLVM::BitcastOp::create(
733 rewriter, loc, VectorType::get(numWords, rewriter.getI32Type()),
734 input);
735 }
736 }
737 return input;
738}
739
740/// Converts packed vector operands to the expected ROCDL types.
741static Value convertPackedVectorOperand(ConversionPatternRewriter &rewriter,
742 Location loc, Value input,
743 bool allowBf16 = true) {
744 Type inputType = input.getType();
745 auto vectorType = cast<VectorType>(inputType);
746 // bf16 -> i16 when not allowed (pre-gfx950).
747 if (vectorType.getElementType().isBF16() && !allowBf16)
748 return LLVM::BitcastOp::create(
749 rewriter, loc, vectorType.clone(rewriter.getI16Type()), input);
750 // i8/fp8 vectors -> vector<Nxi32>.
751 if (isa<IntegerType>(vectorType.getElementType()) &&
752 vectorType.getElementTypeBitWidth() <= 8) {
753 int64_t numWords = llvm::divideCeil(
754 vectorType.getNumElements() * vectorType.getElementTypeBitWidth(), 32);
755 Type castType = (numWords > 1)
756 ? Type{VectorType::get(numWords, rewriter.getI32Type())}
757 : rewriter.getI32Type();
758 return LLVM::BitcastOp::create(rewriter, loc, castType, input);
759 }
760 return input;
761}
762
763/// Converts the scaled MFMA/WMMA operands, `scalesA` and `scalesB`, from MLIR
764/// AMDGPU dialect convention to ROCDL and LLVM AMDGPU intrinsics convention.
765///
766/// Specifically:
767/// 1. If `input` is a i8 value, zero extend it to i32
768/// 2. If `input` is a vector of length 4 or 8 and type i8, cast it to i32
769///
770/// Note that the type of `input` has already been LLVM type converted:
771/// therefore 8-bit and smaller floats are represented as their corresponding
772/// `iN` integers.
773static Value castScaleOperand(ConversionPatternRewriter &rewriter, Location loc,
774 Value input) {
775 return TypeSwitch<Type, Value>(input.getType())
776 .Case([&](IntegerType) {
777 // Handle scalar i8: zero extend to i32.
778 return LLVM::ZExtOp::create(rewriter, loc, rewriter.getI32Type(),
779 input);
780 })
781 .Case([&](VectorType vectorType) {
782 // Handle vector<4xi8> -> i32 or vector<8xi8> -> i64.
783 int64_t numElements = vectorType.getNumElements();
784 assert((numElements == 4 || numElements == 8) &&
785 "scale operand must be a vector of length 4 or 8");
786 IntegerType outputType =
787 (numElements == 4) ? rewriter.getI32Type() : rewriter.getI64Type();
788 return LLVM::BitcastOp::create(rewriter, loc, outputType, input);
789 })
790 .DefaultUnreachable("unexpected input type for scale operand");
791}
792
793/// Maps f8 scale element types to WMMA scale format codes.
794static std::optional<uint32_t> getWmmaScaleFormat(Type elemType) {
796 .Case([](Float8E8M0FNUType) { return 0; })
797 .Case([](Float8E4M3FNType) { return 2; })
798 .Default(std::nullopt);
799}
800
801/// Determines the ROCDL intrinsic name for scaled WMMA based on dimensions
802/// and scale block size (16 or 32).
803static std::optional<StringRef>
805 if (m == 16 && n == 16 && k == 128)
806 return isScale16
807 ? ROCDL::wmma_scale16_f32_16x16x128_f8f6f4::getOperationName()
808 : ROCDL::wmma_scale_f32_16x16x128_f8f6f4::getOperationName();
809
810 if (m == 32 && n == 16 && k == 128)
811 return isScale16 ? ROCDL::wmma_scale16_f32_32x16x128_f4::getOperationName()
812 : ROCDL::wmma_scale_f32_32x16x128_f4::getOperationName();
813
814 return std::nullopt;
815}
816
817/// Push an input operand. If it is a float type, nothing to do. If it is
818/// an integer type, then we need to also push its signdness (1 for signed, 0
819/// for unsigned) and we need to pack the input 16xi8 vector into a 4xi32
820/// vector (or the 8xi8 vector into a 2xi32 one for gfx12+).
821/// We also need to convert bfloat inputs to i16 to account for the bfloat
822/// intrinsics having been defined before the AMD backend supported bfloat. We
823/// similarly need to pack 8-bit float types into integers as if they were i8
824/// (which they are for the backend's purposes).
826 ConversionPatternRewriter &rewriter, Location loc,
827 const TypeConverter *typeConverter, bool isUnsigned, Value llvmInput,
828 Value mlirInput, SmallVectorImpl<Value> &operands,
829 SmallVectorImpl<NamedAttribute> &attrs, StringRef attrName) {
830 Type inputType = llvmInput.getType();
831 auto vectorType = dyn_cast<VectorType>(inputType);
832 if (!vectorType) {
833 operands.push_back(llvmInput);
834 return;
835 }
836 Type elemType = vectorType.getElementType();
837 if (elemType.getIntOrFloatBitWidth() > 8) {
838 operands.push_back(llvmInput);
839 return;
840 }
841
842 // We need to check the type of the input before conversion to properly test
843 // for int8. This is because, in LLVM, fp8 type is converted to int8, so the
844 // fp8/int8 information is lost during the conversion process.
845 auto mlirInputType = cast<VectorType>(mlirInput.getType());
846 bool isInputInteger = mlirInputType.getElementType().isInteger();
847 if (isInputInteger) {
848 // if element type is 8-bit signed or unsigned, ignore the isUnsigned flag
849 bool localIsUnsigned = isUnsigned;
850 if (elemType.isUnsignedInteger()) {
851 localIsUnsigned = true;
852 } else if (elemType.isSignedInteger()) {
853 localIsUnsigned = false;
854 }
855 attrs.push_back(
856 NamedAttribute(attrName, rewriter.getBoolAttr(!localIsUnsigned)));
857 }
858
859 int64_t numBits =
860 vectorType.getNumElements() * elemType.getIntOrFloatBitWidth();
861 Type i32 = rewriter.getI32Type();
862 Type intrinsicInType = numBits <= 32
863 ? (Type)rewriter.getIntegerType(numBits)
864 : (Type)VectorType::get(numBits / 32, i32);
865 auto llvmIntrinsicInType = typeConverter->convertType(intrinsicInType);
866 Value castInput = rewriter.createOrFold<LLVM::BitcastOp>(
867 loc, llvmIntrinsicInType, llvmInput);
868 // The wave64-mode 16x16x16 intrinsics that take 4-bit integers only need
869 // (256 / 64) * 4 = 16 bits of input (on gfx12+) but take i32 arguments.
870 // Add in the zeros here.
871 if (numBits < 32)
872 castInput = LLVM::ZExtOp::create(rewriter, loc, i32, castInput);
873 operands.push_back(castInput);
874}
875
876/// Push the output operand. For many cases this is only pushing the output in
877/// the operand list. But when we have f16 -> f16 or bf16 -> bf16 intrinsics,
878/// since the same numbers of VGPRs is used, we need to decide if to store the
879/// result in the upper 16 bits of the VGPRs or in the lower part. To store the
880/// result in the lower 16 bits, set subwordOffset to 1, otherwise result will
881/// be stored it in the upper part. The subwordOffset must not be set for gfx12,
882/// as the instructions have been changed to return fewer registers instead.
883static void wmmaPushOutputOperand(ConversionPatternRewriter &rewriter,
884 Location loc,
885 const TypeConverter *typeConverter,
886 Value output, int32_t subwordOffset,
887 bool clamp, SmallVectorImpl<Value> &operands,
889 Type inputType = output.getType();
890 auto vectorType = dyn_cast<VectorType>(inputType);
891 Type elemType = vectorType.getElementType();
892 operands.push_back(output);
893 if (elemType.isF16() || elemType.isBF16() || elemType.isInteger(16)) {
894 attrs.push_back(
895 NamedAttribute("opsel", rewriter.getBoolAttr(subwordOffset)));
896 } else if (elemType.isInteger(32)) {
897 attrs.push_back(NamedAttribute("clamp", rewriter.getBoolAttr(clamp)));
898 }
899}
900
901/// Return true if `type` is the E5M2 variant of an 8-bit float that is
902/// supported by the `_bf8` instructions on the given `chipset`.
903static bool typeIsExpectedBf8ForChipset(Chipset chipset, Type type) {
904 return (chipset == kGfx942 && isa<Float8E5M2FNUZType>(type)) ||
905 (hasOcpFp8(chipset) && isa<Float8E5M2Type>(type));
906}
907
908/// Return true if `type` is the E4M3FN variant of an 8-bit float that is
909/// supported by the `_fp8` instructions on the given `chipset`.
910static bool typeIsExpectedFp8ForChipset(Chipset chipset, Type type) {
911 return (chipset == kGfx942 && isa<Float8E4M3FNUZType>(type)) ||
912 (hasOcpFp8(chipset) && isa<Float8E4M3FNType>(type));
913}
914
915/// Return the `rocdl` intrinsic corresponding to a MFMA operation `mfma`
916/// if one exists. This includes checking to ensure the intrinsic is supported
917/// on the architecture you are compiling for.
918static std::optional<StringRef> mfmaOpToIntrinsic(MFMAOp mfma,
919 Chipset chipset) {
920 uint32_t m = mfma.getM(), n = mfma.getN(), k = mfma.getK(),
921 b = mfma.getBlocks();
922 Type sourceElem = getElementTypeOrSelf(mfma.getSourceA().getType());
923 Type destElem = getElementTypeOrSelf(mfma.getDestC().getType());
924
925 if (sourceElem.isF32() && destElem.isF32()) {
926 if (mfma.getReducePrecision() && chipset >= kGfx942) {
927 if (m == 32 && n == 32 && k == 4 && b == 1)
928 return ROCDL::mfma_f32_32x32x4_xf32::getOperationName();
929 if (m == 16 && n == 16 && k == 8 && b == 1)
930 return ROCDL::mfma_f32_16x16x8_xf32::getOperationName();
931 }
932 if (m == 32 && n == 32 && k == 1 && b == 2)
933 return ROCDL::mfma_f32_32x32x1f32::getOperationName();
934 if (m == 16 && n == 16 && k == 1 && b == 4)
935 return ROCDL::mfma_f32_16x16x1f32::getOperationName();
936 if (m == 4 && n == 4 && k == 1 && b == 16)
937 return ROCDL::mfma_f32_4x4x1f32::getOperationName();
938 if (m == 32 && n == 32 && k == 2 && b == 1)
939 return ROCDL::mfma_f32_32x32x2f32::getOperationName();
940 if (m == 16 && n == 16 && k == 4 && b == 1)
941 return ROCDL::mfma_f32_16x16x4f32::getOperationName();
942 }
943
944 if (sourceElem.isF16() && destElem.isF32()) {
945 if (chipset >= kGfx950) {
946 if (m == 32 && n == 32 && k == 16 && b == 1)
947 return ROCDL::mfma_f32_32x32x16_f16::getOperationName();
948 if (m == 16 && n == 16 && k == 32 && b == 1)
949 return ROCDL::mfma_f32_16x16x32_f16::getOperationName();
950 }
951 if (m == 32 && n == 32 && k == 4 && b == 2)
952 return ROCDL::mfma_f32_32x32x4f16::getOperationName();
953 if (m == 16 && n == 16 && k == 4 && b == 4)
954 return ROCDL::mfma_f32_16x16x4f16::getOperationName();
955 if (m == 4 && n == 4 && k == 4 && b == 16)
956 return ROCDL::mfma_f32_4x4x4f16::getOperationName();
957 if (m == 32 && n == 32 && k == 8 && b == 1)
958 return ROCDL::mfma_f32_32x32x8f16::getOperationName();
959 if (m == 16 && n == 16 && k == 16 && b == 1)
960 return ROCDL::mfma_f32_16x16x16f16::getOperationName();
961 }
962
963 if (sourceElem.isBF16() && destElem.isF32()) {
964 if (chipset >= kGfx950) {
965 if (m == 32 && n == 32 && k == 16 && b == 1)
966 return ROCDL::mfma_f32_32x32x16_bf16::getOperationName();
967 if (m == 16 && n == 16 && k == 32 && b == 1)
968 return ROCDL::mfma_f32_16x16x32_bf16::getOperationName();
969 }
970 if (chipset >= kGfx90a) {
971 if (m == 32 && n == 32 && k == 4 && b == 2)
972 return ROCDL::mfma_f32_32x32x4bf16_1k::getOperationName();
973 if (m == 16 && n == 16 && k == 4 && b == 4)
974 return ROCDL::mfma_f32_16x16x4bf16_1k::getOperationName();
975 if (m == 4 && n == 4 && k == 4 && b == 16)
976 return ROCDL::mfma_f32_4x4x4bf16_1k::getOperationName();
977 if (m == 32 && n == 32 && k == 8 && b == 1)
978 return ROCDL::mfma_f32_32x32x8bf16_1k::getOperationName();
979 if (m == 16 && n == 16 && k == 16 && b == 1)
980 return ROCDL::mfma_f32_16x16x16bf16_1k::getOperationName();
981 }
982 if (m == 32 && n == 32 && k == 2 && b == 2)
983 return ROCDL::mfma_f32_32x32x2bf16::getOperationName();
984 if (m == 16 && n == 16 && k == 2 && b == 4)
985 return ROCDL::mfma_f32_16x16x2bf16::getOperationName();
986 if (m == 4 && n == 4 && k == 2 && b == 16)
987 return ROCDL::mfma_f32_4x4x2bf16::getOperationName();
988 if (m == 32 && n == 32 && k == 4 && b == 1)
989 return ROCDL::mfma_f32_32x32x4bf16::getOperationName();
990 if (m == 16 && n == 16 && k == 8 && b == 1)
991 return ROCDL::mfma_f32_16x16x8bf16::getOperationName();
992 }
993
994 if (sourceElem.isInteger(8) && destElem.isInteger(32)) {
995 if (chipset >= kGfx950) {
996 if (m == 32 && n == 32 && k == 32 && b == 1)
997 return ROCDL::mfma_i32_32x32x32_i8::getOperationName();
998 if (m == 16 && n == 16 && k == 64 && b == 1)
999 return ROCDL::mfma_i32_16x16x64_i8::getOperationName();
1000 }
1001 if (m == 32 && n == 32 && k == 4 && b == 2)
1002 return ROCDL::mfma_i32_32x32x4i8::getOperationName();
1003 if (m == 16 && n == 16 && k == 4 && b == 4)
1004 return ROCDL::mfma_i32_16x16x4i8::getOperationName();
1005 if (m == 4 && n == 4 && k == 4 && b == 16)
1006 return ROCDL::mfma_i32_4x4x4i8::getOperationName();
1007 if (m == 32 && n == 32 && k == 8 && b == 1)
1008 return ROCDL::mfma_i32_32x32x8i8::getOperationName();
1009 if (m == 16 && n == 16 && k == 16 && b == 1)
1010 return ROCDL::mfma_i32_16x16x16i8::getOperationName();
1011 if (m == 32 && n == 32 && k == 16 && b == 1 && chipset >= kGfx942)
1012 return ROCDL::mfma_i32_32x32x16_i8::getOperationName();
1013 if (m == 16 && n == 16 && k == 32 && b == 1 && chipset >= kGfx942)
1014 return ROCDL::mfma_i32_16x16x32_i8::getOperationName();
1015 }
1016
1017 if (sourceElem.isF64() && destElem.isF64() && chipset >= kGfx90a) {
1018 if (m == 16 && n == 16 && k == 4 && b == 1)
1019 return ROCDL::mfma_f64_16x16x4f64::getOperationName();
1020 if (m == 4 && n == 4 && k == 4 && b == 4)
1021 return ROCDL::mfma_f64_4x4x4f64::getOperationName();
1022 }
1023
1024 if (destElem.isF32() && typeIsExpectedBf8ForChipset(chipset, sourceElem)) {
1025 // Known to be correct because there are no scalar f8 instructions and
1026 // because a length mismatch will have been caught by the verifier.
1027 Type sourceBElem =
1028 cast<VectorType>(mfma.getSourceB().getType()).getElementType();
1029 if (m == 16 && n == 16 && k == 32 && b == 1) {
1030 if (typeIsExpectedBf8ForChipset(chipset, sourceBElem))
1031 return ROCDL::mfma_f32_16x16x32_bf8_bf8::getOperationName();
1032 if (typeIsExpectedFp8ForChipset(chipset, sourceBElem))
1033 return ROCDL::mfma_f32_16x16x32_bf8_fp8::getOperationName();
1034 }
1035 if (m == 32 && n == 32 && k == 16 && b == 1) {
1036 if (typeIsExpectedBf8ForChipset(chipset, sourceBElem))
1037 return ROCDL::mfma_f32_32x32x16_bf8_bf8::getOperationName();
1038 if (typeIsExpectedFp8ForChipset(chipset, sourceBElem))
1039 return ROCDL::mfma_f32_32x32x16_bf8_fp8::getOperationName();
1040 }
1041 }
1042
1043 if (destElem.isF32() && typeIsExpectedFp8ForChipset(chipset, sourceElem)) {
1044 Type sourceBElem =
1045 cast<VectorType>(mfma.getSourceB().getType()).getElementType();
1046 if (m == 16 && n == 16 && k == 32 && b == 1) {
1047 if (typeIsExpectedBf8ForChipset(chipset, sourceBElem))
1048 return ROCDL::mfma_f32_16x16x32_fp8_bf8::getOperationName();
1049 if (typeIsExpectedFp8ForChipset(chipset, sourceBElem))
1050 return ROCDL::mfma_f32_16x16x32_fp8_fp8::getOperationName();
1051 }
1052 if (m == 32 && n == 32 && k == 16 && b == 1) {
1053 if (typeIsExpectedBf8ForChipset(chipset, sourceBElem))
1054 return ROCDL::mfma_f32_32x32x16_fp8_bf8::getOperationName();
1055 if (typeIsExpectedFp8ForChipset(chipset, sourceBElem))
1056 return ROCDL::mfma_f32_32x32x16_fp8_fp8::getOperationName();
1057 }
1058 }
1059
1060 return std::nullopt;
1061}
1062
1063static std::optional<uint32_t> smallFloatTypeToFormatCode(Type mlirElemType) {
1065 .Case([](Float8E4M3FNType) { return 0u; })
1066 .Case([](Float8E5M2Type) { return 1u; })
1067 .Case([](Float6E2M3FNType) { return 2u; })
1068 .Case([](Float6E3M2FNType) { return 3u; })
1069 .Case([](Float4E2M1FNType) { return 4u; })
1070 .Default(std::nullopt);
1071}
1072
1073/// If there is a scaled MFMA instruction for the input element types `aType`
1074/// and `bType`, output type `destType`, problem size M, N, K, and B (number of
1075/// blocks) on the given `chipset`, return a tuple consisting of the
1076/// OperationName of the intrinsic and the type codes that need to be passed to
1077/// that intrinsic. Note that this is also used to implement some un-scaled
1078/// MFMAs, since the compiler represents the ordinary instruction as a "scaled"
1079/// MFMA with a scale of 0.
1080static std::optional<std::tuple<StringRef, uint32_t, uint32_t>>
1081mfmaOpToScaledIntrinsic(Type aType, Type bType, Type destType, uint32_t m,
1082 uint32_t n, uint32_t k, uint32_t b, Chipset chipset) {
1083 aType = getElementTypeOrSelf(aType);
1084 bType = getElementTypeOrSelf(bType);
1085 destType = getElementTypeOrSelf(destType);
1086
1087 if (chipset < kGfx950)
1088 return std::nullopt;
1089 if (!isa<Float32Type>(destType))
1090 return std::nullopt;
1091
1092 std::optional<uint32_t> aTypeCode = smallFloatTypeToFormatCode(aType);
1093 std::optional<uint32_t> bTypeCode = smallFloatTypeToFormatCode(bType);
1094 if (!aTypeCode || !bTypeCode)
1095 return std::nullopt;
1096
1097 if (m == 32 && n == 32 && k == 64 && b == 1)
1098 return std::tuple{ROCDL::mfma_scale_f32_32x32x64_f8f6f4::getOperationName(),
1099 *aTypeCode, *bTypeCode};
1100 if (m == 16 && n == 16 && k == 128 && b == 1)
1101 return std::tuple{
1102 ROCDL::mfma_scale_f32_16x16x128_f8f6f4::getOperationName(), *aTypeCode,
1103 *bTypeCode};
1104
1105 return std::nullopt;
1106}
1107
1108static std::optional<std::tuple<StringRef, uint32_t, uint32_t>>
1109mfmaOpToScaledIntrinsic(MFMAOp mfma, Chipset chipset) {
1111 mfma.getSourceA().getType(), mfma.getSourceB().getType(),
1112 mfma.getDestC().getType(), mfma.getM(), mfma.getN(), mfma.getK(),
1113 mfma.getBlocks(), chipset);
1114}
1115
1116static std::optional<std::tuple<StringRef, uint32_t, uint32_t>>
1117mfmaOpToScaledIntrinsic(ScaledMFMAOp smfma, Chipset chipset) {
1118 return mfmaOpToScaledIntrinsic(smfma.getSourceA().getType(),
1119 smfma.getSourceB().getType(),
1120 smfma.getDestC().getType(), smfma.getM(),
1121 smfma.getN(), smfma.getK(), 1u, chipset);
1122}
1123
1124/// Returns the `rocdl` intrinsic corresponding to a WMMA operation `wmma`
1125/// for RDNA3/4 architectures.
1126static std::optional<StringRef>
1127wmmaOpToIntrinsicRDNA(Type elemSourceType, Type elemBSourceType,
1128 Type elemDestType, uint32_t k, bool isRDNA3) {
1129 using fp8 = Float8E4M3FNType;
1130 using bf8 = Float8E5M2Type;
1131
1132 // Handle k == 16 for RDNA3/4.
1133 if (k == 16) {
1134 // Common patterns for RDNA3 and RDNA4.
1135 if (elemSourceType.isF16() && elemDestType.isF32())
1136 return ROCDL::wmma_f32_16x16x16_f16::getOperationName();
1137 if (elemSourceType.isBF16() && elemDestType.isF32())
1138 return ROCDL::wmma_f32_16x16x16_bf16::getOperationName();
1139 if (elemSourceType.isF16() && elemDestType.isF16())
1140 return ROCDL::wmma_f16_16x16x16_f16::getOperationName();
1141 if (elemSourceType.isBF16() && elemDestType.isBF16())
1142 return ROCDL::wmma_bf16_16x16x16_bf16::getOperationName();
1143 if (elemSourceType.isInteger(8) && elemDestType.isInteger(32))
1144 return ROCDL::wmma_i32_16x16x16_iu8::getOperationName();
1145
1146 // RDNA3 specific patterns.
1147 if (isRDNA3) {
1148 if (elemSourceType.isInteger(4) && elemDestType.isInteger(32))
1149 return ROCDL::wmma_i32_16x16x16_iu4::getOperationName();
1150 return std::nullopt;
1151 }
1152
1153 // RDNA4 specific patterns (fp8/bf8).
1154 if (isa<fp8>(elemSourceType) && isa<fp8>(elemBSourceType) &&
1155 elemDestType.isF32())
1156 return ROCDL::wmma_f32_16x16x16_fp8_fp8::getOperationName();
1157 if (isa<fp8>(elemSourceType) && isa<bf8>(elemBSourceType) &&
1158 elemDestType.isF32())
1159 return ROCDL::wmma_f32_16x16x16_fp8_bf8::getOperationName();
1160 if (isa<bf8>(elemSourceType) && isa<bf8>(elemBSourceType) &&
1161 elemDestType.isF32())
1162 return ROCDL::wmma_f32_16x16x16_bf8_bf8::getOperationName();
1163 if (isa<bf8>(elemSourceType) && isa<fp8>(elemBSourceType) &&
1164 elemDestType.isF32())
1165 return ROCDL::wmma_f32_16x16x16_bf8_fp8::getOperationName();
1166 if (elemSourceType.isInteger(4) && elemDestType.isInteger(32))
1167 return ROCDL::wmma_i32_16x16x16_iu4::getOperationName();
1168
1169 return std::nullopt;
1170 }
1171
1172 // Handle k == 32 for RDNA4.
1173 if (k == 32 && !isRDNA3) {
1174 if (elemSourceType.isInteger(4) && elemDestType.isInteger(32))
1175 return ROCDL::wmma_i32_16x16x32_iu4::getOperationName();
1176 }
1177
1178 return std::nullopt;
1179}
1180
1181/// Return the `rocdl` intrinsic corresponding to a WMMA operation `wmma`
1182/// for the gfx1250 architecture.
1183static std::optional<StringRef> wmmaOpToIntrinsicGfx1250(Type elemSourceType,
1184 Type elemBSourceType,
1185 Type elemDestType,
1186 uint32_t k) {
1187 using fp8 = Float8E4M3FNType;
1188 using bf8 = Float8E5M2Type;
1189
1190 if (k == 4) {
1191 if (elemSourceType.isF32() && elemDestType.isF32())
1192 return ROCDL::wmma_f32_16x16x4_f32::getOperationName();
1193
1194 return std::nullopt;
1195 }
1196
1197 if (k == 32) {
1198 if (elemSourceType.isF16() && elemDestType.isF32())
1199 return ROCDL::wmma_f32_16x16x32_f16::getOperationName();
1200 if (elemSourceType.isBF16() && elemDestType.isF32())
1201 return ROCDL::wmma_f32_16x16x32_bf16::getOperationName();
1202 if (elemSourceType.isF16() && elemDestType.isF16())
1203 return ROCDL::wmma_f16_16x16x32_f16::getOperationName();
1204 if (elemSourceType.isBF16() && elemDestType.isBF16())
1205 return ROCDL::wmma_bf16_16x16x32_bf16::getOperationName();
1206
1207 return std::nullopt;
1208 }
1209
1210 if (k == 64) {
1211 if (isa<fp8>(elemSourceType) && isa<fp8>(elemBSourceType)) {
1212 if (elemDestType.isF32())
1213 return ROCDL::wmma_f32_16x16x64_fp8_fp8::getOperationName();
1214 if (elemDestType.isF16())
1215 return ROCDL::wmma_f16_16x16x64_fp8_fp8::getOperationName();
1216 }
1217 if (isa<fp8>(elemSourceType) && isa<bf8>(elemBSourceType)) {
1218 if (elemDestType.isF32())
1219 return ROCDL::wmma_f32_16x16x64_fp8_bf8::getOperationName();
1220 if (elemDestType.isF16())
1221 return ROCDL::wmma_f16_16x16x64_fp8_bf8::getOperationName();
1222 }
1223 if (isa<bf8>(elemSourceType) && isa<bf8>(elemBSourceType)) {
1224 if (elemDestType.isF32())
1225 return ROCDL::wmma_f32_16x16x64_bf8_bf8::getOperationName();
1226 if (elemDestType.isF16())
1227 return ROCDL::wmma_f16_16x16x64_bf8_bf8::getOperationName();
1228 }
1229 if (isa<bf8>(elemSourceType) && isa<fp8>(elemBSourceType)) {
1230 if (elemDestType.isF32())
1231 return ROCDL::wmma_f32_16x16x64_bf8_fp8::getOperationName();
1232 if (elemDestType.isF16())
1233 return ROCDL::wmma_f16_16x16x64_bf8_fp8::getOperationName();
1234 }
1235 if (elemSourceType.isInteger(8) && elemDestType.isInteger(32))
1236 return ROCDL::wmma_i32_16x16x64_iu8::getOperationName();
1237
1238 return std::nullopt;
1239 }
1240
1241 if (k == 128) {
1242 if (isa<fp8>(elemSourceType) && isa<fp8>(elemBSourceType)) {
1243 if (elemDestType.isF32())
1244 return ROCDL::wmma_f32_16x16x128_fp8_fp8::getOperationName();
1245 if (elemDestType.isF16())
1246 return ROCDL::wmma_f16_16x16x128_fp8_fp8::getOperationName();
1247 }
1248 if (isa<fp8>(elemSourceType) && isa<bf8>(elemBSourceType)) {
1249 if (elemDestType.isF32())
1250 return ROCDL::wmma_f32_16x16x128_fp8_bf8::getOperationName();
1251 if (elemDestType.isF16())
1252 return ROCDL::wmma_f16_16x16x128_fp8_bf8::getOperationName();
1253 }
1254 if (isa<bf8>(elemSourceType) && isa<bf8>(elemBSourceType)) {
1255 if (elemDestType.isF32())
1256 return ROCDL::wmma_f32_16x16x128_bf8_bf8::getOperationName();
1257 if (elemDestType.isF16())
1258 return ROCDL::wmma_f16_16x16x128_bf8_bf8::getOperationName();
1259 }
1260 if (isa<bf8>(elemSourceType) && isa<fp8>(elemBSourceType)) {
1261 if (elemDestType.isF32())
1262 return ROCDL::wmma_f32_16x16x128_bf8_fp8::getOperationName();
1263 if (elemDestType.isF16())
1264 return ROCDL::wmma_f16_16x16x128_bf8_fp8::getOperationName();
1265 }
1266
1267 return std::nullopt;
1268 }
1269
1270 return std::nullopt;
1271}
1272
1273/// Returns the `rocdl` intrinsic corresponding to a SparseMFMA (smfmac)
1274/// operation if one exists. This includes checking to ensure the intrinsic is
1275/// supported on the architecture you are compiling for.
1276static std::optional<StringRef> smfmacOpToIntrinsic(SparseMFMAOp op,
1277 Chipset chipset) {
1278 bool isGfx950 = chipset >= kGfx950;
1279 auto isFp8 = [&](Type t) { return typeIsExpectedFp8ForChipset(chipset, t); };
1280 auto isBf8 = [&](Type t) { return typeIsExpectedBf8ForChipset(chipset, t); };
1281
1282 uint32_t m = op.getM(), n = op.getN(), k = op.getK();
1283 Type sourceAElem = getElementTypeOrSelf(op.getSourceA().getType());
1284 Type sourceBElem = getElementTypeOrSelf(op.getSourceB().getType());
1285 Type destElem = getElementTypeOrSelf(op.getDestC().getType());
1286
1287 if (m == 16 && n == 16 && k == 32) {
1288 if (sourceAElem.isF16() && sourceBElem.isF16() && destElem.isF32())
1289 return ROCDL::smfmac_f32_16x16x32_f16::getOperationName();
1290 if (sourceAElem.isBF16() && sourceBElem.isBF16() && destElem.isF32())
1291 return ROCDL::smfmac_f32_16x16x32_bf16::getOperationName();
1292 }
1293
1294 if (m == 16 && n == 16 && k == 64) {
1295 if (isGfx950) {
1296 if (sourceAElem.isF16() && sourceBElem.isF16() && destElem.isF32())
1297 return ROCDL::smfmac_f32_16x16x64_f16::getOperationName();
1298 if (sourceAElem.isBF16() && sourceBElem.isBF16() && destElem.isF32())
1299 return ROCDL::smfmac_f32_16x16x64_bf16::getOperationName();
1300 }
1301 if (sourceAElem.isInteger(8) && sourceBElem.isInteger(8) &&
1302 destElem.isInteger(32))
1303 return ROCDL::smfmac_i32_16x16x64_i8::getOperationName();
1304 if (isFp8(sourceAElem) && isFp8(sourceBElem) && destElem.isF32())
1305 return ROCDL::smfmac_f32_16x16x64_fp8_fp8::getOperationName();
1306 if (isFp8(sourceAElem) && isBf8(sourceBElem) && destElem.isF32())
1307 return ROCDL::smfmac_f32_16x16x64_fp8_bf8::getOperationName();
1308 if (isBf8(sourceAElem) && isFp8(sourceBElem) && destElem.isF32())
1309 return ROCDL::smfmac_f32_16x16x64_bf8_fp8::getOperationName();
1310 if (isBf8(sourceAElem) && isBf8(sourceBElem) && destElem.isF32())
1311 return ROCDL::smfmac_f32_16x16x64_bf8_bf8::getOperationName();
1312 }
1313
1314 if (m == 16 && n == 16 && k == 128 && isGfx950) {
1315 if (sourceAElem.isInteger(8) && sourceBElem.isInteger(8) &&
1316 destElem.isInteger(32))
1317 return ROCDL::smfmac_i32_16x16x128_i8::getOperationName();
1318 if (isFp8(sourceAElem) && isFp8(sourceBElem) && destElem.isF32())
1319 return ROCDL::smfmac_f32_16x16x128_fp8_fp8::getOperationName();
1320 if (isFp8(sourceAElem) && isBf8(sourceBElem) && destElem.isF32())
1321 return ROCDL::smfmac_f32_16x16x128_fp8_bf8::getOperationName();
1322 if (isBf8(sourceAElem) && isFp8(sourceBElem) && destElem.isF32())
1323 return ROCDL::smfmac_f32_16x16x128_bf8_fp8::getOperationName();
1324 if (isBf8(sourceAElem) && isBf8(sourceBElem) && destElem.isF32())
1325 return ROCDL::smfmac_f32_16x16x128_bf8_bf8::getOperationName();
1326 }
1327
1328 if (m == 32 && n == 32 && k == 16) {
1329 if (sourceAElem.isF16() && sourceBElem.isF16() && destElem.isF32())
1330 return ROCDL::smfmac_f32_32x32x16_f16::getOperationName();
1331 if (sourceAElem.isBF16() && sourceBElem.isBF16() && destElem.isF32())
1332 return ROCDL::smfmac_f32_32x32x16_bf16::getOperationName();
1333 }
1334
1335 if (m == 32 && n == 32 && k == 32) {
1336 if (isGfx950) {
1337 if (sourceAElem.isF16() && sourceBElem.isF16() && destElem.isF32())
1338 return ROCDL::smfmac_f32_32x32x32_f16::getOperationName();
1339 if (sourceAElem.isBF16() && sourceBElem.isBF16() && destElem.isF32())
1340 return ROCDL::smfmac_f32_32x32x32_bf16::getOperationName();
1341 }
1342 if (sourceAElem.isInteger(8) && sourceBElem.isInteger(8) &&
1343 destElem.isInteger(32))
1344 return ROCDL::smfmac_i32_32x32x32_i8::getOperationName();
1345 if (isFp8(sourceAElem) && isFp8(sourceBElem) && destElem.isF32())
1346 return ROCDL::smfmac_f32_32x32x32_fp8_fp8::getOperationName();
1347 if (isFp8(sourceAElem) && isBf8(sourceBElem) && destElem.isF32())
1348 return ROCDL::smfmac_f32_32x32x32_fp8_bf8::getOperationName();
1349 if (isBf8(sourceAElem) && isFp8(sourceBElem) && destElem.isF32())
1350 return ROCDL::smfmac_f32_32x32x32_bf8_fp8::getOperationName();
1351 if (isBf8(sourceAElem) && isBf8(sourceBElem) && destElem.isF32())
1352 return ROCDL::smfmac_f32_32x32x32_bf8_bf8::getOperationName();
1353 }
1354
1355 if (m == 32 && n == 32 && k == 64 && isGfx950) {
1356 if (sourceAElem.isInteger(8) && sourceBElem.isInteger(8) &&
1357 destElem.isInteger(32))
1358 return ROCDL::smfmac_i32_32x32x64_i8::getOperationName();
1359 if (isFp8(sourceAElem) && isFp8(sourceBElem) && destElem.isF32())
1360 return ROCDL::smfmac_f32_32x32x64_fp8_fp8::getOperationName();
1361 if (isFp8(sourceAElem) && isBf8(sourceBElem) && destElem.isF32())
1362 return ROCDL::smfmac_f32_32x32x64_fp8_bf8::getOperationName();
1363 if (isBf8(sourceAElem) && isFp8(sourceBElem) && destElem.isF32())
1364 return ROCDL::smfmac_f32_32x32x64_bf8_fp8::getOperationName();
1365 if (isBf8(sourceAElem) && isBf8(sourceBElem) && destElem.isF32())
1366 return ROCDL::smfmac_f32_32x32x64_bf8_bf8::getOperationName();
1367 }
1368
1369 return std::nullopt;
1370}
1371
1372/// Returns the `rocdl` intrinsic corresponding to a WMMA operation `wmma`
1373/// if one exists. This includes checking to ensure the intrinsic is supported
1374/// on the architecture you are compiling for.
1375static std::optional<StringRef> wmmaOpToIntrinsic(WMMAOp wmma,
1376 Chipset chipset) {
1377 auto sourceVectorType = cast<VectorType>(wmma.getSourceA().getType());
1378 auto sourceBVectorType = cast<VectorType>(wmma.getSourceB().getType());
1379 auto destVectorType = cast<VectorType>(wmma.getDestC().getType());
1380 Type elemSourceType = sourceVectorType.getElementType();
1381 Type elemBSourceType = sourceBVectorType.getElementType();
1382 Type elemDestType = destVectorType.getElementType();
1383
1384 const uint32_t k = wmma.getK();
1385 const bool isRDNA3 = chipset.majorVersion == 11;
1386 const bool isRDNA4 = chipset.majorVersion == 12 && chipset.minorVersion == 0;
1387
1388 // Handle RDNA3 and RDNA4.
1389 if (isRDNA3 || isRDNA4)
1390 return wmmaOpToIntrinsicRDNA(elemSourceType, elemBSourceType, elemDestType,
1391 k, isRDNA3);
1392
1393 // Handle gfx1250.
1394 if (chipset == kGfx1250)
1395 return wmmaOpToIntrinsicGfx1250(elemSourceType, elemBSourceType,
1396 elemDestType, k);
1397
1398 return std::nullopt;
1399}
1400
1401/// Returns the `rocdl` intrinsic corresponding to a SparseWMMA operation
1402/// `swmmac` if one exists. This includes checking to ensure the intrinsic is
1403/// supported on the architecture you are compiling for.
1405 StringRef name;
1409};
1410
1411static std::optional<SparseWMMAOpInfo>
1412sparseWMMAOpToIntrinsic(SparseWMMAOp swmmac, Chipset chipset) {
1413 Type sourceAElem = getElementTypeOrSelf(swmmac.getSourceA().getType());
1414 Type sourceBElem = getElementTypeOrSelf(swmmac.getSourceB().getType());
1415 Type destElem = getElementTypeOrSelf(swmmac.getDestC().getType());
1416
1417 uint32_t m = swmmac.getM(), n = swmmac.getN(), k = swmmac.getK();
1418
1419 if ((m != 16) || (n != 16))
1420 return std::nullopt;
1421
1422 const bool isRDNA4 = chipset.majorVersion == 12 && chipset.minorVersion == 0;
1423 if (isRDNA4) {
1424 if (k == 32) {
1425 if (destElem.isF32() && sourceAElem.isF16() && sourceBElem.isF16())
1426 return SparseWMMAOpInfo{
1427 ROCDL::swmmac_f32_16x16x32_f16::getOperationName(), false, false,
1428 false};
1429 if (destElem.isF32() && sourceAElem.isBF16() && sourceBElem.isBF16())
1430 return SparseWMMAOpInfo{
1431 ROCDL::swmmac_f32_16x16x32_bf16::getOperationName(), false, false,
1432 false};
1433 if (destElem.isF16() && sourceAElem.isF16() && sourceBElem.isF16())
1434 return SparseWMMAOpInfo{
1435 ROCDL::swmmac_f16_16x16x32_f16::getOperationName(), false, false,
1436 false};
1437 if (destElem.isBF16() && sourceAElem.isBF16() && sourceBElem.isBF16())
1438 return SparseWMMAOpInfo{
1439 ROCDL::swmmac_bf16_16x16x32_bf16::getOperationName(), false, false,
1440 false};
1441 if (destElem.isInteger(32) && sourceAElem.isInteger(8) &&
1442 sourceBElem.isInteger(8))
1443 return SparseWMMAOpInfo{
1444 ROCDL::swmmac_i32_16x16x32_iu8::getOperationName(), true, false,
1445 true};
1446 if (destElem.isInteger(32) && sourceAElem.isInteger(4) &&
1447 sourceBElem.isInteger(4))
1448 return SparseWMMAOpInfo{
1449 ROCDL::swmmac_i32_16x16x32_iu4::getOperationName(), true, false,
1450 true};
1451 if (destElem.isF32() && sourceAElem.isF8E4M3FN() &&
1452 sourceBElem.isF8E4M3FN())
1453 return SparseWMMAOpInfo{
1454 ROCDL::swmmac_f32_16x16x32_fp8_fp8::getOperationName(), false,
1455 false, false};
1456 if (destElem.isF32() && sourceAElem.isF8E4M3FN() &&
1457 sourceBElem.isF8E5M2())
1458 return SparseWMMAOpInfo{
1459 ROCDL::swmmac_f32_16x16x32_fp8_bf8::getOperationName(), false,
1460 false, false};
1461 if (destElem.isF32() && sourceAElem.isF8E5M2() &&
1462 sourceBElem.isF8E4M3FN())
1463 return SparseWMMAOpInfo{
1464 ROCDL::swmmac_f32_16x16x32_bf8_fp8::getOperationName(), false,
1465 false, false};
1466 if (destElem.isF32() && sourceAElem.isF8E5M2() && sourceBElem.isF8E5M2())
1467 return SparseWMMAOpInfo{
1468 ROCDL::swmmac_f32_16x16x32_bf8_bf8::getOperationName(), false,
1469 false, false};
1470 }
1471 if (k == 64) {
1472 if (destElem.isInteger(32) && sourceAElem.isInteger(4) &&
1473 sourceBElem.isInteger(4))
1474 return SparseWMMAOpInfo{
1475 ROCDL::swmmac_i32_16x16x64_iu4::getOperationName(), true, false,
1476 true};
1477 }
1478 }
1479
1480 const bool isGFX1250 = chipset == kGfx1250;
1481 const bool isWavesize64 = swmmac.getWave64();
1482 if (isGFX1250 && !isWavesize64) {
1483 if (k == 64) {
1484 if (destElem.isF32() && sourceAElem.isF16() && sourceBElem.isF16())
1485 return SparseWMMAOpInfo{
1486 ROCDL::swmmac_f32_16x16x64_f16::getOperationName(), true, true,
1487 false};
1488 if (destElem.isF32() && sourceAElem.isBF16() && sourceBElem.isBF16())
1489 return SparseWMMAOpInfo{
1490 ROCDL::swmmac_f32_16x16x64_bf16::getOperationName(), true, true,
1491 false};
1492 if (destElem.isF16() && sourceAElem.isF16() && sourceBElem.isF16())
1493 return SparseWMMAOpInfo{
1494 ROCDL::swmmac_f16_16x16x64_f16::getOperationName(), true, true,
1495 false};
1496 if (destElem.isBF16() && sourceAElem.isBF16() && sourceBElem.isBF16())
1497 return SparseWMMAOpInfo{
1498 ROCDL::swmmac_bf16_16x16x64_bf16::getOperationName(), true, true,
1499 false};
1500 }
1501 if (k == 128) {
1502 if (destElem.isF32() && sourceAElem.isF8E4M3FN() &&
1503 sourceBElem.isF8E4M3FN())
1504 return SparseWMMAOpInfo{
1505 ROCDL::swmmac_f32_16x16x128_fp8_fp8::getOperationName(), false,
1506 true, false};
1507 if (destElem.isF32() && sourceAElem.isF8E4M3FN() &&
1508 sourceBElem.isF8E5M2())
1509 return SparseWMMAOpInfo{
1510 ROCDL::swmmac_f32_16x16x128_fp8_bf8::getOperationName(), false,
1511 true, false};
1512 if (destElem.isF32() && sourceAElem.isF8E5M2() &&
1513 sourceBElem.isF8E4M3FN())
1514 return SparseWMMAOpInfo{
1515 ROCDL::swmmac_f32_16x16x128_bf8_fp8::getOperationName(), false,
1516 true, false};
1517 if (destElem.isF32() && sourceAElem.isF8E5M2() && sourceBElem.isF8E5M2())
1518 return SparseWMMAOpInfo{
1519 ROCDL::swmmac_f32_16x16x128_bf8_bf8::getOperationName(), false,
1520 true, false};
1521 if (destElem.isF16() && sourceAElem.isF8E4M3FN() &&
1522 sourceBElem.isF8E4M3FN())
1523 return SparseWMMAOpInfo{
1524 ROCDL::swmmac_f16_16x16x128_fp8_fp8::getOperationName(), false,
1525 true, false};
1526 if (destElem.isF16() && sourceAElem.isF8E4M3FN() &&
1527 sourceBElem.isF8E5M2())
1528 return SparseWMMAOpInfo{
1529 ROCDL::swmmac_f16_16x16x128_fp8_bf8::getOperationName(), false,
1530 true, false};
1531 if (destElem.isF16() && sourceAElem.isF8E5M2() &&
1532 sourceBElem.isF8E4M3FN())
1533 return SparseWMMAOpInfo{
1534 ROCDL::swmmac_f16_16x16x128_bf8_fp8::getOperationName(), false,
1535 true, false};
1536 if (destElem.isF16() && sourceAElem.isF8E5M2() && sourceBElem.isF8E5M2())
1537 return SparseWMMAOpInfo{
1538 ROCDL::swmmac_f16_16x16x128_bf8_bf8::getOperationName(), false,
1539 true, false};
1540 if (destElem.isF16() && sourceAElem.isInteger(8) &&
1541 sourceBElem.isInteger(8))
1542 return SparseWMMAOpInfo{
1543 ROCDL::swmmac_f16_16x16x128_bf8_bf8::getOperationName(), false,
1544 true, false};
1545 if (destElem.isInteger(32) && sourceAElem.isInteger(8) &&
1546 sourceBElem.isInteger(8))
1547 return SparseWMMAOpInfo{
1548 ROCDL::swmmac_i32_16x16x128_iu8::getOperationName(), true, true,
1549 true};
1550 }
1551 }
1552
1553 return std::nullopt;
1554}
1555
1556namespace {
1557struct MFMAOpLowering : public ConvertOpToLLVMPattern<MFMAOp> {
1558 MFMAOpLowering(const LLVMTypeConverter &converter, Chipset chipset)
1559 : ConvertOpToLLVMPattern<MFMAOp>(converter), chipset(chipset) {}
1560
1561 Chipset chipset;
1562
1563 LogicalResult
1564 matchAndRewrite(MFMAOp op, MFMAOpAdaptor adaptor,
1565 ConversionPatternRewriter &rewriter) const override {
1566 Location loc = op.getLoc();
1567 Type outType = typeConverter->convertType(op.getDestD().getType());
1568 Type intrinsicOutType = outType;
1569 if (auto outVecType = dyn_cast<VectorType>(outType))
1570 if (outVecType.getElementType().isBF16())
1571 intrinsicOutType = outVecType.clone(rewriter.getI16Type());
1572
1573 if (chipset.majorVersion != 9 || chipset < kGfx908)
1574 return op->emitOpError("MFMA only supported on gfx908+");
1575 uint32_t getBlgpField = static_cast<uint32_t>(op.getBlgp());
1576 if (op.getNegateA() || op.getNegateB() || op.getNegateC()) {
1577 if (chipset < kGfx942)
1578 return op.emitOpError("negation unsupported on older than gfx942");
1579 getBlgpField |=
1580 op.getNegateA() | (op.getNegateB() << 1) | (op.getNegateC() << 2);
1581 }
1582 std::optional<StringRef> maybeIntrinsic = mfmaOpToIntrinsic(op, chipset);
1583 std::optional<std::tuple<StringRef, uint32_t, uint32_t>>
1584 maybeScaledIntrinsic = mfmaOpToScaledIntrinsic(op, chipset);
1585 if (!maybeIntrinsic.has_value() && !maybeScaledIntrinsic.has_value())
1586 return op.emitOpError("no intrinsic matching MFMA size on given chipset");
1587
1588 bool isScaled =
1589 !maybeIntrinsic.has_value() && maybeScaledIntrinsic.has_value();
1590 if (isScaled &&
1591 (adaptor.getAbid() > 0 || getBlgpField > 0 || op.getCbsz() > 0)) {
1592 return op.emitOpError(
1593 "non-default abid, blgp, and cbsz aren't supported on MFMAs that can "
1594 "be scaled as those fields are used for type information");
1595 }
1596
1597 StringRef intrinsicName =
1598 isScaled ? std::get<0>(*maybeScaledIntrinsic) : *maybeIntrinsic;
1599 // Determine if we can use bf16 in the intrinsic. Newer MFMAs in gfx950+
1600 // allows bf16 as the input. For reference check IntrinsicsAMDGPU.td file.
1601 bool allowBf16 = [&]() {
1602 if (chipset < kGfx950)
1603 return false;
1604 if (isScaled)
1605 return true;
1606 return intrinsicName.contains("16x16x32.bf16") ||
1607 intrinsicName.contains("32x32x16.bf16");
1608 }();
1609 OperationState loweredOp(loc, intrinsicName);
1610 loweredOp.addTypes(intrinsicOutType);
1611 loweredOp.addOperands({packSmallFloatVectorOperand(
1612 rewriter, loc, adaptor.getSourceA(), allowBf16),
1614 rewriter, loc, adaptor.getSourceB(), allowBf16),
1615 adaptor.getDestC()});
1616 if (isScaled) {
1617 Value zero = createI32Constant(rewriter, loc, 0);
1618 auto [_scaledName, aTypeCode, bTypeCode] = *maybeScaledIntrinsic;
1619 loweredOp.addOperands({/*scale A=*/zero, /*scale B=*/zero});
1620 loweredOp.addAttributes({{"cbsz", rewriter.getI32IntegerAttr(aTypeCode)},
1621 {"blgp", rewriter.getI32IntegerAttr(bTypeCode)},
1622 {"opselA", rewriter.getI32IntegerAttr(0)},
1623 {"opselB", rewriter.getI32IntegerAttr(0)}});
1624 } else {
1625 loweredOp.addAttributes(
1626 {{"cbsz", rewriter.getI32IntegerAttr(op.getCbsz())},
1627 {"abid", rewriter.getI32IntegerAttr(op.getAbid())},
1628 {"blgp", rewriter.getI32IntegerAttr(getBlgpField)}});
1629 };
1630 Value lowered = rewriter.create(loweredOp)->getResult(0);
1631 if (outType != intrinsicOutType)
1632 lowered = LLVM::BitcastOp::create(rewriter, loc, outType, lowered);
1633 rewriter.replaceOp(op, lowered);
1634 return success();
1635 }
1636};
1637
1638struct ScaledMFMAOpLowering : public ConvertOpToLLVMPattern<ScaledMFMAOp> {
1639 ScaledMFMAOpLowering(const LLVMTypeConverter &converter, Chipset chipset)
1640 : ConvertOpToLLVMPattern(converter), chipset(chipset) {}
1641
1642 Chipset chipset;
1643
1644 LogicalResult
1645 matchAndRewrite(ScaledMFMAOp op, ScaledMFMAOpAdaptor adaptor,
1646 ConversionPatternRewriter &rewriter) const override {
1647 Location loc = op.getLoc();
1648 Type intrinsicOutType = typeConverter->convertType(op.getDestD().getType());
1649
1650 if (chipset.majorVersion != 9 || chipset < kGfx950)
1651 return op->emitOpError("scaled MFMA only supported on gfx908+");
1652 std::optional<std::tuple<StringRef, uint32_t, uint32_t>>
1653 maybeScaledIntrinsic = mfmaOpToScaledIntrinsic(op, chipset);
1654 if (!maybeScaledIntrinsic.has_value())
1655 return op.emitOpError(
1656 "no intrinsic matching scaled MFMA size on given chipset");
1657
1658 auto [intrinsicName, aTypeCode, bTypeCode] = *maybeScaledIntrinsic;
1659 OperationState loweredOp(loc, intrinsicName);
1660 loweredOp.addTypes(intrinsicOutType);
1661 loweredOp.addOperands(
1662 {packSmallFloatVectorOperand(rewriter, loc, adaptor.getSourceA()),
1663 packSmallFloatVectorOperand(rewriter, loc, adaptor.getSourceB()),
1664 adaptor.getDestC()});
1665 loweredOp.addOperands(
1666 {/*scales A*/
1667 castScaleOperand(rewriter, loc, adaptor.getScalesA()),
1668 /*scales B*/
1669 castScaleOperand(rewriter, loc, adaptor.getScalesB())});
1670 loweredOp.addAttributes(
1671 {{"cbsz", rewriter.getI32IntegerAttr(aTypeCode)},
1672 {"blgp", rewriter.getI32IntegerAttr(bTypeCode)},
1673 {"opselA", rewriter.getI32IntegerAttr(adaptor.getScalesIdxA())},
1674 {"opselB", rewriter.getI32IntegerAttr(adaptor.getScalesIdxB())}});
1675
1676 Value lowered = rewriter.create(loweredOp)->getResult(0);
1677 rewriter.replaceOp(op, lowered);
1678 return success();
1679 }
1680};
1681
1682struct SparseMFMAOpLowering : public ConvertOpToLLVMPattern<SparseMFMAOp> {
1683 SparseMFMAOpLowering(const LLVMTypeConverter &converter, Chipset chipset)
1684 : ConvertOpToLLVMPattern<SparseMFMAOp>(converter), chipset(chipset) {}
1685
1686 Chipset chipset;
1687
1688 LogicalResult
1689 matchAndRewrite(SparseMFMAOp op, SparseMFMAOpAdaptor adaptor,
1690 ConversionPatternRewriter &rewriter) const override {
1691 Location loc = op.getLoc();
1692 auto outType =
1693 typeConverter->convertType<VectorType>(op.getDestC().getType());
1694 if (!outType)
1695 return rewriter.notifyMatchFailure(op, "type conversion failed");
1696
1697 // smfmac is supported on gfx942 and gfx950.
1698 if (chipset.majorVersion != 9 || chipset < kGfx942)
1699 return op->emitOpError("sparse MFMA (smfmac) only supported on gfx942+");
1700
1701 std::optional<StringRef> maybeIntrinsic = smfmacOpToIntrinsic(op, chipset);
1702 if (!maybeIntrinsic.has_value())
1703 return op.emitOpError(
1704 "no intrinsic matching sparse MFMA on the given chipset");
1705 bool isGfx942BF16 =
1706 (*maybeIntrinsic ==
1707 ROCDL::smfmac_f32_16x16x32_bf16::getOperationName() ||
1708 *maybeIntrinsic ==
1709 ROCDL::smfmac_f32_32x32x16_bf16::getOperationName());
1710 bool isGfx950 = (chipset >= kGfx950) && !isGfx942BF16;
1711
1712 Value a = convertPackedVectorOperand(rewriter, loc, adaptor.getSourceA(),
1713 isGfx950);
1714 Value b = convertPackedVectorOperand(rewriter, loc, adaptor.getSourceB(),
1715 isGfx950);
1716 Value c = adaptor.getDestC();
1717
1718 // Bitcast sparse indices from vector<4xi8> or vector<2xi16> to i32.
1719 // gfx950 8-bit variants already carry the index as i32; skip the bitcast.
1720 Value sparseIdx = adaptor.getSparseIdx();
1721 Type i32Type = rewriter.getI32Type();
1722 if (sparseIdx.getType() != i32Type)
1723 sparseIdx = LLVM::BitcastOp::create(rewriter, loc, i32Type, sparseIdx);
1724
1725 OperationState loweredOp(loc, maybeIntrinsic.value());
1726 loweredOp.addTypes(outType);
1727 loweredOp.addOperands({a, b, c, sparseIdx});
1728 loweredOp.addAttributes(
1729 {{"cbsz", rewriter.getI32IntegerAttr(op.getCbsz())},
1730 {"abid", rewriter.getI32IntegerAttr(op.getAbid())}});
1731 Value lowered = rewriter.create(loweredOp)->getResult(0);
1732 rewriter.replaceOp(op, lowered);
1733 return success();
1734 }
1735};
1736
1737struct WMMAOpLowering : public ConvertOpToLLVMPattern<WMMAOp> {
1738 WMMAOpLowering(const LLVMTypeConverter &converter, Chipset chipset)
1739 : ConvertOpToLLVMPattern<WMMAOp>(converter), chipset(chipset) {}
1740
1741 Chipset chipset;
1742
1743 LogicalResult
1744 matchAndRewrite(WMMAOp op, WMMAOpAdaptor adaptor,
1745 ConversionPatternRewriter &rewriter) const override {
1746 Location loc = op.getLoc();
1747 auto outType =
1748 typeConverter->convertType<VectorType>(op.getDestD().getType());
1749 if (!outType)
1750 return rewriter.notifyMatchFailure(op, "type conversion failed");
1751
1752 if (chipset.majorVersion != 11 && chipset.majorVersion != 12)
1753 return op->emitOpError("WMMA only supported on gfx11 and gfx12");
1754
1755 bool isGFX1250 = chipset >= kGfx1250;
1756
1757 // The WMMA operations represent vectors of bf16s as vectors of i16s
1758 // (except on gfx1250), so we need to bitcast bfloats to i16 and then
1759 // bitcast them back.
1760 auto aType = cast<VectorType>(adaptor.getSourceA().getType());
1761 auto bType = cast<VectorType>(adaptor.getSourceB().getType());
1762 auto destCType = cast<VectorType>(adaptor.getDestC().getType());
1763 bool castAToI16 = aType.getElementType().isBF16() && !isGFX1250;
1764 bool castBToI16 = bType.getElementType().isBF16() && !isGFX1250;
1765 bool castDestCToI16 = destCType.getElementType().isBF16() && !isGFX1250;
1766 bool castOutToI16 = outType.getElementType().isBF16() && !isGFX1250;
1767 VectorType rawOutType = outType;
1768 if (castOutToI16)
1769 rawOutType = outType.clone(rewriter.getI16Type());
1770 Value a = adaptor.getSourceA();
1771 if (castAToI16)
1772 a = LLVM::BitcastOp::create(rewriter, loc,
1773 aType.clone(rewriter.getI16Type()), a);
1774 Value b = adaptor.getSourceB();
1775 if (castBToI16)
1776 b = LLVM::BitcastOp::create(rewriter, loc,
1777 bType.clone(rewriter.getI16Type()), b);
1778 Value destC = adaptor.getDestC();
1779 if (castDestCToI16)
1780 destC = LLVM::BitcastOp::create(
1781 rewriter, loc, destCType.clone(rewriter.getI16Type()), destC);
1782
1783 std::optional<StringRef> maybeIntrinsic = wmmaOpToIntrinsic(op, chipset);
1784
1785 if (!maybeIntrinsic.has_value())
1786 return op.emitOpError("no intrinsic matching WMMA on the given chipset");
1787
1788 if (chipset.majorVersion >= 12 && op.getSubwordOffset() != 0)
1789 return op.emitOpError("subwordOffset not supported on gfx12+");
1790
1791 SmallVector<Value, 4> operands;
1792 SmallVector<NamedAttribute, 4> attrs;
1793 wmmaPushInputOperand(rewriter, loc, typeConverter, op.getUnsignedA(), a,
1794 op.getSourceA(), operands, attrs, "signA");
1795 wmmaPushInputOperand(rewriter, loc, typeConverter, op.getUnsignedB(), b,
1796 op.getSourceB(), operands, attrs, "signB");
1797 wmmaPushOutputOperand(rewriter, loc, typeConverter, destC,
1798 op.getSubwordOffset(), op.getClamp(), operands,
1799 attrs);
1800
1801 OperationState loweredOp(loc, *maybeIntrinsic);
1802 loweredOp.addTypes(rawOutType);
1803 loweredOp.addOperands(operands);
1804 loweredOp.addAttributes(attrs);
1805 Operation *lowered = rewriter.create(loweredOp);
1806
1807 Operation *maybeCastBack = lowered;
1808 if (rawOutType != outType)
1809 maybeCastBack = LLVM::BitcastOp::create(rewriter, loc, outType,
1810 lowered->getResult(0));
1811 rewriter.replaceOp(op, maybeCastBack->getResults());
1812
1813 return success();
1814 }
1815};
1816
1817enum class DotFamily {
1818 /// ROCDL_Dot_IntrOp: single `clamp` attribute.
1819 Clamp,
1820 /// ROCDL_Dot_NoClamp_IntrOp: no attributes.
1821 NoClamp,
1822 /// ROCDL_Sudot_IntrOp: `signA`, `signB`, and `clamp` attributes.
1823 Sudot,
1824};
1825
1826static std::optional<std::pair<StringRef, DotFamily>>
1827dotOpToIntrinsic(DotOp op, Chipset chipset) {
1828 Type aElem = cast<VectorType>(op.getSourceA().getType()).getElementType();
1829 Type bElem = cast<VectorType>(op.getSourceB().getType()).getElementType();
1830 Type dest = op.getDestC().getType();
1831 bool uA = op.getUnsignedA();
1832 bool uB = op.getUnsignedB();
1833
1834 // f16 x f16 -> f32 / f16.
1835 if (aElem.isF16() && bElem.isF16()) {
1836 if (dest.isF32() && hasDot10Insts(chipset))
1837 return {{ROCDL::fdot2::getOperationName(), DotFamily::Clamp}};
1838 if (dest.isF16() && hasDot9Insts(chipset))
1839 return {{ROCDL::fdot2_f16_f16::getOperationName(), DotFamily::NoClamp}};
1840 return std::nullopt;
1841 }
1842
1843 // bf16 x bf16 -> f32 / bf16.
1844 if (aElem.isBF16() && bElem.isBF16()) {
1845 if (dest.isF32() && hasDot12Insts(chipset))
1846 return {{ROCDL::fdot2_f32_bf16::getOperationName(), DotFamily::Clamp}};
1847 if (dest.isBF16() && hasDot9Insts(chipset))
1848 return {{ROCDL::fdot2_bf16_bf16::getOperationName(), DotFamily::NoClamp}};
1849 return std::nullopt;
1850 }
1851
1852 // Integer sources -> i32.
1853 if (isa<IntegerType>(aElem) && isa<IntegerType>(bElem) &&
1854 dest.isInteger(32)) {
1855 bool mixedSign = (uA != uB);
1856 unsigned elemWidth = aElem.getIntOrFloatBitWidth();
1857
1858 if (mixedSign) {
1859 if (!hasDot8Insts(chipset))
1860 return std::nullopt;
1861 StringRef name;
1862 switch (elemWidth) {
1863 case 8:
1864 name = ROCDL::sudot4::getOperationName();
1865 break;
1866 case 4:
1867 name = ROCDL::sudot8::getOperationName();
1868 break;
1869 default:
1870 return std::nullopt;
1871 }
1872 return {{name, DotFamily::Sudot}};
1873 }
1874
1875 StringRef name;
1876 bool supported = false;
1877 switch (elemWidth) {
1878 case 16:
1879 supported = hasDot2Insts(chipset);
1880 name = uA ? ROCDL::udot2::getOperationName()
1881 : ROCDL::sdot2::getOperationName();
1882 break;
1883 case 8:
1884 supported = uA ? hasDot7Insts(chipset)
1885 : hasDot1Insts(chipset) || hasDot8Insts(chipset);
1886 name = uA ? ROCDL::udot4::getOperationName()
1887 : ROCDL::sdot4::getOperationName();
1888 break;
1889 case 4:
1890 supported = uA ? hasDot7Insts(chipset)
1891 : hasDot1Insts(chipset) || hasDot8Insts(chipset);
1892 name = uA ? ROCDL::udot8::getOperationName()
1893 : ROCDL::sdot8::getOperationName();
1894 break;
1895 default:
1896 return std::nullopt;
1897 }
1898 if (!supported)
1899 return std::nullopt;
1900 return {{name, DotFamily::Clamp}};
1901 }
1902
1903 // fp8/bf8 x fp8/bf8 -> f32.
1904 bool aIsFp8 = isa<Float8E4M3FNType>(aElem);
1905 bool aIsBf8 = isa<Float8E5M2Type>(aElem);
1906 bool bIsFp8 = isa<Float8E4M3FNType>(bElem);
1907 bool bIsBf8 = isa<Float8E5M2Type>(bElem);
1908 if ((aIsFp8 || aIsBf8) && (bIsFp8 || bIsBf8) && dest.isF32()) {
1909 if (!hasDot11Insts(chipset))
1910 return std::nullopt;
1911 StringRef name;
1912 if (aIsFp8 && bIsFp8)
1913 name = ROCDL::dot4_f32_fp8_fp8::getOperationName();
1914 else if (aIsFp8 && bIsBf8)
1915 name = ROCDL::dot4_f32_fp8_bf8::getOperationName();
1916 else if (aIsBf8 && bIsFp8)
1917 name = ROCDL::dot4_f32_bf8_fp8::getOperationName();
1918 else
1919 name = ROCDL::dot4_f32_bf8_bf8::getOperationName();
1920 return {{name, DotFamily::NoClamp}};
1921 }
1922
1923 return std::nullopt;
1924}
1925
1926struct DotOpLowering : public ConvertOpToLLVMPattern<DotOp> {
1927 DotOpLowering(const LLVMTypeConverter &converter, Chipset chipset)
1928 : ConvertOpToLLVMPattern<DotOp>(converter), chipset(chipset) {}
1929
1930 Chipset chipset;
1931
1932 LogicalResult
1933 matchAndRewrite(DotOp op, DotOpAdaptor adaptor,
1934 ConversionPatternRewriter &rewriter) const override {
1935 Location loc = op.getLoc();
1936
1937 std::optional<std::pair<StringRef, DotFamily>> maybeIntrinsic =
1938 dotOpToIntrinsic(op, chipset);
1939 if (!maybeIntrinsic)
1940 return op.emitOpError("no intrinsic matching dot on the given chipset: ")
1941 << op.getSourceA().getType() << " * " << op.getSourceB().getType()
1942 << " + " << op.getDestC().getType();
1943
1944 auto [intrinsicName, family] = maybeIntrinsic.value();
1945
1946 Value a = convertPackedVectorOperand(rewriter, loc, adaptor.getSourceA());
1947 Value b = convertPackedVectorOperand(rewriter, loc, adaptor.getSourceB());
1948 Value c = adaptor.getDestC();
1949
1950 SmallVector<NamedAttribute, 3> attrs;
1951 if (family == DotFamily::Sudot) {
1952 attrs.push_back(rewriter.getNamedAttr(
1953 "signA", rewriter.getBoolAttr(!op.getUnsignedA())));
1954 attrs.push_back(rewriter.getNamedAttr(
1955 "signB", rewriter.getBoolAttr(!op.getUnsignedB())));
1956 }
1957
1958 if (family != DotFamily::NoClamp && op.getClamp())
1959 attrs.push_back(
1960 rewriter.getNamedAttr("clamp", rewriter.getBoolAttr(true)));
1961
1962 Type resultType = typeConverter->convertType(op.getDestD().getType());
1963
1964 OperationState loweredOp(loc, intrinsicName);
1965 loweredOp.addTypes(resultType);
1966 loweredOp.addOperands({a, b, c});
1967 loweredOp.addAttributes(attrs);
1968 Operation *lowered = rewriter.create(loweredOp);
1969 rewriter.replaceOp(op, lowered->getResults());
1970 return success();
1971 }
1972};
1973
1974struct SparseWMMAOpLowering : public ConvertOpToLLVMPattern<SparseWMMAOp> {
1975 SparseWMMAOpLowering(const LLVMTypeConverter &converter, Chipset chipset)
1976 : ConvertOpToLLVMPattern<SparseWMMAOp>(converter), chipset(chipset) {}
1977
1978 Chipset chipset;
1979
1980 LogicalResult
1981 matchAndRewrite(SparseWMMAOp op, SparseWMMAOpAdaptor adaptor,
1982 ConversionPatternRewriter &rewriter) const override {
1983 Location loc = op.getLoc();
1984 auto outType =
1985 typeConverter->convertType<VectorType>(op.getDestD().getType());
1986 if (!outType)
1987 return rewriter.notifyMatchFailure(op, "type conversion failed");
1988
1989 std::optional<SparseWMMAOpInfo> maybeIntrinsic =
1990 sparseWMMAOpToIntrinsic(op, chipset);
1991
1992 if (!maybeIntrinsic.has_value())
1993 return op.emitOpError(
1994 "no intrinsic matching Sparse WMMA on the given chipset");
1995 SparseWMMAOpInfo intrinsic = maybeIntrinsic.value();
1996
1997 SmallVector<NamedAttribute> attrs;
1998
1999 if ((op.getUnsignedA() || op.getUnsignedB()) && !intrinsic.useSign)
2000 return op->emitOpError("intrinsic doesn't support unsign");
2001 if (intrinsic.useSign) {
2002 if (auto attr = op.getUnsignedAAttr())
2003 attrs.push_back({"signA", attr});
2004 if (auto attr = op.getUnsignedBAttr())
2005 attrs.push_back({"signB", attr});
2006 }
2007
2008 if ((op.getReuseA() || op.getReuseB()) && !intrinsic.useReuse)
2009 return op->emitOpError("intrinsic doesn't support reuse");
2010 if (intrinsic.useReuse) {
2011 if (auto attr = op.getReuseAAttr())
2012 attrs.push_back({"reuseA", attr});
2013 if (auto attr = op.getReuseBAttr())
2014 attrs.push_back({"reuseB", attr});
2015 }
2016
2017 if (op.getClamp() && !intrinsic.useClamp)
2018 return op->emitOpError("intrinsic doesn't support clamp");
2019 if (intrinsic.useClamp && op.getClampAttr())
2020 attrs.push_back({"clamp", op.getClampAttr()});
2021
2022 const bool isGFX1250orHigher =
2023 chipset.majorVersion == 12 && chipset.minorVersion >= 5;
2024 Value a = convertPackedVectorOperand(rewriter, loc, adaptor.getSourceA(),
2025 isGFX1250orHigher);
2026 Value b = convertPackedVectorOperand(rewriter, loc, adaptor.getSourceB(),
2027 isGFX1250orHigher);
2028 Value c = adaptor.getDestC();
2029 VectorType rawOutType = outType;
2030 if (!isGFX1250orHigher) {
2031 c = convertPackedVectorOperand(rewriter, loc, adaptor.getDestC(), false);
2032 rawOutType = cast<VectorType>(c.getType());
2033 }
2034
2035 // Bitcast sparse indices from vector<4xi8> to i32.
2036 Value sparseIdx = LLVM::BitcastOp::create(
2037 rewriter, loc, rewriter.getI32Type(), adaptor.getSparseIdx());
2038
2039 OperationState loweredOp(loc, intrinsic.name);
2040 loweredOp.addTypes(rawOutType);
2041 loweredOp.addOperands({a, b, c, sparseIdx});
2042 loweredOp.addAttributes(attrs);
2043 Operation *lowered = rewriter.create(loweredOp);
2044
2045 Operation *maybeCastBack = lowered;
2046 if (rawOutType != outType)
2047 maybeCastBack = LLVM::BitcastOp::create(rewriter, loc, outType,
2048 lowered->getResult(0));
2049 rewriter.replaceOp(op, maybeCastBack->getResults());
2050
2051 return success();
2052 }
2053};
2054
2055struct ScaledWMMAOpLowering : public ConvertOpToLLVMPattern<ScaledWMMAOp> {
2056 ScaledWMMAOpLowering(const LLVMTypeConverter &converter, Chipset chipset)
2057 : ConvertOpToLLVMPattern<ScaledWMMAOp>(converter), chipset(chipset) {}
2058
2059 Chipset chipset;
2060
2061 LogicalResult
2062 matchAndRewrite(ScaledWMMAOp op, ScaledWMMAOpAdaptor adaptor,
2063 ConversionPatternRewriter &rewriter) const override {
2064 Location loc = op.getLoc();
2065 auto outType =
2066 typeConverter->convertType<VectorType>(op.getDestD().getType());
2067 if (!outType)
2068 return rewriter.notifyMatchFailure(op, "type conversion failed");
2069
2070 if (chipset < kGfx1250)
2071 return op->emitOpError("WMMA scale only supported on gfx1250+");
2072
2073 int64_t m = op.getM();
2074 int64_t n = op.getN();
2075 int64_t k = op.getK();
2076
2077 Type aElemType = getElementTypeOrSelf(op.getSourceA().getType());
2078 Type bElemType = getElementTypeOrSelf(op.getSourceB().getType());
2079
2080 std::optional<uint32_t> aFmtCode = smallFloatTypeToFormatCode(aElemType);
2081 std::optional<uint32_t> bFmtCode = smallFloatTypeToFormatCode(bElemType);
2082
2083 if (!aFmtCode || !bFmtCode)
2084 return op.emitOpError("unsupported element types for scaled_wmma");
2085
2086 // Get scale vector types and determine variant (scale vs scale16).
2087 auto scaleAVecType = cast<VectorType>(op.getScaleA().getType());
2088 auto scaleBVecType = cast<VectorType>(op.getScaleB().getType());
2089
2090 if (scaleAVecType.getNumElements() != scaleBVecType.getNumElements())
2091 return op.emitOpError("scaleA and scaleB must have equal vector length");
2092
2093 // Extract scale format from element types.
2094 Type scaleAElemType = scaleAVecType.getElementType();
2095 Type scaleBElemType = scaleBVecType.getElementType();
2096
2097 std::optional<uint32_t> scaleAFmt = getWmmaScaleFormat(scaleAElemType);
2098 std::optional<uint32_t> scaleBFmt = getWmmaScaleFormat(scaleBElemType);
2099
2100 if (!scaleAFmt || !scaleBFmt)
2101 return op.emitOpError("unsupported scale element types");
2102
2103 // Determine which intrinsic to use based on dimensions.
2104 bool isScale16 = (scaleAVecType.getNumElements() == 8);
2105 std::optional<StringRef> intrinsicName =
2106 getScaledWmmaIntrinsicName(m, n, k, isScale16);
2107 if (!intrinsicName)
2108 return op.emitOpError("unsupported scaled_wmma dimensions: ")
2109 << m << "x" << n << "x" << k;
2110
2111 SmallVector<NamedAttribute, 8> attrs;
2112
2113 // The f4 variant does not have fmtA and fmtB attributes.
2114 bool is32x16 = (m == 32 && n == 16 && k == 128);
2115 if (!is32x16) {
2116 attrs.emplace_back("fmtA", rewriter.getI32IntegerAttr(*aFmtCode));
2117 attrs.emplace_back("fmtB", rewriter.getI32IntegerAttr(*bFmtCode));
2118 }
2119
2120 // modC uses default value of 0.
2121 attrs.emplace_back("modC", rewriter.getI16IntegerAttr(0));
2122
2123 // Scale attributes. Convert user-facing firstScaleLane (0 or 16) to the
2124 // half of the wave that is being selected (0 or 1).
2125 attrs.emplace_back(
2126 "scaleAType", rewriter.getI32IntegerAttr(op.getAFirstScaleLane() / 16));
2127 attrs.emplace_back("fmtScaleA", rewriter.getI32IntegerAttr(*scaleAFmt));
2128 attrs.emplace_back(
2129 "scaleBType", rewriter.getI32IntegerAttr(op.getBFirstScaleLane() / 16));
2130 attrs.emplace_back("fmtScaleB", rewriter.getI32IntegerAttr(*scaleBFmt));
2131
2132 // Reuse flags use default value of false.
2133 attrs.emplace_back("reuseA", rewriter.getBoolAttr(false));
2134 attrs.emplace_back("reuseB", rewriter.getBoolAttr(false));
2135
2136 // Convert typed float vectors to packed format.
2137 Value sourceA =
2138 packSmallFloatVectorOperand(rewriter, loc, adaptor.getSourceA());
2139 Value sourceB =
2140 packSmallFloatVectorOperand(rewriter, loc, adaptor.getSourceB());
2141
2142 // Pack scale vectors into i32/i64.
2143 Value packedScaleA = castScaleOperand(rewriter, loc, adaptor.getScaleA());
2144 Value packedScaleB = castScaleOperand(rewriter, loc, adaptor.getScaleB());
2145
2146 // Create the intrinsic call.
2147 OperationState loweredOp(loc, *intrinsicName);
2148 loweredOp.addTypes(outType);
2149 loweredOp.addOperands(
2150 {sourceA, sourceB, adaptor.getDestC(), packedScaleA, packedScaleB});
2151 loweredOp.addAttributes(attrs);
2152
2153 Operation *lowered = rewriter.create(loweredOp);
2154 rewriter.replaceOp(op, lowered->getResults());
2155
2156 return success();
2157 }
2158};
2159
2160struct TransposeLoadOpLowering
2161 : public ConvertOpToLLVMPattern<TransposeLoadOp> {
2162 TransposeLoadOpLowering(const LLVMTypeConverter &converter, Chipset chipset)
2163 : ConvertOpToLLVMPattern<TransposeLoadOp>(converter), chipset(chipset) {}
2164
2165 Chipset chipset;
2166
2167 LogicalResult
2168 matchAndRewrite(TransposeLoadOp op, TransposeLoadOpAdaptor adaptor,
2169 ConversionPatternRewriter &rewriter) const override {
2170 if (chipset != kGfx950)
2171 return op.emitOpError("Non-gfx950 chipset not supported");
2172
2173 Location loc = op.getLoc();
2174 auto srcMemRefType = cast<MemRefType>(op.getSrc().getType());
2175
2176 // Elements in subbyte memrefs are stored non-contiguously,
2177 // reject if source is sub-byte memref. Use emulated memrefs instead.
2178 size_t srcElementSize =
2179 srcMemRefType.getElementType().getIntOrFloatBitWidth();
2180 if (srcElementSize < 8)
2181 return op.emitOpError("Expect source memref to have at least 8 bits "
2182 "element size, got ")
2183 << srcElementSize;
2184
2185 auto resultType = cast<VectorType>(op.getResult().getType());
2186 Value srcPtr =
2187 getStridedElementPtr(rewriter, loc, srcMemRefType, adaptor.getSrc(),
2188 (adaptor.getSrcIndices()));
2189
2190 size_t numElements = resultType.getNumElements();
2191 size_t elementTypeSize =
2192 resultType.getElementType().getIntOrFloatBitWidth();
2193
2194 // ROCDL transpose load intrinsics return vectors of 32-bit integers, if
2195 // the element size is smaller than 16 bits.
2196 Type rocdlResultType = VectorType::get((numElements * elementTypeSize) / 32,
2197 rewriter.getIntegerType(32));
2198 Type llvmResultType = typeConverter->convertType(resultType);
2199
2200 switch (elementTypeSize) {
2201 case 4: {
2202 assert(numElements == 16);
2203 auto rocdlOp = ROCDL::ds_read_tr4_b64::create(rewriter, loc,
2204 rocdlResultType, srcPtr);
2205 rewriter.replaceOpWithNewOp<LLVM::BitcastOp>(op, llvmResultType, rocdlOp);
2206 break;
2207 }
2208 case 6: {
2209 assert(numElements == 16);
2210 auto rocdlOp = ROCDL::ds_read_tr6_b96::create(rewriter, loc,
2211 rocdlResultType, srcPtr);
2212 rewriter.replaceOpWithNewOp<LLVM::BitcastOp>(op, llvmResultType, rocdlOp);
2213 break;
2214 }
2215 case 8: {
2216 assert(numElements == 8);
2217 auto rocdlOp = ROCDL::ds_read_tr8_b64::create(rewriter, loc,
2218 rocdlResultType, srcPtr);
2219 rewriter.replaceOpWithNewOp<LLVM::BitcastOp>(op, llvmResultType, rocdlOp);
2220 break;
2221 }
2222 case 16: {
2223 assert(numElements == 4);
2224 rewriter.replaceOpWithNewOp<ROCDL::ds_read_tr16_b64>(op, llvmResultType,
2225 srcPtr);
2226 break;
2227 }
2228 default:
2229 return op.emitOpError("Unsupported element size for transpose load");
2230 }
2231 return success();
2232 }
2233};
2234
2235struct GlobalTransposeLoadOpLowering
2236 : public ConvertOpToLLVMPattern<GlobalTransposeLoadOp> {
2237 GlobalTransposeLoadOpLowering(const LLVMTypeConverter &converter,
2238 Chipset chipset)
2239 : ConvertOpToLLVMPattern<GlobalTransposeLoadOp>(converter),
2240 chipset(chipset) {}
2241
2242 Chipset chipset;
2243
2244 LogicalResult
2245 matchAndRewrite(GlobalTransposeLoadOp op,
2246 GlobalTransposeLoadOpAdaptor adaptor,
2247 ConversionPatternRewriter &rewriter) const override {
2248 if (chipset < kGfx1200)
2249 return op.emitOpError(
2250 "global_transpose_load is only supported on gfx1200+");
2251
2252 Location loc = op.getLoc();
2253 auto srcMemRefType = cast<MemRefType>(op.getSrc().getType());
2254 auto resultType = cast<VectorType>(op.getResult().getType());
2255
2256 Value srcPtr = getStridedElementPtr(
2257 rewriter, loc, srcMemRefType, adaptor.getSrc(), adaptor.getSrcIndices(),
2258 LLVM::GEPNoWrapFlags::inbounds | LLVM::GEPNoWrapFlags::nuw);
2259
2260 size_t numElements = resultType.getNumElements();
2261 size_t elementTypeSize =
2262 resultType.getElementType().getIntOrFloatBitWidth();
2263
2264 // ROCDL global transpose load intrinsics return vectors of i32 for
2265 // sub-16-bit elements, matching the LDS lowering convention.
2266 Type rocdlResultType =
2267 elementTypeSize < 16
2268 ? VectorType::get((numElements * elementTypeSize) / 32,
2269 rewriter.getIntegerType(32))
2270 : typeConverter->convertType(resultType);
2271 Type llvmResultType = typeConverter->convertType(resultType);
2272
2273 switch (elementTypeSize) {
2274 case 4: {
2275 assert(numElements == 16);
2276 if (chipset < kGfx1250)
2277 return op.emitOpError("4-bit global_transpose_load requires gfx1250+");
2278 auto rocdlOp = ROCDL::GlobalLoadTr4_B64::create(rewriter, loc,
2279 rocdlResultType, srcPtr);
2280 rewriter.replaceOpWithNewOp<LLVM::BitcastOp>(op, llvmResultType, rocdlOp);
2281 break;
2282 }
2283 case 6: {
2284 assert(numElements == 16);
2285 if (chipset < kGfx1250)
2286 return op.emitOpError("6-bit global_transpose_load requires gfx1250+");
2287 auto rocdlOp = ROCDL::GlobalLoadTr6_B96::create(rewriter, loc,
2288 rocdlResultType, srcPtr);
2289 rewriter.replaceOpWithNewOp<LLVM::BitcastOp>(op, llvmResultType, rocdlOp);
2290 break;
2291 }
2292 case 8: {
2293 assert(numElements == 8);
2294 auto rocdlOp = ROCDL::GlobalLoadTr8_B64::create(rewriter, loc,
2295 rocdlResultType, srcPtr);
2296 rewriter.replaceOpWithNewOp<LLVM::BitcastOp>(op, llvmResultType, rocdlOp);
2297 break;
2298 }
2299 case 16: {
2300 assert(numElements == 8);
2301 rewriter.replaceOpWithNewOp<ROCDL::GlobalLoadTr8_B128>(op, llvmResultType,
2302 srcPtr);
2303 break;
2304 }
2305 default:
2306 return op.emitOpError(
2307 "unsupported element size for global transpose load");
2308 }
2309 return success();
2310 }
2311};
2312
2313struct GatherToLDSOpLowering : public ConvertOpToLLVMPattern<GatherToLDSOp> {
2314 GatherToLDSOpLowering(const LLVMTypeConverter &converter, Chipset chipset)
2315 : ConvertOpToLLVMPattern<GatherToLDSOp>(converter), chipset(chipset) {}
2316
2317 Chipset chipset;
2318
2319 LogicalResult
2320 matchAndRewrite(GatherToLDSOp op, GatherToLDSOpAdaptor adaptor,
2321 ConversionPatternRewriter &rewriter) const override {
2322 if (chipset.majorVersion < 9 || chipset.majorVersion > 10)
2323 return op.emitOpError("pre-gfx9 and post-gfx10 not supported");
2324
2325 Location loc = op.getLoc();
2326
2327 auto srcMemRefType = cast<MemRefType>(op.getSrc().getType());
2328 auto dstMemRefType = cast<MemRefType>(op.getDst().getType());
2329
2330 // TODO: instead of only transfering one element per thread, we could
2331 // augment it to transfer multiple elements per thread by issuing multiple
2332 // `global_load_lds` instructions.
2333 Type transferType = op.getTransferType();
2334 int loadWidth = [&]() -> int {
2335 if (auto transferVectorType = dyn_cast<VectorType>(transferType)) {
2336 return (transferVectorType.getNumElements() *
2337 transferVectorType.getElementTypeBitWidth()) /
2338 8;
2339 }
2340 return transferType.getIntOrFloatBitWidth() / 8;
2341 }();
2342
2343 // Currently only 1, 2, 4, 12 and 16 byte loads are supported.
2344 if (!llvm::is_contained({1, 2, 4, 12, 16}, loadWidth))
2345 return op.emitOpError("chipset unsupported element size");
2346
2347 if (chipset != kGfx950 && llvm::is_contained({12, 16}, loadWidth))
2348 return op.emitOpError("Gather to LDS instructions with 12-byte and "
2349 "16-byte load widths are only supported on gfx950");
2350
2351 Value srcPtr =
2352 getStridedElementPtr(rewriter, loc, srcMemRefType, adaptor.getSrc(),
2353 (adaptor.getSrcIndices()));
2354 Value dstPtr =
2355 getStridedElementPtr(rewriter, loc, dstMemRefType, adaptor.getDst(),
2356 (adaptor.getDstIndices()));
2357
2358 if (op.getAsync()) {
2359 rewriter.replaceOpWithNewOp<ROCDL::LoadAsyncToLDSOp>(
2360 op, srcPtr, dstPtr, rewriter.getI32IntegerAttr(loadWidth),
2361 /*offset=*/rewriter.getI32IntegerAttr(0),
2362 /*aux=*/rewriter.getI32IntegerAttr(0), ArrayAttr{}, ArrayAttr{},
2363 ArrayAttr{});
2364 } else {
2365 rewriter.replaceOpWithNewOp<ROCDL::LoadToLDSOp>(
2366 op, srcPtr, dstPtr, rewriter.getI32IntegerAttr(loadWidth),
2367 /*offset=*/rewriter.getI32IntegerAttr(0),
2368 /*aux=*/rewriter.getI32IntegerAttr(0), ArrayAttr{}, ArrayAttr{},
2369 ArrayAttr{});
2370 }
2371
2372 return success();
2373 }
2374};
2375
2376struct GlobalLoadAsyncToLDSOpLowering
2377 : public ConvertOpToLLVMPattern<GlobalLoadAsyncToLDSOp> {
2378 GlobalLoadAsyncToLDSOpLowering(const LLVMTypeConverter &converter,
2379 Chipset chipset)
2380 : ConvertOpToLLVMPattern<GlobalLoadAsyncToLDSOp>(converter),
2381 chipset(chipset) {}
2382
2383 Chipset chipset;
2384
2385 LogicalResult
2386 matchAndRewrite(GlobalLoadAsyncToLDSOp op,
2387 GlobalLoadAsyncToLDSOpAdaptor adaptor,
2388 ConversionPatternRewriter &rewriter) const override {
2389 if (chipset < kGfx1250)
2390 return op.emitOpError(
2391 "global_load_async_to_lds is only supported on gfx1250+");
2392
2393 Location loc = op.getLoc();
2394 auto srcMemRefType = cast<MemRefType>(op.getSrc().getType());
2395 auto dstMemRefType = cast<MemRefType>(op.getDst().getType());
2396
2397 Type transferType = op.getTransferType();
2398 int transferBits =
2399 isa<VectorType>(transferType)
2400 ? cast<VectorType>(transferType).getNumElements() *
2401 cast<VectorType>(transferType).getElementTypeBitWidth()
2402 : transferType.getIntOrFloatBitWidth();
2403
2404 Value srcPtr =
2405 getStridedElementPtr(rewriter, loc, srcMemRefType, adaptor.getSrc(),
2406 adaptor.getSrcIndices());
2407 Value dstPtr =
2408 getStridedElementPtr(rewriter, loc, dstMemRefType, adaptor.getDst(),
2409 adaptor.getDstIndices());
2410
2411 if (op.getMask()) {
2412 Value mask = adaptor.getMask();
2413 int64_t nullptrVal =
2414 llvm::AMDGPU::getNullPointerValue(llvm::AMDGPUAS::LOCAL_ADDRESS);
2415 Value nullInt =
2416 createI32Constant(rewriter, loc, static_cast<int32_t>(nullptrVal));
2417 Value nullPtr =
2418 LLVM::IntToPtrOp::create(rewriter, loc, dstPtr.getType(), nullInt);
2419 dstPtr = LLVM::SelectOp::create(rewriter, loc, mask, dstPtr, nullPtr);
2420 }
2421
2422 auto offset = rewriter.getI32IntegerAttr(0);
2423 auto aux = rewriter.getI32IntegerAttr(0);
2424
2425 switch (transferBits) {
2426 case 8:
2427 rewriter.replaceOpWithNewOp<ROCDL::GlobalLoadAsyncToLDSB8Op>(
2428 op, srcPtr, dstPtr, offset, aux, ArrayAttr{}, ArrayAttr{},
2429 ArrayAttr{});
2430 break;
2431 case 32:
2432 rewriter.replaceOpWithNewOp<ROCDL::GlobalLoadAsyncToLDSB32Op>(
2433 op, srcPtr, dstPtr, offset, aux, ArrayAttr{}, ArrayAttr{},
2434 ArrayAttr{});
2435 break;
2436 case 64:
2437 rewriter.replaceOpWithNewOp<ROCDL::GlobalLoadAsyncToLDSB64Op>(
2438 op, srcPtr, dstPtr, offset, aux, ArrayAttr{}, ArrayAttr{},
2439 ArrayAttr{});
2440 break;
2441 case 128:
2442 rewriter.replaceOpWithNewOp<ROCDL::GlobalLoadAsyncToLDSB128Op>(
2443 op, srcPtr, dstPtr, offset, aux, ArrayAttr{}, ArrayAttr{},
2444 ArrayAttr{});
2445 break;
2446 default:
2447 return op.emitOpError("unsupported transfer width");
2448 }
2449 return success();
2450 }
2451};
2452
2453namespace {
2454struct ExtPackedFp8OpLowering final
2455 : public ConvertOpToLLVMPattern<ExtPackedFp8Op> {
2456 ExtPackedFp8OpLowering(const LLVMTypeConverter &converter, Chipset chipset)
2457 : ConvertOpToLLVMPattern<amdgpu::ExtPackedFp8Op>(converter),
2458 chipset(chipset) {}
2459 Chipset chipset;
2460
2461 LogicalResult
2462 matchAndRewrite(ExtPackedFp8Op op, ExtPackedFp8OpAdaptor adaptor,
2463 ConversionPatternRewriter &rewriter) const override;
2464};
2465
2466struct ScaledExtPackedMatrixOpLowering final
2467 : public ConvertOpToLLVMPattern<ScaledExtPackedMatrixOp> {
2468 ScaledExtPackedMatrixOpLowering(const LLVMTypeConverter &converter,
2469 Chipset chipset)
2470 : ConvertOpToLLVMPattern<amdgpu::ScaledExtPackedMatrixOp>(converter),
2471 chipset(chipset) {}
2472 Chipset chipset;
2473
2474 LogicalResult
2475 matchAndRewrite(ScaledExtPackedMatrixOp op,
2476 ScaledExtPackedMatrixOpAdaptor adaptor,
2477 ConversionPatternRewriter &rewriter) const override;
2478};
2479
2480struct PackedTrunc2xFp8OpLowering final
2481 : public ConvertOpToLLVMPattern<PackedTrunc2xFp8Op> {
2482 PackedTrunc2xFp8OpLowering(const LLVMTypeConverter &converter,
2483 Chipset chipset)
2484 : ConvertOpToLLVMPattern<amdgpu::PackedTrunc2xFp8Op>(converter),
2485 chipset(chipset) {}
2486 Chipset chipset;
2487
2488 LogicalResult
2489 matchAndRewrite(PackedTrunc2xFp8Op op, PackedTrunc2xFp8OpAdaptor adaptor,
2490 ConversionPatternRewriter &rewriter) const override;
2491};
2492
2493struct PackedStochRoundFp8OpLowering final
2494 : public ConvertOpToLLVMPattern<PackedStochRoundFp8Op> {
2495 PackedStochRoundFp8OpLowering(const LLVMTypeConverter &converter,
2496 Chipset chipset)
2497 : ConvertOpToLLVMPattern<amdgpu::PackedStochRoundFp8Op>(converter),
2498 chipset(chipset) {}
2499 Chipset chipset;
2500
2501 LogicalResult
2502 matchAndRewrite(PackedStochRoundFp8Op op,
2503 PackedStochRoundFp8OpAdaptor adaptor,
2504 ConversionPatternRewriter &rewriter) const override;
2505};
2506
2507struct ScaledExtPackedOpLowering final
2508 : public ConvertOpToLLVMPattern<ScaledExtPackedOp> {
2509 ScaledExtPackedOpLowering(const LLVMTypeConverter &converter, Chipset chipset)
2510 : ConvertOpToLLVMPattern<amdgpu::ScaledExtPackedOp>(converter),
2511 chipset(chipset) {}
2512 Chipset chipset;
2513
2514 LogicalResult
2515 matchAndRewrite(ScaledExtPackedOp op, ScaledExtPackedOpAdaptor adaptor,
2516 ConversionPatternRewriter &rewriter) const override;
2517};
2518
2519struct PackedScaledTruncOpLowering final
2520 : public ConvertOpToLLVMPattern<PackedScaledTruncOp> {
2521 PackedScaledTruncOpLowering(const LLVMTypeConverter &converter,
2522 Chipset chipset)
2523 : ConvertOpToLLVMPattern<amdgpu::PackedScaledTruncOp>(converter),
2524 chipset(chipset) {}
2525 Chipset chipset;
2526
2527 LogicalResult
2528 matchAndRewrite(PackedScaledTruncOp op, PackedScaledTruncOpAdaptor adaptor,
2529 ConversionPatternRewriter &rewriter) const override;
2530};
2531
2532} // end namespace
2533
2534LogicalResult ExtPackedFp8OpLowering::matchAndRewrite(
2535 ExtPackedFp8Op op, ExtPackedFp8OpAdaptor adaptor,
2536 ConversionPatternRewriter &rewriter) const {
2537 Location loc = op.getLoc();
2538 if (!(chipset == kGfx942 || hasOcpFp8(chipset)))
2539 return rewriter.notifyMatchFailure(
2540 loc, "Fp8 conversion instructions are not available on target "
2541 "architecture and their emulation is not implemented");
2542 Type v4i8 =
2543 getTypeConverter()->convertType(VectorType::get(4, rewriter.getI8Type()));
2544 Type i32 = getTypeConverter()->convertType(rewriter.getI32Type());
2545 Type f32 = getTypeConverter()->convertType(op.getResult().getType());
2546
2547 Value source = adaptor.getSource();
2548 auto sourceVecType = dyn_cast<VectorType>(op.getSource().getType());
2549 auto resultVecType = dyn_cast<VectorType>(op.getResult().getType());
2550 Type sourceElemType = getElementTypeOrSelf(op.getSource());
2551 // Extend to a v4i8
2552 if (!sourceVecType || sourceVecType.getNumElements() < 4) {
2553 Value longVec = LLVM::UndefOp::create(rewriter, loc, v4i8);
2554 if (!sourceVecType) {
2555 longVec = LLVM::InsertElementOp::create(
2556 rewriter, loc, longVec, source, createI32Constant(rewriter, loc, 0));
2557 } else {
2558 for (int32_t i = 0, e = sourceVecType.getNumElements(); i < e; ++i) {
2559 Value idx = createI32Constant(rewriter, loc, i);
2560 Value elem = LLVM::ExtractElementOp::create(rewriter, loc, source, idx);
2561 longVec =
2562 LLVM::InsertElementOp::create(rewriter, loc, longVec, elem, idx);
2563 }
2564 }
2565 source = longVec;
2566 }
2567 Value i32Source = LLVM::BitcastOp::create(rewriter, loc, i32, source);
2568 if (resultVecType) {
2569 if (typeIsExpectedBf8ForChipset(chipset, sourceElemType)) {
2570 rewriter.replaceOpWithNewOp<ROCDL::CvtPkF32Bf8Op>(op, f32, i32Source,
2571 op.getIndex());
2572 } else if (typeIsExpectedFp8ForChipset(chipset, sourceElemType)) {
2573 rewriter.replaceOpWithNewOp<ROCDL::CvtPkF32Fp8Op>(op, f32, i32Source,
2574 op.getIndex());
2575 }
2576 } else {
2577 if (typeIsExpectedBf8ForChipset(chipset, sourceElemType)) {
2578 rewriter.replaceOpWithNewOp<ROCDL::CvtF32Bf8Op>(op, f32, i32Source,
2579 op.getIndex());
2580 } else if (typeIsExpectedFp8ForChipset(chipset, sourceElemType)) {
2581 rewriter.replaceOpWithNewOp<ROCDL::CvtF32Fp8Op>(op, f32, i32Source,
2582 op.getIndex());
2583 }
2584 }
2585 return success();
2586}
2587
2588int32_t getScaleSel(int32_t blockSize, unsigned bitWidth, int32_t scaleWaveHalf,
2589 int32_t firstScaleByte) {
2590 // When lowering amdgpu.scaled_ext_packed_matrix to rocdl.cvt.scale.pk*.f*.f*
2591 // operations, the attributes blockSize, sourceType, scaleWaveHalf, and
2592 // firstScaleByte are merged into a single attribute scaleSel. This is how
2593 // those values are merged together. (Note: scaleWaveHalf isn't a high-level
2594 // attribute but is derifed from firstScaleLane).
2595 assert(llvm::is_contained({16, 32}, blockSize));
2596 assert(llvm::is_contained({4u, 6u, 8u}, bitWidth));
2597
2598 const bool isFp8 = bitWidth == 8;
2599 const bool isBlock16 = blockSize == 16;
2600
2601 if (!isFp8) {
2602 int32_t bit0 = isBlock16;
2603 assert(llvm::is_contained({0, 1, 2}, firstScaleByte));
2604 int32_t bit1 = (firstScaleByte == 2) << 1;
2605 assert(llvm::is_contained({0, 1}, scaleWaveHalf));
2606 int32_t bit2 = scaleWaveHalf << 2;
2607 return bit2 | bit1 | bit0;
2608 }
2609
2610 int32_t bit0 = isBlock16;
2611 // firstScaleByte is guaranteed to be defined by two bits.
2612 assert(llvm::is_contained({0, 1, 2, 3}, firstScaleByte));
2613 int32_t bits2and1 = firstScaleByte << 1;
2614 assert(llvm::is_contained({0, 1}, scaleWaveHalf));
2615 int32_t bit3 = scaleWaveHalf << 3;
2616 int32_t bits = bit3 | bits2and1 | bit0;
2617 // These are invalid cases.
2618 assert(!llvm::is_contained(
2619 {0b0011, 0b0101, 0b0111, 0b1000, 0b1001, 0b1011, 0b1111}, bits));
2620 return bits;
2621}
2622
2623static std::optional<StringRef>
2624scaledExtPacked816ToIntrinsic(Type srcElemType, Type destElemType) {
2625 using fp4 = Float4E2M1FNType;
2626 using fp8 = Float8E4M3FNType;
2627 using bf8 = Float8E5M2Type;
2628 using fp6 = Float6E2M3FNType;
2629 using bf6 = Float6E3M2FNType;
2630 if (isa<fp4>(srcElemType)) {
2631 if (destElemType.isF16())
2632 return ROCDL::CvtPkScalePk8F16Fp4Op::getOperationName();
2633 if (destElemType.isBF16())
2634 return ROCDL::CvtPkScalePk8Bf16Fp4Op::getOperationName();
2635 if (destElemType.isF32())
2636 return ROCDL::CvtPkScalePk8F32Fp4Op::getOperationName();
2637 return std::nullopt;
2638 }
2639 if (isa<fp8>(srcElemType)) {
2640 if (destElemType.isF16())
2641 return ROCDL::CvtPkScalePk8F16Fp8Op::getOperationName();
2642 if (destElemType.isBF16())
2643 return ROCDL::CvtPkScalePk8Bf16Fp8Op::getOperationName();
2644 if (destElemType.isF32())
2645 return ROCDL::CvtPkScalePk8F32Fp8Op::getOperationName();
2646 return std::nullopt;
2647 }
2648 if (isa<bf8>(srcElemType)) {
2649 if (destElemType.isF16())
2650 return ROCDL::CvtPkScalePk8F16Bf8Op::getOperationName();
2651 if (destElemType.isBF16())
2652 return ROCDL::CvtPkScalePk8Bf16Bf8Op::getOperationName();
2653 if (destElemType.isF32())
2654 return ROCDL::CvtPkScalePk8F32Bf8Op::getOperationName();
2655 return std::nullopt;
2656 }
2657 if (isa<fp6>(srcElemType)) {
2658 if (destElemType.isF16())
2659 return ROCDL::CvtPkScalePk16F16Fp6Op::getOperationName();
2660 if (destElemType.isBF16())
2661 return ROCDL::CvtPkScalePk16Bf16Fp6Op::getOperationName();
2662 if (destElemType.isF32())
2663 return ROCDL::CvtPkScalePk16F32Fp6Op::getOperationName();
2664 return std::nullopt;
2665 }
2666 if (isa<bf6>(srcElemType)) {
2667 if (destElemType.isF16())
2668 return ROCDL::CvtPkScalePk16F16Bf6Op::getOperationName();
2669 if (destElemType.isBF16())
2670 return ROCDL::CvtPkScalePk16Bf16Bf6Op::getOperationName();
2671 if (destElemType.isF32())
2672 return ROCDL::CvtPkScalePk16F32Bf6Op::getOperationName();
2673 return std::nullopt;
2674 }
2675 llvm_unreachable("invalid combination of element types for packed conversion "
2676 "instructions");
2677}
2678
2679LogicalResult ScaledExtPackedMatrixOpLowering::matchAndRewrite(
2680 ScaledExtPackedMatrixOp op, ScaledExtPackedMatrixOpAdaptor adaptor,
2681 ConversionPatternRewriter &rewriter) const {
2682 using fp4 = Float4E2M1FNType;
2683 using fp8 = Float8E4M3FNType;
2684 using bf8 = Float8E5M2Type;
2685 using fp6 = Float6E2M3FNType;
2686 using bf6 = Float6E3M2FNType;
2687 Location loc = op.getLoc();
2688 if (chipset != kGfx1250) {
2689 return rewriter.notifyMatchFailure(
2690 loc,
2691 "Scaled fp packed conversion instructions are not available on target "
2692 "architecture and their emulation is not implemented");
2693 }
2694 // Convert user-facing firstScaleLane (0 or 16) to the half of the wave that
2695 // is being selected.
2696 int32_t scaleWaveHalf = op.getFirstScaleLane() / 16;
2697 int32_t firstScaleByte = op.getFirstScaleByte();
2698 int32_t blockSize = op.getBlockSize();
2699 auto sourceType = cast<VectorType>(op.getSource().getType());
2700 auto srcElemType = cast<FloatType>(sourceType.getElementType());
2701 unsigned bitWidth = srcElemType.getWidth();
2702
2703 auto targetType = cast<VectorType>(op.getResult().getType());
2704 auto destElemType = cast<FloatType>(targetType.getElementType());
2705
2706 IntegerType i32 = rewriter.getI32Type();
2707 Value source = adaptor.getSource();
2708 Type llvmResultType = typeConverter->convertType(op.getResult().getType());
2709 Type packedType = nullptr;
2710 if (isa<fp4>(srcElemType)) {
2711 packedType = i32;
2712 packedType = getTypeConverter()->convertType(packedType);
2713 } else if (isa<fp8, bf8>(srcElemType)) {
2714 packedType = VectorType::get(2, i32);
2715 packedType = getTypeConverter()->convertType(packedType);
2716 } else if (isa<fp6, bf6>(srcElemType)) {
2717 packedType = VectorType::get(3, i32);
2718 packedType = getTypeConverter()->convertType(packedType);
2719 } else {
2720 llvm_unreachable("invalid element type for packed scaled ext");
2721 }
2722
2723 if (!packedType || !llvmResultType) {
2724 return rewriter.notifyMatchFailure(op, "type conversion failed");
2725 }
2726
2727 std::optional<StringRef> maybeIntrinsic =
2728 scaledExtPacked816ToIntrinsic(srcElemType, destElemType);
2729 if (!maybeIntrinsic.has_value())
2730 return op.emitOpError(
2731 "no intrinsic matching packed scaled conversion on the given chipset");
2732
2733 int32_t scaleSel =
2734 getScaleSel(blockSize, bitWidth, scaleWaveHalf, firstScaleByte);
2735 Value castedScale =
2736 LLVM::BitcastOp::create(rewriter, loc, i32, adaptor.getScale());
2737 Value castedSource =
2738 LLVM::BitcastOp::create(rewriter, loc, packedType, source);
2739
2740 OperationState loweredOp(loc, *maybeIntrinsic);
2741 loweredOp.addTypes({llvmResultType});
2742 loweredOp.addOperands({castedSource, castedScale});
2743
2744 SmallVector<NamedAttribute, 1> attrs;
2745 attrs.push_back(
2746 NamedAttribute("scaleSel", rewriter.getI32IntegerAttr(scaleSel)));
2747
2748 loweredOp.addAttributes(attrs);
2749 Operation *lowered = rewriter.create(loweredOp);
2750 rewriter.replaceOp(op, lowered);
2751
2752 return success();
2753}
2754
2755LogicalResult ScaledExtPackedOpLowering::matchAndRewrite(
2756 ScaledExtPackedOp op, ScaledExtPackedOpAdaptor adaptor,
2757 ConversionPatternRewriter &rewriter) const {
2758 Location loc = op.getLoc();
2759 if (chipset != kGfx950)
2760 return rewriter.notifyMatchFailure(
2761 loc, "Scaled fp conversion instructions are not available on target "
2762 "architecture and their emulation is not implemented");
2763 Type i32 = getTypeConverter()->convertType(rewriter.getI32Type());
2764
2765 Value source = adaptor.getSource();
2766 Value scale = adaptor.getScale();
2767
2768 VectorType sourceVecType = cast<VectorType>(op.getSource().getType());
2769 Type sourceElemType = sourceVecType.getElementType();
2770 VectorType destVecType = cast<VectorType>(op.getResult().getType());
2771 Type destElemType = destVecType.getElementType();
2772
2773 VectorType packedVecType;
2774 if (isa<Float8E5M2Type, Float8E4M3FNType>(sourceElemType)) {
2775 VectorType v4i8 = VectorType::get(4, rewriter.getI8Type());
2776 packedVecType = cast<VectorType>(getTypeConverter()->convertType(v4i8));
2777 } else if (isa<Float4E2M1FNType>(sourceElemType)) {
2778 VectorType v8i4 = VectorType::get(8, rewriter.getI4Type());
2779 packedVecType = cast<VectorType>(getTypeConverter()->convertType(v8i4));
2780 } else {
2781 llvm_unreachable("invalid element type for scaled ext");
2782 }
2783
2784 // Extend to a packedVectorType
2785 if (sourceVecType.getNumElements() < packedVecType.getNumElements()) {
2786 Value longVec = LLVM::ZeroOp::create(rewriter, loc, packedVecType);
2787 if (!sourceVecType) {
2788 longVec = LLVM::InsertElementOp::create(
2789 rewriter, loc, longVec, source, createI32Constant(rewriter, loc, 0));
2790 } else {
2791 for (int32_t i = 0, e = sourceVecType.getNumElements(); i < e; ++i) {
2792 Value idx = createI32Constant(rewriter, loc, i);
2793 Value elem = LLVM::ExtractElementOp::create(rewriter, loc, source, idx);
2794 longVec =
2795 LLVM::InsertElementOp::create(rewriter, loc, longVec, elem, idx);
2796 }
2797 }
2798 source = longVec;
2799 }
2800 Value i32Source = LLVM::BitcastOp::create(rewriter, loc, i32, source);
2801
2802 if (isa<Float8E5M2Type>(sourceElemType) && destElemType.isF32())
2803 rewriter.replaceOpWithNewOp<ROCDL::CvtScaleF32PkF32Bf8Op>(
2804 op, destVecType, i32Source, scale, op.getIndex());
2805 else if (isa<Float8E5M2Type>(sourceElemType) && destElemType.isF16())
2806 rewriter.replaceOpWithNewOp<ROCDL::CvtScaleF32PkF16Bf8Op>(
2807 op, destVecType, i32Source, scale, op.getIndex());
2808 else if (isa<Float8E5M2Type>(sourceElemType) && destElemType.isBF16())
2809 rewriter.replaceOpWithNewOp<ROCDL::CvtScaleF32PkBf16Bf8Op>(
2810 op, destVecType, i32Source, scale, op.getIndex());
2811 else if (isa<Float8E4M3FNType>(sourceElemType) && destElemType.isF32())
2812 rewriter.replaceOpWithNewOp<ROCDL::CvtScaleF32PkF32Fp8Op>(
2813 op, destVecType, i32Source, scale, op.getIndex());
2814 else if (isa<Float8E4M3FNType>(sourceElemType) && destElemType.isF16())
2815 rewriter.replaceOpWithNewOp<ROCDL::CvtScaleF32PkF16Fp8Op>(
2816 op, destVecType, i32Source, scale, op.getIndex());
2817 else if (isa<Float8E4M3FNType>(sourceElemType) && destElemType.isBF16())
2818 rewriter.replaceOpWithNewOp<ROCDL::CvtScaleF32PkBf16Fp8Op>(
2819 op, destVecType, i32Source, scale, op.getIndex());
2820 else if (isa<Float4E2M1FNType>(sourceElemType) && destElemType.isF32())
2821 rewriter.replaceOpWithNewOp<ROCDL::CvtScaleF32PkF32Fp4Op>(
2822 op, destVecType, i32Source, scale, op.getIndex());
2823 else if (isa<Float4E2M1FNType>(sourceElemType) && destElemType.isF16())
2824 rewriter.replaceOpWithNewOp<ROCDL::CvtScaleF32PkF16Fp4Op>(
2825 op, destVecType, i32Source, scale, op.getIndex());
2826 else if (isa<Float4E2M1FNType>(sourceElemType) && destElemType.isBF16())
2827 rewriter.replaceOpWithNewOp<ROCDL::CvtScaleF32PkBf16Fp4Op>(
2828 op, destVecType, i32Source, scale, op.getIndex());
2829 else
2830 return failure();
2831
2832 return success();
2833}
2834
2835LogicalResult PackedScaledTruncOpLowering::matchAndRewrite(
2836 PackedScaledTruncOp op, PackedScaledTruncOpAdaptor adaptor,
2837 ConversionPatternRewriter &rewriter) const {
2838 Location loc = op.getLoc();
2839 if (chipset != kGfx950)
2840 return rewriter.notifyMatchFailure(
2841 loc, "Scaled fp conversion instructions are not available on target "
2842 "architecture and their emulation is not implemented");
2843 Type v2i16 = getTypeConverter()->convertType(
2844 VectorType::get(2, rewriter.getI16Type()));
2845 Type i32 = getTypeConverter()->convertType(rewriter.getI32Type());
2846
2847 Type resultType = op.getResult().getType();
2848 Type resultElemType = getElementTypeOrSelf(resultType);
2849 VectorType sourceVecType = cast<VectorType>(op.getSource().getType());
2850 Type sourceElemType = sourceVecType.getElementType();
2851
2852 Type intResultType = isa<Float4E2M1FNType>(resultElemType) ? i32 : v2i16;
2853
2854 Value source = adaptor.getSource();
2855 Value scale = adaptor.getScale();
2856 Value existing = adaptor.getExisting();
2857 if (existing)
2858 existing = LLVM::BitcastOp::create(rewriter, loc, intResultType, existing);
2859 else
2860 existing = LLVM::ZeroOp::create(rewriter, loc, intResultType);
2861
2862 if (sourceVecType.getNumElements() < 2) {
2863 Value c0 = createI32Constant(rewriter, loc, 0);
2864 Value elem0 = LLVM::ExtractElementOp::create(rewriter, loc, source, c0);
2865 VectorType v2 = VectorType::get(2, sourceElemType);
2866 source = LLVM::ZeroOp::create(rewriter, loc, v2);
2867 source = LLVM::InsertElementOp::create(rewriter, loc, source, elem0, c0);
2868 }
2869
2870 Value sourceA, sourceB;
2871 if (sourceElemType.isF32()) {
2872 Value c0 = createI32Constant(rewriter, loc, 0);
2873 Value c1 = createI32Constant(rewriter, loc, 1);
2874 sourceA = LLVM::ExtractElementOp::create(rewriter, loc, source, c0);
2875 sourceB = LLVM::ExtractElementOp::create(rewriter, loc, source, c1);
2876 }
2877
2878 Value result;
2879 if (sourceElemType.isF32() && isa<Float8E5M2Type>(resultElemType))
2880 result = ROCDL::CvtScaleF32PkBf8F32Op::create(rewriter, loc, intResultType,
2881 existing, sourceA, sourceB,
2882 scale, op.getIndex());
2883 else if (sourceElemType.isF16() && isa<Float8E5M2Type>(resultElemType))
2884 result = ROCDL::CvtScaleF32PkBf8F16Op::create(
2885 rewriter, loc, intResultType, existing, source, scale, op.getIndex());
2886 else if (sourceElemType.isBF16() && isa<Float8E5M2Type>(resultElemType))
2887 result = ROCDL::CvtScaleF32PkBf8Bf16Op::create(
2888 rewriter, loc, intResultType, existing, source, scale, op.getIndex());
2889 else if (sourceElemType.isF32() && isa<Float8E4M3FNType>(resultElemType))
2890 result = ROCDL::CvtScaleF32PkFp8F32Op::create(rewriter, loc, intResultType,
2891 existing, sourceA, sourceB,
2892 scale, op.getIndex());
2893 else if (sourceElemType.isF16() && isa<Float8E4M3FNType>(resultElemType))
2894 result = ROCDL::CvtScaleF32PkFp8F16Op::create(
2895 rewriter, loc, intResultType, existing, source, scale, op.getIndex());
2896 else if (sourceElemType.isBF16() && isa<Float8E4M3FNType>(resultElemType))
2897 result = ROCDL::CvtScaleF32PkFp8Bf16Op::create(
2898 rewriter, loc, intResultType, existing, source, scale, op.getIndex());
2899 else if (sourceElemType.isF32() && isa<Float4E2M1FNType>(resultElemType))
2900 result = ROCDL::CvtScaleF32PkFp4F32Op::create(rewriter, loc, intResultType,
2901 existing, sourceA, sourceB,
2902 scale, op.getIndex());
2903 else if (sourceElemType.isF16() && isa<Float4E2M1FNType>(resultElemType))
2904 result = ROCDL::CvtScaleF32PkFp4F16Op::create(
2905 rewriter, loc, intResultType, existing, source, scale, op.getIndex());
2906 else if (sourceElemType.isBF16() && isa<Float4E2M1FNType>(resultElemType))
2907 result = ROCDL::CvtScaleF32PkFp4Bf16Op::create(
2908 rewriter, loc, intResultType, existing, source, scale, op.getIndex());
2909 else
2910 return failure();
2911
2912 result = rewriter.replaceOpWithNewOp<LLVM::BitcastOp>(
2913 op, getTypeConverter()->convertType(resultType), result);
2914 return success();
2915}
2916
2917LogicalResult PackedTrunc2xFp8OpLowering::matchAndRewrite(
2918 PackedTrunc2xFp8Op op, PackedTrunc2xFp8OpAdaptor adaptor,
2919 ConversionPatternRewriter &rewriter) const {
2920 Location loc = op.getLoc();
2921 if (!(chipset == kGfx942 || hasOcpFp8(chipset)))
2922 return rewriter.notifyMatchFailure(
2923 loc, "Fp8 conversion instructions are not available on target "
2924 "architecture and their emulation is not implemented");
2925 Type i32 = getTypeConverter()->convertType(rewriter.getI32Type());
2926
2927 Type resultType = op.getResult().getType();
2928 Type resultElemType = getElementTypeOrSelf(resultType);
2929
2930 Value sourceA = adaptor.getSourceA();
2931 Value sourceB = adaptor.getSourceB();
2932 if (!sourceB)
2933 sourceB = LLVM::UndefOp::create(rewriter, loc, sourceA.getType());
2934 Value existing = adaptor.getExisting();
2935 if (existing)
2936 existing = LLVM::BitcastOp::create(rewriter, loc, i32, existing);
2937 else
2938 existing = LLVM::UndefOp::create(rewriter, loc, i32);
2939
2940 Value result;
2941 if (typeIsExpectedBf8ForChipset(chipset, resultElemType))
2942 result = ROCDL::CvtPkBf8F32Op::create(rewriter, loc, i32, sourceA, sourceB,
2943 existing, op.getWordIndex());
2944 else if (typeIsExpectedFp8ForChipset(chipset, resultElemType))
2945 result = ROCDL::CvtPkFp8F32Op::create(rewriter, loc, i32, sourceA, sourceB,
2946 existing, op.getWordIndex());
2947
2948 result = rewriter.replaceOpWithNewOp<LLVM::BitcastOp>(
2949 op, getTypeConverter()->convertType(resultType), result);
2950 return success();
2951}
2952
2953LogicalResult PackedStochRoundFp8OpLowering::matchAndRewrite(
2954 PackedStochRoundFp8Op op, PackedStochRoundFp8OpAdaptor adaptor,
2955 ConversionPatternRewriter &rewriter) const {
2956 Location loc = op.getLoc();
2957 if (!(chipset == kGfx942 || hasOcpFp8(chipset)))
2958 return rewriter.notifyMatchFailure(
2959 loc, "Fp8 conversion instructions are not available on target "
2960 "architecture and their emulation is not implemented");
2961 Type i32 = getTypeConverter()->convertType(rewriter.getI32Type());
2962
2963 Type resultType = op.getResult().getType();
2964 Type resultElemType = getElementTypeOrSelf(resultType);
2965
2966 Value source = adaptor.getSource();
2967 Value stoch = adaptor.getStochiasticParam();
2968 Value existing = adaptor.getExisting();
2969 if (existing)
2970 existing = LLVM::BitcastOp::create(rewriter, loc, i32, existing);
2971 else
2972 existing = LLVM::UndefOp::create(rewriter, loc, i32);
2973
2974 Value result;
2975 if (typeIsExpectedBf8ForChipset(chipset, resultElemType))
2976 result = ROCDL::CvtSrBf8F32Op::create(rewriter, loc, i32, source, stoch,
2977 existing, op.getStoreIndex());
2978 else if (typeIsExpectedFp8ForChipset(chipset, resultElemType))
2979 result = ROCDL::CvtSrFp8F32Op::create(rewriter, loc, i32, source, stoch,
2980 existing, op.getStoreIndex());
2981
2982 result = rewriter.replaceOpWithNewOp<LLVM::BitcastOp>(
2983 op, getTypeConverter()->convertType(resultType), result);
2984 return success();
2985}
2986
2987// Implement the AMDGPU_DPPLowering class that will convert the amdgpu.dpp
2988// operation into the corresponding ROCDL instructions.
2989struct AMDGPUDPPLowering : public ConvertOpToLLVMPattern<DPPOp> {
2990 AMDGPUDPPLowering(const LLVMTypeConverter &converter, Chipset chipset)
2991 : ConvertOpToLLVMPattern<DPPOp>(converter), chipset(chipset) {}
2992 Chipset chipset;
2993
2994 LogicalResult
2995 matchAndRewrite(DPPOp DppOp, DPPOp::Adaptor adaptor,
2996 ConversionPatternRewriter &rewriter) const override {
2997
2998 // Convert the source operand to the corresponding LLVM type
2999 Location loc = DppOp.getLoc();
3000 Value src = adaptor.getSrc();
3001 Value old = adaptor.getOld();
3002 Type srcType = src.getType();
3003 Type oldType = old.getType();
3004 Type llvmType = nullptr;
3005 if (srcType.getIntOrFloatBitWidth() < 32) {
3006 llvmType = rewriter.getI32Type();
3007 } else if (isa<FloatType>(srcType)) {
3008 llvmType = (srcType.getIntOrFloatBitWidth() == 32)
3009 ? rewriter.getF32Type()
3010 : rewriter.getF64Type();
3011 } else if (isa<IntegerType>(srcType)) {
3012 llvmType = (srcType.getIntOrFloatBitWidth() == 32)
3013 ? rewriter.getI32Type()
3014 : rewriter.getI64Type();
3015 }
3016 auto llvmSrcIntType = typeConverter->convertType(
3017 rewriter.getIntegerType(srcType.getIntOrFloatBitWidth()));
3018
3019 // If the source type is less of 32, use bitcast to convert it to i32.
3020 auto convertOperand = [&](Value operand, Type operandType) {
3021 if (operandType.getIntOrFloatBitWidth() <= 16) {
3022 if (llvm::isa<FloatType>(operandType)) {
3023 operand =
3024 LLVM::BitcastOp::create(rewriter, loc, llvmSrcIntType, operand);
3025 }
3026 auto llvmVecType = typeConverter->convertType(mlir::VectorType::get(
3027 32 / operandType.getIntOrFloatBitWidth(), llvmSrcIntType));
3028 Value undefVec = LLVM::UndefOp::create(rewriter, loc, llvmVecType);
3029 operand =
3030 LLVM::InsertElementOp::create(rewriter, loc, undefVec, operand,
3031 createI32Constant(rewriter, loc, 0));
3032 operand = LLVM::BitcastOp::create(rewriter, loc, llvmType, operand);
3033 }
3034 return operand;
3035 };
3036
3037 src = convertOperand(src, srcType);
3038 old = convertOperand(old, oldType);
3039
3040 // This is taken from the following file llvm/lib/Target/AMDGPU/SIDefines.h
3041 enum DppCtrl : unsigned {
3042 ROW_SHL0 = 0x100,
3043 ROW_SHR0 = 0x110,
3044 ROW_ROR0 = 0x120,
3045 WAVE_SHL1 = 0x130,
3046 WAVE_ROL1 = 0x134,
3047 WAVE_SHR1 = 0x138,
3048 WAVE_ROR1 = 0x13C,
3049 ROW_MIRROR = 0x140,
3050 ROW_HALF_MIRROR = 0x141,
3051 BCAST15 = 0x142,
3052 BCAST31 = 0x143,
3053 };
3054
3055 auto kind = DppOp.getKind();
3056 auto permArgument = DppOp.getPermArgument();
3057 uint32_t DppCtrl = 0;
3058
3059 switch (kind) {
3060
3061 case DPPPerm::quad_perm: {
3062 auto quadPermAttr = cast<ArrayAttr>(*permArgument);
3063 int32_t i = 0;
3064 for (auto elem : quadPermAttr.getAsRange<IntegerAttr>()) {
3065 uint32_t num = elem.getInt();
3066 DppCtrl |= num << (i * 2);
3067 i++;
3068 }
3069 break;
3070 }
3071 case DPPPerm::row_shl: {
3072 auto intAttr = cast<IntegerAttr>(*permArgument);
3073 DppCtrl = intAttr.getInt() + DppCtrl::ROW_SHL0;
3074 break;
3075 }
3076 case DPPPerm::row_shr: {
3077 auto intAttr = cast<IntegerAttr>(*permArgument);
3078 DppCtrl = intAttr.getInt() + DppCtrl::ROW_SHR0;
3079 break;
3080 }
3081 case DPPPerm::row_ror: {
3082 auto intAttr = cast<IntegerAttr>(*permArgument);
3083 DppCtrl = intAttr.getInt() + DppCtrl::ROW_ROR0;
3084 break;
3085 }
3086 case DPPPerm::wave_shl:
3087 DppCtrl = DppCtrl::WAVE_SHL1;
3088 break;
3089 case DPPPerm::wave_shr:
3090 DppCtrl = DppCtrl::WAVE_SHR1;
3091 break;
3092 case DPPPerm::wave_rol:
3093 DppCtrl = DppCtrl::WAVE_ROL1;
3094 break;
3095 case DPPPerm::wave_ror:
3096 DppCtrl = DppCtrl::WAVE_ROR1;
3097 break;
3098 case DPPPerm::row_mirror:
3099 DppCtrl = DppCtrl::ROW_MIRROR;
3100 break;
3101 case DPPPerm::row_half_mirror:
3102 DppCtrl = DppCtrl::ROW_HALF_MIRROR;
3103 break;
3104 case DPPPerm::row_bcast_15:
3105 DppCtrl = DppCtrl::BCAST15;
3106 break;
3107 case DPPPerm::row_bcast_31:
3108 DppCtrl = DppCtrl::BCAST31;
3109 break;
3110 }
3111
3112 // Check for row_mask, bank_mask, bound_ctrl if they exist and create
3113 // constants
3114 auto rowMask = DppOp->getAttrOfType<IntegerAttr>("row_mask").getInt();
3115 auto bankMask = DppOp->getAttrOfType<IntegerAttr>("bank_mask").getInt();
3116 bool boundCtrl = DppOp->getAttrOfType<BoolAttr>("bound_ctrl").getValue();
3117
3118 // create a ROCDL_DPPMovOp instruction with the appropriate attributes
3119 auto dppMovOp =
3120 ROCDL::DPPUpdateOp::create(rewriter, loc, llvmType, old, src, DppCtrl,
3121 rowMask, bankMask, boundCtrl);
3122
3123 Value result = dppMovOp.getRes();
3124 if (srcType.getIntOrFloatBitWidth() < 32) {
3125 result = LLVM::TruncOp::create(rewriter, loc, llvmSrcIntType, result);
3126 if (!llvm::isa<IntegerType>(srcType)) {
3127 result = LLVM::BitcastOp::create(rewriter, loc, srcType, result);
3128 }
3129 }
3130
3131 // We are replacing the AMDGPU_DPPOp instruction with the new
3132 // ROCDL_DPPMovOp instruction
3133 rewriter.replaceOp(DppOp, ValueRange(result));
3134 return success();
3135 }
3136};
3137
3138struct AMDGPUSwizzleBitModeLowering
3139 : public ConvertOpToLLVMPattern<SwizzleBitModeOp> {
3141
3142 LogicalResult
3143 matchAndRewrite(SwizzleBitModeOp op, OpAdaptor adaptor,
3144 ConversionPatternRewriter &rewriter) const override {
3145 Location loc = op.getLoc();
3146 Type i32 = rewriter.getI32Type();
3147 Value src = adaptor.getSrc();
3148 SmallVector<Value> decomposed;
3149 if (failed(LLVM::decomposeValue(rewriter, loc, src, i32, decomposed)))
3150 return rewriter.notifyMatchFailure(op,
3151 "failed to decompose value to i32");
3152 unsigned andMask = op.getAndMask();
3153 unsigned orMask = op.getOrMask();
3154 unsigned xorMask = op.getXorMask();
3155
3156 // bit 15 is 0 for the BitMode swizzle.
3157 // https://gpuopen.com/learn/amd-gcn-assembly-cross-lane-operations/
3158 unsigned mask = andMask | (orMask << 5) | (xorMask << 10);
3159 Value maskValue = createI32Constant(rewriter, loc, mask);
3160 SmallVector<Value> swizzled;
3161 for (Value v : decomposed) {
3162 Value res =
3163 ROCDL::DsSwizzleOp::create(rewriter, loc, v.getType(), v, maskValue);
3164 swizzled.emplace_back(res);
3165 }
3166
3167 Value result = LLVM::composeValue(rewriter, loc, swizzled, src.getType());
3168 rewriter.replaceOp(op, result);
3169 return success();
3170 }
3171};
3172
3173struct AMDGPUPermlaneLowering : public ConvertOpToLLVMPattern<PermlaneSwapOp> {
3175
3176 AMDGPUPermlaneLowering(const LLVMTypeConverter &converter, Chipset chipset)
3177 : ConvertOpToLLVMPattern<PermlaneSwapOp>(converter), chipset(chipset) {}
3178 Chipset chipset;
3179
3180 LogicalResult
3181 matchAndRewrite(PermlaneSwapOp op, OpAdaptor adaptor,
3182 ConversionPatternRewriter &rewriter) const override {
3183 if (chipset < kGfx950)
3184 return op->emitOpError("permlane_swap is only supported on gfx950+");
3185
3186 Location loc = op.getLoc();
3187 Type i32 = rewriter.getI32Type();
3188 Value src = adaptor.getSrc();
3189 unsigned rowLength = op.getRowLength();
3190 bool fi = op.getFetchInactive();
3191 bool boundctrl = op.getBoundCtrl();
3192
3193 SmallVector<Value> decomposed;
3194 if (failed(LLVM::decomposeValue(rewriter, loc, src, i32, decomposed)))
3195 return rewriter.notifyMatchFailure(op,
3196 "failed to decompose value to i32");
3197
3198 SmallVector<Value> permuted;
3199 for (Value v : decomposed) {
3200 Value res;
3201 Type i32pair = LLVM::LLVMStructType::getLiteral(
3202 rewriter.getContext(), {v.getType(), v.getType()});
3203
3204 if (rowLength == 16)
3205 res = ROCDL::Permlane16SwapOp::create(rewriter, loc, i32pair, v, v, fi,
3206 boundctrl);
3207 else if (rowLength == 32)
3208 res = ROCDL::Permlane32SwapOp::create(rewriter, loc, i32pair, v, v, fi,
3209 boundctrl);
3210 else
3211 llvm_unreachable("unsupported row length");
3212
3213 Value vdst0 = LLVM::ExtractValueOp::create(rewriter, loc, res, {0});
3214 Value vdst1 = LLVM::ExtractValueOp::create(rewriter, loc, res, {1});
3215
3216 Value isEqual = LLVM::ICmpOp::create(rewriter, loc,
3217 LLVM::ICmpPredicate::eq, vdst0, v);
3218
3219 // Per `permlane(16|32)` semantics: if the first extracted element equals
3220 // 'v', the result is the second element; otherwise it is the first.
3221 Value vdstNew =
3222 LLVM::SelectOp::create(rewriter, loc, isEqual, vdst1, vdst0);
3223 permuted.emplace_back(vdstNew);
3224 }
3225
3226 Value result = LLVM::composeValue(rewriter, loc, permuted, src.getType());
3227 rewriter.replaceOp(op, result);
3228 return success();
3229 }
3230};
3231
3232//===----------------------------------------------------------------------===//
3233// In-LDS Barrier Operations
3234//===----------------------------------------------------------------------===//
3235
3236// Bit layout of ds_barrier_state (as i64):
3237// [63:32] init count (32 bits)
3238// [31:29] phase (3 bits)
3239// [28:0] pending count (29 bits)
3240constexpr int32_t kDsBarrierPendingCountBitWidth = 29;
3241constexpr int32_t kDsBarrierPhasePos = kDsBarrierPendingCountBitWidth;
3242constexpr int32_t kDsBarrierInitCountPos = 32;
3243constexpr int32_t kDsBarrierPendingCountMask =
3244 (1 << kDsBarrierPendingCountBitWidth) - 1;
3245
3246struct DsBarrierInitOpLowering
3247 : public ConvertOpToLLVMPattern<DsBarrierInitOp> {
3248 Chipset chipset;
3249
3250 DsBarrierInitOpLowering(const LLVMTypeConverter &converter, Chipset chipset)
3251 : ConvertOpToLLVMPattern<DsBarrierInitOp>(converter), chipset(chipset) {}
3252
3253 LogicalResult
3254 matchAndRewrite(DsBarrierInitOp op, OpAdaptor adaptor,
3255 ConversionPatternRewriter &rewriter) const override {
3256 if (chipset < kGfx1250)
3257 return op->emitOpError("only supported on gfx1250+");
3258
3259 Location loc = op.getLoc();
3260 Type i64 = rewriter.getI64Type();
3261
3262 MemRefType memrefType = cast<MemRefType>(op.getBase().getType());
3263 Value ptr = getStridedElementPtr(rewriter, loc, memrefType,
3264 adaptor.getBase(), adaptor.getIndices());
3265
3266 // Note: We give participants as the number of arrivals that have to occur
3267 // before the phase changes. Hardware changes the phase when updating the
3268 // pending count would underflow, so we subtract 1 to get the behavior we're
3269 // looking for.
3270 Value initCount =
3271 LLVM::SubOp::create(rewriter, loc, adaptor.getParticipants(),
3272 createI32Constant(rewriter, loc, 1));
3273
3274 // Just a bit of paranoia, but this also allows for configurable width if
3275 // that becomes a thing.
3276 Value countMask =
3277 createI32Constant(rewriter, loc, kDsBarrierPendingCountMask);
3278 Value maskedCount32 =
3279 LLVM::AndOp::create(rewriter, loc, initCount, countMask);
3280 Value maskedCount = LLVM::ZExtOp::create(rewriter, loc, i64, maskedCount32);
3281
3282 Value initCountShifted = LLVM::ShlOp::create(
3283 rewriter, loc, maskedCount,
3284 createI64Constant(rewriter, loc, kDsBarrierInitCountPos));
3285 Value barrierState =
3286 LLVM::OrOp::create(rewriter, loc, initCountShifted, maskedCount);
3287
3288 LLVM::StoreOp::create(
3289 rewriter, loc, barrierState, ptr, /*alignment=*/8, /*isVolatile=*/false,
3290 /*isNonTemporal=*/false,
3291 /*isInvariantGroup=*/false, LLVM::AtomicOrdering::release,
3292 /*syncscope=*/"workgroup");
3293
3294 rewriter.eraseOp(op);
3295 return success();
3296 }
3297};
3298
3299struct DsBarrierPollStateOpLowering
3300 : public ConvertOpToLLVMPattern<DsBarrierPollStateOp> {
3301 Chipset chipset;
3302
3303 DsBarrierPollStateOpLowering(const LLVMTypeConverter &converter,
3304 Chipset chipset)
3305 : ConvertOpToLLVMPattern<DsBarrierPollStateOp>(converter),
3306 chipset(chipset) {}
3307
3308 LogicalResult
3309 matchAndRewrite(DsBarrierPollStateOp op, OpAdaptor adaptor,
3310 ConversionPatternRewriter &rewriter) const override {
3311 if (chipset < kGfx1250)
3312 return op->emitOpError("only supported on gfx1250+");
3313
3314 Location loc = op.getLoc();
3315 Type i64 = rewriter.getI64Type();
3316
3317 MemRefType memrefType = cast<MemRefType>(op.getBase().getType());
3318 Value ptr = getStridedElementPtr(rewriter, loc, memrefType,
3319 adaptor.getBase(), adaptor.getIndices());
3320
3321 // Atomic load with workgroup scope and acquire ordering should be what
3322 // we're looking for.
3323 rewriter.replaceOpWithNewOp<LLVM::LoadOp>(
3324 op, i64, ptr, /*alignment=*/8, /*volatile_=*/false,
3325 /*nontemporal=*/false, /*invariant=*/false,
3326 /*invariantGroup=*/false, LLVM::AtomicOrdering::acquire,
3327 /*syncscope=*/"workgroup");
3328 return success();
3329 }
3330};
3331
3332struct DsAsyncBarrierArriveOpLowering
3333 : public ConvertOpToLLVMPattern<DsAsyncBarrierArriveOp> {
3334 Chipset chipset;
3335
3336 DsAsyncBarrierArriveOpLowering(const LLVMTypeConverter &converter,
3337 Chipset chipset)
3338 : ConvertOpToLLVMPattern<DsAsyncBarrierArriveOp>(converter),
3339 chipset(chipset) {}
3340
3341 LogicalResult
3342 matchAndRewrite(DsAsyncBarrierArriveOp op, OpAdaptor adaptor,
3343 ConversionPatternRewriter &rewriter) const override {
3344 if (chipset < kGfx1250)
3345 return op->emitOpError("only supported on gfx1250+");
3346
3347 Location loc = op.getLoc();
3348
3349 MemRefType memrefType = cast<MemRefType>(op.getBase().getType());
3350 Value ptr = getStridedElementPtr(rewriter, loc, memrefType,
3351 adaptor.getBase(), adaptor.getIndices());
3352
3353 rewriter.replaceOpWithNewOp<ROCDL::DsAtomicAsyncBarrierArriveOp>(
3354 op, ptr, /*alias_scopes=*/nullptr, /*noalias_scopes=*/nullptr,
3355 /*tbaa=*/nullptr);
3356 return success();
3357 }
3358};
3359
3360struct DsBarrierArriveOpLowering
3361 : public ConvertOpToLLVMPattern<DsBarrierArriveOp> {
3362 Chipset chipset;
3363
3364 DsBarrierArriveOpLowering(const LLVMTypeConverter &converter, Chipset chipset)
3365 : ConvertOpToLLVMPattern<DsBarrierArriveOp>(converter), chipset(chipset) {
3366 }
3367
3368 LogicalResult
3369 matchAndRewrite(DsBarrierArriveOp op, OpAdaptor adaptor,
3370 ConversionPatternRewriter &rewriter) const override {
3371 if (chipset < kGfx1250)
3372 return op->emitOpError("only supported on gfx1250+");
3373
3374 Location loc = op.getLoc();
3375 Type i64 = rewriter.getI64Type();
3376
3377 MemRefType memrefType = cast<MemRefType>(op.getBase().getType());
3378 Value ptr = getStridedElementPtr(rewriter, loc, memrefType,
3379 adaptor.getBase(), adaptor.getIndices());
3380
3381 rewriter.replaceOpWithNewOp<ROCDL::DsAtomicBarrierArriveRtnOp>(
3382 op, i64, ptr, adaptor.getCount(), /*alias_scopes=*/nullptr,
3383 /*noalias_scopes=*/nullptr, /*tbaa=*/nullptr);
3384 return success();
3385 }
3386};
3387
3388struct DsBarrierStatePhaseOpLowering
3389 : public ConvertOpToLLVMPattern<DsBarrierStatePhaseOp> {
3391
3392 LogicalResult
3393 matchAndRewrite(DsBarrierStatePhaseOp op, OpAdaptor adaptor,
3394 ConversionPatternRewriter &rewriter) const override {
3395 Location loc = op.getLoc();
3396 Type i32 = rewriter.getI32Type();
3397
3398 Value state = adaptor.getState();
3399
3400 Value noInitCount = LLVM::TruncOp::create(rewriter, loc, i32, state);
3401 Value phase = LLVM::LShrOp::create(
3402 rewriter, loc, noInitCount,
3403 createI32Constant(rewriter, loc, kDsBarrierPhasePos));
3404
3405 rewriter.replaceOp(op, phase);
3406 return success();
3407 }
3408};
3409
3410struct DsBarrierStatePendingCountOpLowering
3411 : public ConvertOpToLLVMPattern<DsBarrierStatePendingCountOp> {
3413
3414 LogicalResult
3415 matchAndRewrite(DsBarrierStatePendingCountOp op, OpAdaptor adaptor,
3416 ConversionPatternRewriter &rewriter) const override {
3417 Location loc = op.getLoc();
3418 Type i32 = rewriter.getI32Type();
3419
3420 Value state = adaptor.getState();
3421
3422 Value noInitCount = LLVM::TruncOp::create(rewriter, loc, i32, state);
3423 Value pendingCount = LLVM::AndOp::create(
3424 rewriter, loc, noInitCount,
3425 createI32Constant(rewriter, loc,
3426 static_cast<uint32_t>(kDsBarrierPendingCountMask)));
3427
3428 rewriter.replaceOp(op, pendingCount);
3429 return success();
3430 }
3431};
3432
3433struct DsBarrierStateInitCountOpLowering
3434 : public ConvertOpToLLVMPattern<DsBarrierStateInitCountOp> {
3436
3437 LogicalResult
3438 matchAndRewrite(DsBarrierStateInitCountOp op, OpAdaptor adaptor,
3439 ConversionPatternRewriter &rewriter) const override {
3440 Location loc = op.getLoc();
3441 Type i32 = rewriter.getI32Type();
3442
3443 Value state = adaptor.getState();
3444
3445 Value initCountI64 = LLVM::LShrOp::create(
3446 rewriter, loc, state,
3447 createI64Constant(rewriter, loc, kDsBarrierInitCountPos));
3448 Value initCount = LLVM::TruncOp::create(rewriter, loc, i32, initCountI64);
3449
3450 rewriter.replaceOp(op, initCount);
3451 return success();
3452 }
3453};
3454
3455struct DsBarrierStatePhaseParityLowering
3456 : public ConvertOpToLLVMPattern<DsBarrierStatePhaseParity> {
3458
3459 LogicalResult
3460 matchAndRewrite(DsBarrierStatePhaseParity op, OpAdaptor adaptor,
3461 ConversionPatternRewriter &rewriter) const override {
3462 Location loc = op.getLoc();
3463 Type i1 = rewriter.getI1Type();
3464
3465 Value state = adaptor.getState();
3466
3467 Value noInitCount =
3468 LLVM::TruncOp::create(rewriter, loc, rewriter.getI32Type(), state);
3469 Value phase = LLVM::LShrOp::create(
3470 rewriter, loc, noInitCount,
3471 createI32Constant(rewriter, loc, kDsBarrierPhasePos));
3472 Value parity = LLVM::TruncOp::create(rewriter, loc, i1, phase);
3473
3474 rewriter.replaceOp(op, parity);
3475 return success();
3476 }
3477};
3478
3479//===----------------------------------------------------------------------===//
3480// Tensor Data Mover (TDM)
3481//===----------------------------------------------------------------------===//
3482
3483static Value setValueAtOffset(ConversionPatternRewriter &rewriter, Location loc,
3484 Value accumulator, Value value, int64_t shift) {
3485 shift = shift % 32;
3486 Value shiftAmount;
3487 if (shift != 0) {
3488 shiftAmount = createI32Constant(rewriter, loc, shift % 32);
3489 value = LLVM::ShlOp::create(rewriter, loc, value, shiftAmount);
3490 }
3491
3492 if (matchPattern(accumulator, mlir::m_Zero()))
3493 return value;
3494
3495 constexpr bool isDisjoint = true;
3496 return LLVM::OrOp::create(rewriter, loc, accumulator, value, isDisjoint);
3497}
3498
3499template <typename BaseOp>
3500struct AMDGPUMakeDmaBaseLowering : public ConvertOpToLLVMPattern<BaseOp> {
3501 using ConvertOpToLLVMPattern<BaseOp>::ConvertOpToLLVMPattern;
3502 using Adaptor = typename ConvertOpToLLVMPattern<BaseOp>::OpAdaptor;
3503
3504 AMDGPUMakeDmaBaseLowering(const LLVMTypeConverter &converter, Chipset chipset)
3505 : ConvertOpToLLVMPattern<BaseOp>(converter), chipset(chipset) {}
3506 Chipset chipset;
3507
3508 LogicalResult
3509 matchAndRewrite(BaseOp op, Adaptor adaptor,
3510 ConversionPatternRewriter &rewriter) const override {
3511 if (chipset < kGfx1250)
3512 return op->emitOpError("make_dma_base is only supported on gfx1250");
3513
3514 Location loc = op.getLoc();
3515
3516 constexpr int32_t constlen = 4;
3517 Value consts[constlen];
3518 for (int64_t i = 0; i < constlen; ++i)
3519 consts[i] = createI32Constant(rewriter, loc, i);
3520
3521 constexpr int32_t sgprslen = constlen;
3522 Value sgprs[sgprslen];
3523 for (int64_t i = 0; i < sgprslen; ++i) {
3524 sgprs[i] = consts[0];
3525 }
3526
3527 sgprs[0] = consts[1];
3528
3529 if constexpr (BaseOp::isGather()) {
3530 sgprs[0] = setValueAtOffset(rewriter, loc, sgprs[0], consts[1], 30);
3531
3532 auto type = cast<TDMGatherBaseType>(op.getResult().getType());
3533 Type indexType = type.getIndexType();
3534 unsigned indexSize = indexType.getIntOrFloatBitWidth();
3535 assert(llvm::is_contained({16u, 32u}, indexSize) &&
3536 "expected index_size to be 16 or 32");
3537 unsigned idx = (indexSize / 16) - 1;
3538
3539 if (idx)
3540 sgprs[0] = setValueAtOffset(rewriter, loc, sgprs[0], consts[1], 31);
3541 }
3542
3543 ValueRange ldsIndices = adaptor.getLdsIndices();
3544 Value lds = adaptor.getLds();
3545 auto ldsMemRefType = cast<MemRefType>(op.getLds().getType());
3546
3548 rewriter, loc, ldsMemRefType, lds, ldsIndices);
3549
3550 ValueRange globalIndices = adaptor.getGlobalIndices();
3551 Value global = adaptor.getGlobal();
3552 auto globalMemRefType = cast<MemRefType>(op.getGlobal().getType());
3553
3555 rewriter, loc, globalMemRefType, global, globalIndices);
3556
3557 Type i32 = rewriter.getI32Type();
3558 Type i64 = rewriter.getI64Type();
3559
3560 sgprs[1] = LLVM::PtrToIntOp::create(rewriter, loc, i32, ldsPtr);
3561 Value castForGlobalAddr =
3562 LLVM::PtrToIntOp::create(rewriter, loc, i64, globalPtr);
3563
3564 sgprs[2] = LLVM::TruncOp::create(rewriter, loc, i32, castForGlobalAddr);
3565
3566 Value shift = LLVM::LShrOp::create(rewriter, loc, castForGlobalAddr,
3567 createI64Constant(rewriter, loc, 32));
3568
3569 Value highHalf = LLVM::TruncOp::create(rewriter, loc, i32, shift);
3570
3571 Value mask = createI32Constant(rewriter, loc, (1ull << 25) - 1);
3572 highHalf = LLVM::AndOp::create(rewriter, loc, highHalf, mask);
3573
3574 sgprs[3] = setValueAtOffset(rewriter, loc, highHalf, consts[2], 30);
3575
3576 Type v4i32 = this->typeConverter->convertType(VectorType::get(4, i32));
3577 assert(v4i32 && "expected type conversion to succeed");
3578 Value result = LLVM::PoisonOp::create(rewriter, loc, v4i32);
3579
3580 for (auto [sgpr, constant] : llvm::zip_equal(sgprs, consts))
3581 result =
3582 LLVM::InsertElementOp::create(rewriter, loc, result, sgpr, constant);
3583
3584 rewriter.replaceOp(op, result);
3585 return success();
3586 }
3587};
3588
3589template <typename DescriptorOp>
3590struct AMDGPULowerDescriptor : public ConvertOpToLLVMPattern<DescriptorOp> {
3591 using ConvertOpToLLVMPattern<DescriptorOp>::ConvertOpToLLVMPattern;
3592 using OpAdaptor = typename ConvertOpToLLVMPattern<DescriptorOp>::OpAdaptor;
3593
3594 AMDGPULowerDescriptor(const LLVMTypeConverter &converter, Chipset chipset)
3595 : ConvertOpToLLVMPattern<DescriptorOp>(converter), chipset(chipset) {}
3596 Chipset chipset;
3597
3598 Value getDGroup0(OpAdaptor adaptor) const { return adaptor.getBase(); }
3599
3600 Value setWorkgroupMask(DescriptorOp op, OpAdaptor adaptor,
3601 ConversionPatternRewriter &rewriter, Location loc,
3602 Value sgpr0) const {
3603 Value mask = op.getWorkgroupMask();
3604 if (!mask)
3605 return sgpr0;
3606
3607 Type i16 = rewriter.getI16Type();
3608 mask = LLVM::BitcastOp::create(rewriter, loc, i16, mask);
3609 Type i32 = rewriter.getI32Type();
3610 Value extendedMask = LLVM::ZExtOp::create(rewriter, loc, i32, mask);
3611 return setValueAtOffset(rewriter, loc, sgpr0, extendedMask, 0);
3612 }
3613
3614 Value setDataSize(DescriptorOp op, OpAdaptor adaptor,
3615 ConversionPatternRewriter &rewriter, Location loc,
3616 Value sgpr0, ArrayRef<Value> consts) const {
3617 unsigned elementTypeWidthInBits = op.getElementTypeWidth();
3618 assert(llvm::is_contained({8u, 16u, 32u, 64u}, elementTypeWidthInBits) &&
3619 "expected type width to be 8, 16, 32, or 64.");
3620 int64_t idx = llvm::Log2_32(elementTypeWidthInBits / 8);
3621 Value size = consts[idx];
3622 return setValueAtOffset(rewriter, loc, sgpr0, size, 16);
3623 }
3624
3625 Value setAtomicBarrier(DescriptorOp op, OpAdaptor adaptor,
3626 ConversionPatternRewriter &rewriter, Location loc,
3627 Value sgpr0, ArrayRef<Value> consts) const {
3628 if (!adaptor.getAtomicBarrierAddress())
3629 return sgpr0;
3630
3631 return setValueAtOffset(rewriter, loc, sgpr0, consts[1], 18);
3632 }
3633
3634 Value setIterateEnable(DescriptorOp op, OpAdaptor adaptor,
3635 ConversionPatternRewriter &rewriter, Location loc,
3636 Value sgpr0, ArrayRef<Value> consts) const {
3637 if (!adaptor.getGlobalIncrement())
3638 return sgpr0;
3639
3640 // Value is ignored when in gather mode.
3641 // TODO: emit error earlier?
3642 return setValueAtOffset(rewriter, loc, sgpr0, consts[1], 19);
3643 }
3644
3645 Value setPadEnable(DescriptorOp op, OpAdaptor adaptor,
3646 ConversionPatternRewriter &rewriter, Location loc,
3647 Value sgpr0, ArrayRef<Value> consts) const {
3648 if (!op.getPadAmount())
3649 return sgpr0;
3650
3651 return setValueAtOffset(rewriter, loc, sgpr0, consts[1], 20);
3652 }
3653
3654 Value setEarlyTimeout(DescriptorOp op, OpAdaptor adaptor,
3655 ConversionPatternRewriter &rewriter, Location loc,
3656 Value sgpr0, ArrayRef<Value> consts) const {
3657 if (!op.getWorkgroupMask())
3658 return sgpr0;
3659
3660 return setValueAtOffset(rewriter, loc, sgpr0, consts[1], 21);
3661 }
3662
3663 Value setPadInterval(DescriptorOp op, OpAdaptor adaptor,
3664 ConversionPatternRewriter &rewriter, Location loc,
3665 Value sgpr0, ArrayRef<Value> consts) const {
3666 if (!op.getPadAmount())
3667 return sgpr0;
3668
3669 // pre-condition: padInterval can be a power of two between 2 and 256.
3670 // TODO: Validation if the value breaks the pre-condition.
3671 // If the pre-condition fails, there is a possibility of
3672 // affecting the higher bits. In a following PR implement
3673 // RuntimeVerifiableOpInterface that instruments conditions that need to be
3674 // checked at runtime.
3675 IntegerType i32 = rewriter.getI32Type();
3676 Value padInterval = adaptor.getPadInterval();
3677 padInterval = LLVM::CountTrailingZerosOp::create(rewriter, loc, i32,
3678 padInterval, false);
3679 padInterval = LLVM::SubOp::create(rewriter, loc, padInterval, consts[1]);
3680 // post-condition: padInterval can be a value between 0 and 7.
3681 return setValueAtOffset(rewriter, loc, sgpr0, padInterval, 22);
3682 }
3683
3684 Value setPadAmount(DescriptorOp op, OpAdaptor adaptor,
3685 ConversionPatternRewriter &rewriter, Location loc,
3686 Value sgpr0, ArrayRef<Value> consts) const {
3687 if (!op.getPadAmount())
3688 return sgpr0;
3689
3690 // pre-condition: padAmount is a value between 1-128.
3691 // TODO: Validation if the value breaks the pre-condition.
3692 // If the pre-condition fails, there is a possibility of
3693 // affecting the higher bits. In a following PR implement
3694 // RuntimeVerifiableOpInterface that instruments conditions that need to be
3695 // checked at runtime.
3696 Value padAmount = adaptor.getPadAmount();
3697 padAmount = LLVM::SubOp::create(rewriter, loc, padAmount, consts[1]);
3698 // post-condition: padAmount is a value between 0-127.
3699 return setValueAtOffset(rewriter, loc, sgpr0, padAmount, 25);
3700 }
3701
3702 Value setAtomicBarrierAddress(DescriptorOp op, OpAdaptor adaptor,
3703 ConversionPatternRewriter &rewriter,
3704 Location loc, Value sgpr1,
3705 ArrayRef<Value> consts) const {
3706 if (!adaptor.getAtomicBarrierAddress())
3707 return sgpr1;
3708
3709 Value atomicBarrierAddress = adaptor.getAtomicBarrierAddress();
3710 auto barrierAddressTy =
3711 cast<MemRefType>(op.getAtomicBarrierAddress().getType());
3712 ValueRange atomicBarrierIndices = adaptor.getAtomicBarrierIndices();
3713 atomicBarrierAddress = ConvertToLLVMPattern::getStridedElementPtr(
3714 rewriter, loc, barrierAddressTy, atomicBarrierAddress,
3715 atomicBarrierIndices);
3716 IntegerType i32 = rewriter.getI32Type();
3717 // pre-condition: atomicBarrierAddress is aligned to 8 bytes which implies
3718 // that the 3 LSBs are zero.
3719 // TODO: Validation if the value breaks the pre-condition.
3720 // In a following PR implement RuntimeVerifiableOpInterface
3721 // that instruments conditions that need to be checked at runtime.
3722 atomicBarrierAddress =
3723 LLVM::PtrToIntOp::create(rewriter, loc, i32, atomicBarrierAddress);
3724 atomicBarrierAddress =
3725 LLVM::LShrOp::create(rewriter, loc, atomicBarrierAddress, consts[3]);
3726 Value mask = createI32Constant(rewriter, loc, 0xFFFF);
3727 atomicBarrierAddress =
3728 LLVM::AndOp::create(rewriter, loc, atomicBarrierAddress, mask);
3729 return setValueAtOffset(rewriter, loc, sgpr1, atomicBarrierAddress, 32);
3730 }
3731
3732 std::pair<Value, Value> setTensorDimX(DescriptorOp op, OpAdaptor adaptor,
3733 ConversionPatternRewriter &rewriter,
3734 Location loc, Value sgpr1, Value sgpr2,
3735 ArrayRef<Value> consts, uint64_t dimX,
3736 uint32_t offset) const {
3737 ArrayRef<int64_t> globalStaticSizes = adaptor.getGlobalStaticSizes();
3738 ValueRange globalDynamicSizes = adaptor.getGlobalDynamicSizes();
3739 SmallVector<OpFoldResult> mixedGlobalSizes =
3740 getMixedValues(globalStaticSizes, globalDynamicSizes, rewriter);
3741 if (mixedGlobalSizes.size() <= dimX)
3742 return {sgpr1, sgpr2};
3743
3744 OpFoldResult tensorDimXOpFoldResult = *(mixedGlobalSizes.rbegin() + dimX);
3745 // pre-condition: tensorDimX is less than 2^32-1
3746 // TODO: Validation if the value breaks the pre-condition.
3747 // In a following PR implement RuntimeVerifiableOpInterface that instruments
3748 // conditions that need to be checked at runtime. This could also be fixed
3749 // by saying that mixedGlobalSizes is a DynamicI32List.
3750 Value tensorDimX;
3751 if (auto attr = dyn_cast<Attribute>(tensorDimXOpFoldResult)) {
3752 tensorDimX =
3753 createI32Constant(rewriter, loc, cast<IntegerAttr>(attr).getInt());
3754 } else {
3755 IntegerType i32 = rewriter.getI32Type();
3756 tensorDimX = cast<Value>(tensorDimXOpFoldResult);
3757 tensorDimX = LLVM::TruncOp::create(rewriter, loc, i32, tensorDimX);
3758 }
3759
3760 sgpr1 = setValueAtOffset(rewriter, loc, sgpr1, tensorDimX, offset);
3761
3762 Value c16 = createI32Constant(rewriter, loc, 16);
3763 Value tensorDimXHigh = LLVM::LShrOp::create(rewriter, loc, tensorDimX, c16);
3764 sgpr2 = setValueAtOffset(rewriter, loc, sgpr2, tensorDimXHigh, offset + 16);
3765 return {sgpr1, sgpr2};
3766 }
3767
3768 std::pair<Value, Value> setTensorDim0(DescriptorOp op, OpAdaptor adaptor,
3769 ConversionPatternRewriter &rewriter,
3770 Location loc, Value sgpr1, Value sgpr2,
3771 ArrayRef<Value> consts) const {
3772 return setTensorDimX(op, adaptor, rewriter, loc, sgpr1, sgpr2, consts, 0,
3773 48);
3774 }
3775
3776 std::pair<Value, Value> setTensorDim1(DescriptorOp op, OpAdaptor adaptor,
3777 ConversionPatternRewriter &rewriter,
3778 Location loc, Value sgpr2, Value sgpr3,
3779 ArrayRef<Value> consts) const {
3780 return setTensorDimX(op, adaptor, rewriter, loc, sgpr2, sgpr3, consts, 1,
3781 80);
3782 }
3783
3784 Value setTileDimX(DescriptorOp op, OpAdaptor adaptor,
3785 ConversionPatternRewriter &rewriter, Location loc,
3786 Value sgpr, ArrayRef<Value> consts, size_t dimX,
3787 int64_t offset) const {
3788 ArrayRef<int64_t> sharedStaticSizes = adaptor.getSharedStaticSizes();
3789 ValueRange sharedDynamicSizes = adaptor.getSharedDynamicSizes();
3790 SmallVector<OpFoldResult> mixedSharedSizes =
3791 getMixedValues(sharedStaticSizes, sharedDynamicSizes, rewriter);
3792 if (mixedSharedSizes.size() <= dimX)
3793 return sgpr;
3794
3795 OpFoldResult tileDimXOpFoldResult = *(mixedSharedSizes.rbegin() + dimX);
3796 // pre-condition: tileDimX is less than 2^16-1
3797 // TODO: Validation if the value breaks the pre-condition.
3798 // If the pre-condition fails, there is a possibility of
3799 // affecting the higher bits. In a following PR implement
3800 // RuntimeVerifiableOpInterface that instruments conditions that need to be
3801 // checked at runtime. This could also be fixed by saying that
3802 // mixedSharedSizes is a DynamicI16List.
3803 Value tileDimX;
3804 if (auto attr = dyn_cast<Attribute>(tileDimXOpFoldResult)) {
3805 tileDimX =
3806 createI32Constant(rewriter, loc, cast<IntegerAttr>(attr).getInt());
3807 } else {
3808 IntegerType i32 = rewriter.getI32Type();
3809 tileDimX = cast<Value>(tileDimXOpFoldResult);
3810 tileDimX = LLVM::TruncOp::create(rewriter, loc, i32, tileDimX);
3811 }
3812
3813 return setValueAtOffset(rewriter, loc, sgpr, tileDimX, offset);
3814 }
3815
3816 Value setTileDim0(DescriptorOp op, OpAdaptor adaptor,
3817 ConversionPatternRewriter &rewriter, Location loc,
3818 Value sgpr3, ArrayRef<Value> consts) const {
3819 return setTileDimX(op, adaptor, rewriter, loc, sgpr3, consts, 0, 112);
3820 }
3821
3822 Value setTileDim1(DescriptorOp op, OpAdaptor adaptor,
3823 ConversionPatternRewriter &rewriter, Location loc,
3824 Value sgpr4, ArrayRef<Value> consts) const {
3825 return setTileDimX(op, adaptor, rewriter, loc, sgpr4, consts, 1, 128);
3826 }
3827
3828 Value setValidIndices(DescriptorOp op, OpAdaptor adaptor,
3829 ConversionPatternRewriter &rewriter, Location loc,
3830 Value sgpr4, ArrayRef<Value> consts) const {
3831 auto type = cast<VectorType>(op.getIndices().getType());
3832 ArrayRef<int64_t> shape = type.getShape();
3833 assert(shape.size() == 1 && "expected shape to be of rank 1.");
3834 unsigned length = shape.back();
3835 assert(0 < length && length <= 16 && "expected length to be at most 16.");
3836 Value value = createI32Constant(rewriter, loc, length);
3837 return setValueAtOffset(rewriter, loc, sgpr4, value, 128);
3838 }
3839
3840 Value setTileDim1OrValidIndices(DescriptorOp op, OpAdaptor adaptor,
3841 ConversionPatternRewriter &rewriter,
3842 Location loc, Value sgpr4,
3843 ArrayRef<Value> consts) const {
3844 if constexpr (DescriptorOp::isGather())
3845 return setValidIndices(op, adaptor, rewriter, loc, sgpr4, consts);
3846 return setTileDim1(op, adaptor, rewriter, loc, sgpr4, consts);
3847 }
3848
3849 Value setTileDim2(DescriptorOp op, OpAdaptor adaptor,
3850 ConversionPatternRewriter &rewriter, Location loc,
3851 Value sgpr4, ArrayRef<Value> consts) const {
3852 // Value is ignored when in gather mode.
3853 if constexpr (DescriptorOp::isGather())
3854 return sgpr4;
3855 return setTileDimX(op, adaptor, rewriter, loc, sgpr4, consts, 2, 144);
3856 }
3857
3858 std::pair<Value, Value>
3859 setTensorDimXStride(DescriptorOp op, OpAdaptor adaptor,
3860 ConversionPatternRewriter &rewriter, Location loc,
3861 Value sgprY, Value sgprZ, ArrayRef<Value> consts,
3862 size_t dimX, int64_t offset) const {
3863 ArrayRef<int64_t> globalStaticStrides = adaptor.getGlobalStaticStrides();
3864 ValueRange globalDynamicStrides = adaptor.getGlobalDynamicStrides();
3865 SmallVector<OpFoldResult> mixedGlobalStrides =
3866 getMixedValues(globalStaticStrides, globalDynamicStrides, rewriter);
3867
3868 if (mixedGlobalStrides.size() <= (dimX + 1))
3869 return {sgprY, sgprZ};
3870
3871 OpFoldResult tensorDimXStrideOpFoldResult =
3872 *(mixedGlobalStrides.rbegin() + dimX + 1);
3873 // pre-condition: tensorDimXStride is less than 2^48-1
3874 // TODO: Validation if the value breaks the pre-condition.
3875 // In a following PR implement RuntimeVerifiableOpInterface that instruments
3876 // conditions that need to be checked at runtime.
3877 Value tensorDimXStride;
3878 if (auto attr = dyn_cast<Attribute>(tensorDimXStrideOpFoldResult))
3879 tensorDimXStride =
3880 createI64Constant(rewriter, loc, cast<IntegerAttr>(attr).getInt());
3881 else
3882 tensorDimXStride = cast<Value>(tensorDimXStrideOpFoldResult);
3883
3884 constexpr int64_t first48bits = (1ll << 48) - 1;
3885 Value mask = createI64Constant(rewriter, loc, first48bits);
3886 tensorDimXStride =
3887 LLVM::AndOp::create(rewriter, loc, mask, tensorDimXStride);
3888 IntegerType i32 = rewriter.getI32Type();
3889 Value tensorDimXStrideLow =
3890 LLVM::TruncOp::create(rewriter, loc, i32, tensorDimXStride);
3891 sgprY = setValueAtOffset(rewriter, loc, sgprY, tensorDimXStrideLow, offset);
3892
3893 int64_t shift = (offset % 32) == 0 ? 32 : offset % 32;
3894 Value shiftVal = createI64Constant(rewriter, loc, shift);
3895 Value tensorDimXStrideHigh =
3896 LLVM::LShrOp::create(rewriter, loc, tensorDimXStride, shiftVal);
3897 tensorDimXStrideHigh =
3898 LLVM::TruncOp::create(rewriter, loc, i32, tensorDimXStrideHigh);
3899 sgprZ = setValueAtOffset(rewriter, loc, sgprZ, tensorDimXStrideHigh,
3900 offset + shift);
3901 return {sgprY, sgprZ};
3902 }
3903
3904 std::pair<Value, Value>
3905 setTensorDim0Stride(DescriptorOp op, OpAdaptor adaptor,
3906 ConversionPatternRewriter &rewriter, Location loc,
3907 Value sgpr5, Value sgpr6, ArrayRef<Value> consts) const {
3908 return setTensorDimXStride(op, adaptor, rewriter, loc, sgpr5, sgpr6, consts,
3909 0, 160);
3910 }
3911
3912 std::pair<Value, Value>
3913 setTensorDim1Stride(DescriptorOp op, OpAdaptor adaptor,
3914 ConversionPatternRewriter &rewriter, Location loc,
3915 Value sgpr5, Value sgpr6, ArrayRef<Value> consts) const {
3916 // Value is ignored when in gather mode.
3917 if constexpr (DescriptorOp::isGather())
3918 return {sgpr5, sgpr6};
3919 return setTensorDimXStride(op, adaptor, rewriter, loc, sgpr5, sgpr6, consts,
3920 1, 208);
3921 }
3922
3923 Value getDGroup1(DescriptorOp op, OpAdaptor adaptor,
3924 ConversionPatternRewriter &rewriter, Location loc,
3925 ArrayRef<Value> consts) const {
3926 Value sgprs[8];
3927 for (int64_t i = 0; i < 8; ++i) {
3928 sgprs[i] = consts[0];
3929 }
3930
3931 sgprs[0] = setWorkgroupMask(op, adaptor, rewriter, loc, sgprs[0]);
3932 sgprs[0] = setDataSize(op, adaptor, rewriter, loc, sgprs[0], consts);
3933 sgprs[0] = setAtomicBarrier(op, adaptor, rewriter, loc, sgprs[0], consts);
3934 sgprs[0] = setIterateEnable(op, adaptor, rewriter, loc, sgprs[0], consts);
3935 sgprs[0] = setPadEnable(op, adaptor, rewriter, loc, sgprs[0], consts);
3936 sgprs[0] = setEarlyTimeout(op, adaptor, rewriter, loc, sgprs[0], consts);
3937 sgprs[0] = setPadInterval(op, adaptor, rewriter, loc, sgprs[0], consts);
3938 sgprs[0] = setPadAmount(op, adaptor, rewriter, loc, sgprs[0], consts);
3939
3940 sgprs[1] =
3941 setAtomicBarrierAddress(op, adaptor, rewriter, loc, sgprs[1], consts);
3942 std::tie(sgprs[1], sgprs[2]) =
3943 setTensorDim0(op, adaptor, rewriter, loc, sgprs[1], sgprs[2], consts);
3944 std::tie(sgprs[2], sgprs[3]) =
3945 setTensorDim1(op, adaptor, rewriter, loc, sgprs[2], sgprs[3], consts);
3946
3947 sgprs[3] = setTileDim0(op, adaptor, rewriter, loc, sgprs[3], consts);
3948 sgprs[4] =
3949 setTileDim1OrValidIndices(op, adaptor, rewriter, loc, sgprs[4], consts);
3950 sgprs[4] = setTileDim2(op, adaptor, rewriter, loc, sgprs[4], consts);
3951 std::tie(sgprs[5], sgprs[6]) = setTensorDim0Stride(
3952 op, adaptor, rewriter, loc, sgprs[5], sgprs[6], consts);
3953 std::tie(sgprs[6], sgprs[7]) = setTensorDim1Stride(
3954 op, adaptor, rewriter, loc, sgprs[6], sgprs[7], consts);
3955
3956 IntegerType i32 = rewriter.getI32Type();
3957 Type v8i32 = this->typeConverter->convertType(VectorType::get(8, i32));
3958 assert(v8i32 && "expected type conversion to succeed");
3959 Value dgroup1 = LLVM::PoisonOp::create(rewriter, loc, v8i32);
3960
3961 for (auto [sgpr, constant] : llvm::zip_equal(sgprs, consts)) {
3962 dgroup1 =
3963 LLVM::InsertElementOp::create(rewriter, loc, dgroup1, sgpr, constant);
3964 }
3965
3966 return dgroup1;
3967 }
3968
3969 Value setTensorDimX(DescriptorOp op, OpAdaptor adaptor,
3970 ConversionPatternRewriter &rewriter, Location loc,
3971 Value sgpr0, ArrayRef<Value> consts, int64_t dimX,
3972 int64_t offset) const {
3973 ArrayRef<int64_t> globalStaticSizes = adaptor.getGlobalStaticSizes();
3974 ValueRange globalDynamicSizes = adaptor.getGlobalDynamicSizes();
3975 SmallVector<OpFoldResult> mixedGlobalSizes =
3976 getMixedValues(globalStaticSizes, globalDynamicSizes, rewriter);
3977 if (mixedGlobalSizes.size() <= static_cast<unsigned long>(dimX))
3978 return sgpr0;
3979
3980 OpFoldResult tensorDimXOpFoldResult = *(mixedGlobalSizes.rbegin() + dimX);
3981 Value tensorDimX;
3982 if (auto attr = dyn_cast<Attribute>(tensorDimXOpFoldResult)) {
3983 tensorDimX =
3984 createI32Constant(rewriter, loc, cast<IntegerAttr>(attr).getInt());
3985 } else {
3986 IntegerType i32 = rewriter.getI32Type();
3987 tensorDimX = cast<Value>(tensorDimXOpFoldResult);
3988 tensorDimX = LLVM::TruncOp::create(rewriter, loc, i32, tensorDimX);
3989 }
3990
3991 return setValueAtOffset(rewriter, loc, sgpr0, tensorDimX, offset);
3992 }
3993
3994 Value setTensorDim2(DescriptorOp op, OpAdaptor adaptor,
3995 ConversionPatternRewriter &rewriter, Location loc,
3996 Value sgpr0, ArrayRef<Value> consts) const {
3997 return setTensorDimX(op, adaptor, rewriter, loc, sgpr0, consts, 2, 0);
3998 }
3999
4000 Value truncateAndSetValueAtOffset(ConversionPatternRewriter &rewriter,
4001 Location loc, Value accumulator,
4002 Value value, int64_t shift) const {
4003
4004 IntegerType i32 = rewriter.getI32Type();
4005 value = LLVM::TruncOp::create(rewriter, loc, i32, value);
4006 return setValueAtOffset(rewriter, loc, accumulator, value, shift);
4007 }
4008
4009 Value setLDSAddrIncrement(DescriptorOp op, OpAdaptor adaptor,
4010 ConversionPatternRewriter &rewriter, Location loc,
4011 Value sgpr1, ArrayRef<Value> consts,
4012 int64_t offset) const {
4013 Value ldsAddrIncrement = adaptor.getLdsIncrement();
4014 return setValueAtOffset(rewriter, loc, sgpr1, ldsAddrIncrement, offset);
4015 }
4016
4017 std::pair<Value, Value>
4018 setGlobalAddrIncrement(DescriptorOp op, OpAdaptor adaptor,
4019 ConversionPatternRewriter &rewriter, Location loc,
4020 Value sgpr2, Value sgpr3, ArrayRef<Value> consts,
4021 int64_t offset) const {
4022 Value globalAddrIncrement = adaptor.getGlobalIncrement();
4023 sgpr2 = truncateAndSetValueAtOffset(rewriter, loc, sgpr2,
4024 globalAddrIncrement, offset);
4025 Value shift = createI64Constant(rewriter, loc, 32);
4026 globalAddrIncrement =
4027 LLVM::LShrOp::create(rewriter, loc, globalAddrIncrement, shift);
4028 constexpr int64_t first16BitsHigh = (1ll << 16) - 1;
4029 sgpr3 = truncateAndSetValueAtOffset(rewriter, loc, sgpr3,
4030 globalAddrIncrement, offset + 32);
4031 Value mask = createI32Constant(rewriter, loc, first16BitsHigh);
4032 sgpr3 = LLVM::AndOp::create(rewriter, loc, sgpr3, mask);
4033 return {sgpr2, sgpr3};
4034 }
4035
4036 Value setTensorDim3OrLDSAddrIncrement(DescriptorOp op, OpAdaptor adaptor,
4037 ConversionPatternRewriter &rewriter,
4038 Location loc, Value sgpr1,
4039 ArrayRef<Value> consts) const {
4040 Value ldsIncrement = op.getLdsIncrement();
4041 constexpr int64_t dim = 3;
4042 constexpr int64_t offset = 32;
4043 if (!ldsIncrement)
4044 return setTensorDimX(op, adaptor, rewriter, loc, sgpr1, consts, dim,
4045 offset);
4046 return setLDSAddrIncrement(op, adaptor, rewriter, loc, sgpr1, consts,
4047 offset);
4048 }
4049
4050 std::pair<Value, Value> setTensorDim2StrideOrGlobalAddrIncrement(
4051 DescriptorOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter,
4052 Location loc, Value sgpr2, Value sgpr3, ArrayRef<Value> consts) const {
4053 Value globalIncrement = op.getGlobalIncrement();
4054 constexpr int32_t dim = 2;
4055 constexpr int32_t offset = 64;
4056 if (!globalIncrement)
4057 return setTensorDimXStride(op, adaptor, rewriter, loc, sgpr2, sgpr3,
4058 consts, dim, offset);
4059 return setGlobalAddrIncrement(op, adaptor, rewriter, loc, sgpr2, sgpr3,
4060 consts, offset);
4061 }
4062
4063 Value setIterateCount(DescriptorOp op, OpAdaptor adaptor,
4064 ConversionPatternRewriter &rewriter, Location loc,
4065 Value sgpr3, ArrayRef<Value> consts,
4066 int32_t offset) const {
4067 Value iterationCount = adaptor.getIterationCount();
4068 IntegerType i32 = rewriter.getI32Type();
4069 // pre-condition: iterationCount is in the inclusive interval [1, 256].
4070 // TODO: validation if the value breaks the pre-condition.
4071 // If the pre-condition fails, there is a possibility of
4072 // affecting the higher bits. In a following PR implement
4073 // RuntimeVerifiableOpInterface that instruments conditions that need to be
4074 // checked at runtime.
4075 iterationCount = LLVM::TruncOp::create(rewriter, loc, i32, iterationCount);
4076 iterationCount =
4077 LLVM::SubOp::create(rewriter, loc, iterationCount, consts[1]);
4078 return setValueAtOffset(rewriter, loc, sgpr3, iterationCount, offset);
4079 }
4080
4081 Value setTileDim3OrIterateCount(DescriptorOp op, OpAdaptor adaptor,
4082 ConversionPatternRewriter &rewriter,
4083 Location loc, Value sgpr3,
4084 ArrayRef<Value> consts) const {
4085 Value iterateCount = op.getIterationCount();
4086 constexpr int32_t dim = 2;
4087 constexpr int32_t offset = 112;
4088 if (!iterateCount)
4089 return setTileDimX(op, adaptor, rewriter, loc, sgpr3, consts, dim,
4090 offset);
4091
4092 return setIterateCount(op, adaptor, rewriter, loc, sgpr3, consts, offset);
4093 }
4094
4095 Value getDGroup2(DescriptorOp op, OpAdaptor adaptor,
4096 ConversionPatternRewriter &rewriter, Location loc,
4097 ArrayRef<Value> consts) const {
4098 if constexpr (DescriptorOp::isGather())
4099 return getDGroup2Gather(op, adaptor, rewriter, loc, consts);
4100 return getDGroup2NonGather(op, adaptor, rewriter, loc, consts);
4101 }
4102
4103 Value getDGroup2NonGather(DescriptorOp op, OpAdaptor adaptor,
4104 ConversionPatternRewriter &rewriter, Location loc,
4105 ArrayRef<Value> consts) const {
4106 IntegerType i32 = rewriter.getI32Type();
4107 Type v4i32 = this->typeConverter->convertType(VectorType::get(4, i32));
4108 assert(v4i32 && "expected type conversion to succeed.");
4109
4110 bool onlyNeedsTwoDescriptors = !op.getLdsIncrement() && op.getRank() <= 2;
4111 if (onlyNeedsTwoDescriptors)
4112 return LLVM::ZeroOp::create(rewriter, loc, v4i32);
4113
4114 constexpr int64_t sgprlen = 4;
4115 Value sgprs[sgprlen];
4116 for (int i = 0; i < sgprlen; ++i)
4117 sgprs[i] = consts[0];
4118
4119 sgprs[0] = setTensorDim2(op, adaptor, rewriter, loc, sgprs[0], consts);
4120 sgprs[1] = setTensorDim3OrLDSAddrIncrement(op, adaptor, rewriter, loc,
4121 sgprs[1], consts);
4122 std::tie(sgprs[2], sgprs[3]) = setTensorDim2StrideOrGlobalAddrIncrement(
4123 op, adaptor, rewriter, loc, sgprs[2], sgprs[3], consts);
4124 sgprs[3] =
4125 setTileDim3OrIterateCount(op, adaptor, rewriter, loc, sgprs[3], consts);
4126
4127 Value dgroup2 = LLVM::PoisonOp::create(rewriter, loc, v4i32);
4128 for (auto [sgpr, constant] : llvm::zip(sgprs, consts))
4129 dgroup2 =
4130 LLVM::InsertElementOp::create(rewriter, loc, dgroup2, sgpr, constant);
4131
4132 return dgroup2;
4133 }
4134
4135 Value getGatherIndices(DescriptorOp op, OpAdaptor adaptor,
4136 ConversionPatternRewriter &rewriter, Location loc,
4137 ArrayRef<Value> consts, bool firstHalf) const {
4138 IntegerType i32 = rewriter.getI32Type();
4139 Type v4i32 = this->typeConverter->convertType(VectorType::get(4, i32));
4140 assert(v4i32 && "expected type conversion to succeed.");
4141
4142 Value indices = adaptor.getIndices();
4143 auto vectorType = cast<VectorType>(indices.getType());
4144 unsigned length = vectorType.getShape().back();
4145 Type elementType = vectorType.getElementType();
4146 unsigned maxLength = elementType == i32 ? 4 : 8;
4147 int32_t offset = firstHalf ? 0 : maxLength;
4148 unsigned discountedLength =
4149 std::max(static_cast<int32_t>(length - offset), 0);
4150
4151 unsigned targetSize = std::min(maxLength, discountedLength);
4152
4153 SmallVector<Value> indicesVector;
4154 for (unsigned i = offset; i < targetSize + offset; ++i) {
4155 Value idx;
4156 if (i < consts.size())
4157 idx = consts[i];
4158 else
4159 idx = createI32Constant(rewriter, loc, i);
4160 Value elem = LLVM::ExtractElementOp::create(rewriter, loc, indices, idx);
4161 indicesVector.push_back(elem);
4162 }
4163
4164 SmallVector<Value> indicesI32Vector;
4165 if (elementType == i32) {
4166 indicesI32Vector = indicesVector;
4167 } else {
4168 for (unsigned i = 0; i < targetSize; ++i) {
4169 Value index = indicesVector[i];
4170 indicesI32Vector.push_back(
4171 LLVM::ZExtOp::create(rewriter, loc, i32, index));
4172 }
4173 if ((targetSize % 2) != 0)
4174 // Add padding when not divisible by two.
4175 indicesI32Vector.push_back(consts[0]);
4176 }
4177
4178 SmallVector<Value> indicesToInsert;
4179 if (elementType == i32) {
4180 indicesToInsert = indicesI32Vector;
4181 } else {
4182 unsigned size = indicesI32Vector.size() / 2;
4183 for (unsigned i = 0; i < size; ++i) {
4184 Value first = indicesI32Vector[2 * i];
4185 Value second = indicesI32Vector[2 * i + 1];
4186 Value joined = setValueAtOffset(rewriter, loc, first, second, 16);
4187 indicesToInsert.push_back(joined);
4188 }
4189 }
4190
4191 Value dgroup = LLVM::PoisonOp::create(rewriter, loc, v4i32);
4192 for (auto [sgpr, constant] : llvm::zip_first(indicesToInsert, consts))
4193 dgroup =
4194 LLVM::InsertElementOp::create(rewriter, loc, dgroup, sgpr, constant);
4195
4196 return dgroup;
4197 }
4198
4199 Value getDGroup2Gather(DescriptorOp op, OpAdaptor adaptor,
4200 ConversionPatternRewriter &rewriter, Location loc,
4201 ArrayRef<Value> consts) const {
4202 return getGatherIndices(op, adaptor, rewriter, loc, consts, true);
4203 }
4204
4205 std::pair<Value, Value>
4206 setTensorDim3Stride(DescriptorOp op, OpAdaptor adaptor,
4207 ConversionPatternRewriter &rewriter, Location loc,
4208 Value sgpr0, Value sgpr1, ArrayRef<Value> consts) const {
4209 constexpr int32_t dim = 3;
4210 constexpr int32_t offset = 0;
4211 return setTensorDimXStride(op, adaptor, rewriter, loc, sgpr0, sgpr1, consts,
4212 dim, offset);
4213 }
4214
4215 std::pair<Value, Value> setTensorDim4(DescriptorOp op, OpAdaptor adaptor,
4216 ConversionPatternRewriter &rewriter,
4217 Location loc, Value sgpr1, Value sgpr2,
4218 ArrayRef<Value> consts) const {
4219 constexpr int32_t dim = 4;
4220 constexpr int32_t offset = 48;
4221 return setTensorDimX(op, adaptor, rewriter, loc, sgpr1, sgpr2, consts, dim,
4222 offset);
4223 }
4224
4225 Value setTileDim4(DescriptorOp op, OpAdaptor adaptor,
4226 ConversionPatternRewriter &rewriter, Location loc,
4227 Value sgpr2, ArrayRef<Value> consts) const {
4228 constexpr int32_t dim = 4;
4229 constexpr int32_t offset = 80;
4230 return setTileDimX(op, adaptor, rewriter, loc, sgpr2, consts, dim, offset);
4231 }
4232
4233 Value getDGroup3(DescriptorOp op, OpAdaptor adaptor,
4234 ConversionPatternRewriter &rewriter, Location loc,
4235 ArrayRef<Value> consts) const {
4236 if constexpr (DescriptorOp::isGather())
4237 return getDGroup3Gather(op, adaptor, rewriter, loc, consts);
4238 return getDGroup3NonGather(op, adaptor, rewriter, loc, consts);
4239 }
4240
4241 Value getDGroup3NonGather(DescriptorOp op, OpAdaptor adaptor,
4242 ConversionPatternRewriter &rewriter, Location loc,
4243 ArrayRef<Value> consts) const {
4244 IntegerType i32 = rewriter.getI32Type();
4245 Type v4i32 = this->typeConverter->convertType(VectorType::get(4, i32));
4246 assert(v4i32 && "expected type conversion to succeed.");
4247 bool onlyNeedsTwoDescriptors = !op.getLdsIncrement() && op.getRank() <= 2;
4248 if (onlyNeedsTwoDescriptors)
4249 return LLVM::ZeroOp::create(rewriter, loc, v4i32);
4250
4251 constexpr int32_t sgprlen = 4;
4252 Value sgprs[sgprlen];
4253 for (int i = 0; i < sgprlen; ++i)
4254 sgprs[i] = consts[0];
4255
4256 std::tie(sgprs[0], sgprs[1]) = setTensorDim3Stride(
4257 op, adaptor, rewriter, loc, sgprs[0], sgprs[1], consts);
4258 std::tie(sgprs[1], sgprs[2]) =
4259 setTensorDim4(op, adaptor, rewriter, loc, sgprs[1], sgprs[2], consts);
4260 sgprs[2] = setTileDim4(op, adaptor, rewriter, loc, sgprs[2], consts);
4261
4262 Value dgroup3 = LLVM::PoisonOp::create(rewriter, loc, v4i32);
4263 for (auto [sgpr, constant] : llvm::zip(sgprs, consts))
4264 dgroup3 =
4265 LLVM::InsertElementOp::create(rewriter, loc, dgroup3, sgpr, constant);
4266
4267 return dgroup3;
4268 }
4269
4270 Value getDGroup3Gather(DescriptorOp op, OpAdaptor adaptor,
4271 ConversionPatternRewriter &rewriter, Location loc,
4272 ArrayRef<Value> consts) const {
4273 return getGatherIndices(op, adaptor, rewriter, loc, consts, false);
4274 }
4275
4276 LogicalResult
4277 matchAndRewrite(DescriptorOp op, OpAdaptor adaptor,
4278 ConversionPatternRewriter &rewriter) const override {
4279 if (chipset < kGfx1250)
4280 return op->emitOpError(
4281 "make_dma_descriptor is only supported on gfx1250");
4282
4283 Location loc = op.getLoc();
4284
4285 SmallVector<Value> consts;
4286 for (int64_t i = 0; i < 8; ++i)
4287 consts.push_back(createI32Constant(rewriter, loc, i));
4288
4289 Value dgroup0 = this->getDGroup0(adaptor);
4290 Value dgroup1 = this->getDGroup1(op, adaptor, rewriter, loc, consts);
4291 Value dgroup2 = this->getDGroup2(op, adaptor, rewriter, loc, consts);
4292 Value dgroup3 = this->getDGroup3(op, adaptor, rewriter, loc, consts);
4293 SmallVector<Value> results = {dgroup0, dgroup1, dgroup2, dgroup3};
4294 rewriter.replaceOpWithMultiple(op, {results});
4295 return success();
4296 }
4297};
4298
4299template <typename SourceOp, typename TargetOp>
4300struct AMDGPUTensorLoadStoreOpLowering
4301 : public ConvertOpToLLVMPattern<SourceOp> {
4302 using ConvertOpToLLVMPattern<SourceOp>::ConvertOpToLLVMPattern;
4304 AMDGPUTensorLoadStoreOpLowering(const LLVMTypeConverter &converter,
4305 Chipset chipset)
4306 : ConvertOpToLLVMPattern<SourceOp>(converter), chipset(chipset) {}
4307 Chipset chipset;
4308
4309 LogicalResult
4310 matchAndRewrite(SourceOp op, Adaptor adaptor,
4311 ConversionPatternRewriter &rewriter) const override {
4312 if (chipset < kGfx1250)
4313 return op->emitOpError("is only supported on gfx1250");
4314
4315 ValueRange desc = adaptor.getDesc();
4316 // Create a <v8 x i32> 0 as the fifth argument to match llvm intrinsic. It
4317 // will move into the TDM descriptor once it becomes relevant for future use
4318 auto v8i32 = VectorType::get(8, rewriter.getI32Type());
4319 Value dgroup4 = LLVM::ZeroOp::create(rewriter, op.getLoc(), v8i32);
4320 rewriter.replaceOpWithNewOp<TargetOp>(op, desc[0], desc[1], desc[2],
4321 desc[3], dgroup4, /*cachePolicy=*/0,
4322 /*alias_scopes=*/nullptr,
4323 /*noalias_scopes=*/nullptr,
4324 /*tbaa=*/nullptr);
4325 return success();
4326 }
4327};
4328
4329struct GlobalPrefetchOpLowering
4330 : public ConvertOpToLLVMPattern<GlobalPrefetchOp> {
4331 GlobalPrefetchOpLowering(const LLVMTypeConverter &converter, Chipset chipset)
4332 : ConvertOpToLLVMPattern<GlobalPrefetchOp>(converter), chipset(chipset) {}
4333
4334 LogicalResult
4335 matchAndRewrite(GlobalPrefetchOp op, GlobalPrefetchOpAdaptor adaptor,
4336 ConversionPatternRewriter &rewriter) const override {
4337 if (chipset < kGfx1250)
4338 return op->emitOpError("is only supported on gfx1250+");
4339
4340 const bool isSpeculative = op.getSpeculative();
4341 const int32_t immArgValue = getGlobalPrefetchLLVMEncoding(
4342 op.getTemporalHint(), op.getCacheScope(), isSpeculative);
4343 IntegerAttr immArgAttr = rewriter.getI32IntegerAttr(immArgValue);
4344
4345 ValueRange indices = adaptor.getIndices();
4346 Value memRef = adaptor.getSrc();
4347 MemRefDescriptor descriptor(memRef);
4348 MemRefType memRefType = op.getSrc().getType();
4349 Location loc = op->getLoc();
4350 auto inboundsFlags = isSpeculative ? LLVM::GEPNoWrapFlags::none
4351 : LLVM::GEPNoWrapFlags::inbounds |
4352 LLVM::GEPNoWrapFlags::nuw;
4353 Value prefetchPtr = getStridedElementPtr(
4354 rewriter, loc, memRefType, descriptor, indices, inboundsFlags);
4355
4356 rewriter.replaceOpWithNewOp<ROCDL::GlobalPrefetchOp>(
4357 op, prefetchPtr, immArgAttr, mlir::ArrayAttr{}, mlir::ArrayAttr{},
4358 mlir::ArrayAttr{});
4359 return success();
4360 }
4361
4362private:
4363 Chipset chipset;
4364};
4365
4366struct ConvertAMDGPUToROCDLPass
4367 : public impl::ConvertAMDGPUToROCDLPassBase<ConvertAMDGPUToROCDLPass> {
4368 using Base::Base;
4369
4370 void runOnOperation() override {
4371 MLIRContext *ctx = &getContext();
4372 FailureOr<Chipset> maybeChipset = Chipset::parse(chipset);
4373 if (failed(maybeChipset)) {
4374 emitError(UnknownLoc::get(ctx), "Invalid chipset name: " + chipset);
4375 return signalPassFailure();
4376 }
4377
4378 RewritePatternSet patterns(ctx);
4379 LLVMTypeConverter converter(ctx);
4380
4381 populateAMDGPUToROCDLConversionPatterns(converter, patterns, *maybeChipset);
4382 amdgpu::populateCommonGPUTypeAndAttributeConversions(converter);
4383 LLVMConversionTarget target(getContext());
4384 target.addIllegalDialect<::mlir::amdgpu::AMDGPUDialect>();
4385 target.addLegalDialect<::mlir::LLVM::LLVMDialect>();
4386 target.addLegalDialect<::mlir::ROCDL::ROCDLDialect>();
4387 if (failed(applyPartialConversion(getOperation(), target,
4388 std::move(patterns))))
4389 signalPassFailure();
4390 }
4391};
4392} // namespace
4393
4395 TypeConverter &typeConverter) {
4397 typeConverter, [](gpu::AddressSpace space) {
4398 switch (space) {
4399 case gpu::AddressSpace::Global:
4400 return ROCDL::ROCDLDialect::kGlobalMemoryAddressSpace;
4401 case gpu::AddressSpace::Workgroup:
4402 return ROCDL::ROCDLDialect::kSharedMemoryAddressSpace;
4403 case gpu::AddressSpace::Private:
4404 return ROCDL::ROCDLDialect::kPrivateMemoryAddressSpace;
4405 case gpu::AddressSpace::Constant:
4406 return ROCDL::ROCDLDialect::kConstantMemoryAddressSpace;
4407 }
4408 llvm_unreachable("unknown address space enum value");
4409 });
4410 typeConverter.addConversion([](gpu::NamedBarrierType type) {
4411 return LLVM::LLVMPointerType::get(
4412 type.getContext(), ROCDL::ROCDLDialect::kSharedMemoryAddressSpace);
4413 });
4414}
4415
4417 TypeConverter &typeConverter) {
4418 typeConverter.addTypeAttributeConversion(
4419 [](BaseMemRefType type, amdgpu::AddressSpaceAttr as)
4420 -> TypeConverter::AttributeConversionResult {
4421 MLIRContext *ctx = as.getContext();
4422 Type i64 = IntegerType::get(ctx, 64);
4423 switch (as.getValue()) {
4424 case amdgpu::AddressSpace::FatRawBuffer:
4425 return IntegerAttr::get(i64, 7);
4426 case amdgpu::AddressSpace::BufferRsrc:
4427 return IntegerAttr::get(i64, 8);
4428 case amdgpu::AddressSpace::FatStructuredBuffer:
4429 return IntegerAttr::get(i64, 9);
4430 }
4431 return TypeConverter::AttributeConversionResult::abort();
4432 });
4433 typeConverter.addConversion([&](DsBarrierStateType type) -> Type {
4434 return IntegerType::get(type.getContext(), 64);
4435 });
4436 typeConverter.addConversion([&](TDMBaseType type) -> Type {
4437 Type i32 = IntegerType::get(type.getContext(), 32);
4438 return typeConverter.convertType(VectorType::get(4, i32));
4439 });
4440 typeConverter.addConversion([&](TDMGatherBaseType type) -> Type {
4441 Type i32 = IntegerType::get(type.getContext(), 32);
4442 return typeConverter.convertType(VectorType::get(4, i32));
4443 });
4444 typeConverter.addConversion(
4445 [&](TDMDescriptorType type,
4446 SmallVectorImpl<Type> &result) -> std::optional<LogicalResult> {
4447 Type i32 = IntegerType::get(type.getContext(), 32);
4448 Type v4i32 = typeConverter.convertType(VectorType::get(4, i32));
4449 Type v8i32 = typeConverter.convertType(VectorType::get(8, i32));
4450 llvm::append_values(result, v4i32, v8i32, v4i32, v4i32);
4451 return success();
4452 });
4453
4454 auto addUnrealizedCast = [](OpBuilder &builder, TypeRange types,
4455 ValueRange inputs,
4457 // Only create unrealized_conversion_cast for TDMDescriptorType.
4458 // All other types which are not expected, should be
4459 // materialized by other target materialization functions.
4460 if (inputs.size() != 1)
4461 return {};
4462
4463 if (!isa<TDMDescriptorType>(inputs[0].getType()))
4464 return {};
4465
4466 auto cast = UnrealizedConversionCastOp::create(builder, loc, types, inputs);
4467 return cast.getResults();
4468 };
4469
4470 typeConverter.addTargetMaterialization(addUnrealizedCast);
4471}
4472
4474 RewritePatternSet &patterns,
4475 Chipset chipset) {
4477 patterns
4478 .add<FatRawBufferCastLowering,
4479 RawBufferOpLowering<RawBufferLoadOp, ROCDL::RawPtrBufferLoadOp>,
4480 RawBufferOpLowering<RawBufferStoreOp, ROCDL::RawPtrBufferStoreOp>,
4481 RawBufferOpLowering<RawBufferAtomicFaddOp,
4482 ROCDL::RawPtrBufferAtomicFaddOp>,
4483 RawBufferOpLowering<RawBufferAtomicFmaxOp,
4484 ROCDL::RawPtrBufferAtomicFmaxOp>,
4485 RawBufferOpLowering<RawBufferAtomicSmaxOp,
4486 ROCDL::RawPtrBufferAtomicSmaxOp>,
4487 RawBufferOpLowering<RawBufferAtomicUminOp,
4488 ROCDL::RawPtrBufferAtomicUminOp>,
4489 RawBufferOpLowering<RawBufferAtomicCmpswapOp,
4490 ROCDL::RawPtrBufferAtomicCmpSwap>,
4491 AMDGPUDPPLowering, MemoryCounterWaitOpLowering, LDSBarrierOpLowering,
4492 SchedBarrierOpLowering, MFMAOpLowering, ScaledMFMAOpLowering,
4493 SparseMFMAOpLowering, WMMAOpLowering, ScaledWMMAOpLowering,
4494 SparseWMMAOpLowering, DotOpLowering, ExtPackedFp8OpLowering,
4495 ScaledExtPackedMatrixOpLowering, ScaledExtPackedOpLowering,
4496 PackedScaledTruncOpLowering, PackedTrunc2xFp8OpLowering,
4497 PackedStochRoundFp8OpLowering, GatherToLDSOpLowering,
4498 GlobalLoadAsyncToLDSOpLowering, TransposeLoadOpLowering,
4499 GlobalTransposeLoadOpLowering, AMDGPUPermlaneLowering,
4500 AMDGPUMakeDmaBaseLowering<MakeDmaBaseOp>,
4501 AMDGPUMakeDmaBaseLowering<MakeGatherDmaBaseOp>,
4502 AMDGPULowerDescriptor<MakeDmaDescriptorOp>,
4503 AMDGPULowerDescriptor<MakeGatherDmaDescriptorOp>,
4504 AMDGPUTensorLoadStoreOpLowering<TensorLoadToLDSOp,
4505 ROCDL::TensorLoadToLDSOp>,
4506 AMDGPUTensorLoadStoreOpLowering<TensorStoreFromLDSOp,
4507 ROCDL::TensorStoreFromLDSOp>,
4508 DsBarrierInitOpLowering, DsBarrierPollStateOpLowering,
4509 DsAsyncBarrierArriveOpLowering, DsBarrierArriveOpLowering,
4510 GlobalPrefetchOpLowering>(converter, chipset);
4511 patterns.add<AMDGPUSwizzleBitModeLowering, DsBarrierStatePhaseOpLowering,
4512 DsBarrierStatePendingCountOpLowering,
4513 DsBarrierStateInitCountOpLowering,
4514 DsBarrierStatePhaseParityLowering>(converter);
4515}
static bool typeIsExpectedFp8ForChipset(Chipset chipset, Type type)
Return true if type is the E4M3FN variant of an 8-bit float that is supported by the _fp8 instruction...
constexpr Chipset kGfx942
static std::optional< StringRef > wmmaOpToIntrinsicRDNA(Type elemSourceType, Type elemBSourceType, Type elemDestType, uint32_t k, bool isRDNA3)
Returns the rocdl intrinsic corresponding to a WMMA operation wmma for RDNA3/4 architectures.
static bool hasDot10Insts(const Chipset &chipset)
static bool hasDot7Insts(const Chipset &chipset)
static std::optional< SparseWMMAOpInfo > sparseWMMAOpToIntrinsic(SparseWMMAOp swmmac, Chipset chipset)
static std::optional< std::tuple< StringRef, uint32_t, uint32_t > > mfmaOpToScaledIntrinsic(Type aType, Type bType, Type destType, uint32_t m, uint32_t n, uint32_t k, uint32_t b, Chipset chipset)
If there is a scaled MFMA instruction for the input element types aType and bType,...
static std::optional< StringRef > mfmaOpToIntrinsic(MFMAOp mfma, Chipset chipset)
Return the rocdl intrinsic corresponding to a MFMA operation mfma if one exists.
constexpr Chipset kGfx908
static void wmmaPushInputOperand(ConversionPatternRewriter &rewriter, Location loc, const TypeConverter *typeConverter, bool isUnsigned, Value llvmInput, Value mlirInput, SmallVectorImpl< Value > &operands, SmallVectorImpl< NamedAttribute > &attrs, StringRef attrName)
Push an input operand.
constexpr Chipset kGfx1250
static Value castScaleOperand(ConversionPatternRewriter &rewriter, Location loc, Value input)
Converts the scaled MFMA/WMMA operands, scalesA and scalesB, from MLIR AMDGPU dialect convention to R...
constexpr Chipset kGfx90a
static std::optional< StringRef > getScaledWmmaIntrinsicName(int64_t m, int64_t n, int64_t k, bool isScale16)
Determines the ROCDL intrinsic name for scaled WMMA based on dimensions and scale block size (16 or 3...
static void wmmaPushOutputOperand(ConversionPatternRewriter &rewriter, Location loc, const TypeConverter *typeConverter, Value output, int32_t subwordOffset, bool clamp, SmallVectorImpl< Value > &operands, SmallVectorImpl< NamedAttribute > &attrs)
Push the output operand.
static bool typeIsExpectedBf8ForChipset(Chipset chipset, Type type)
Return true if type is the E5M2 variant of an 8-bit float that is supported by the _bf8 instructions ...
static std::optional< StringRef > wmmaOpToIntrinsic(WMMAOp wmma, Chipset chipset)
Returns the rocdl intrinsic corresponding to a WMMA operation wmma if one exists.
static bool hasDot11Insts(const Chipset &chipset)
static std::optional< StringRef > smfmacOpToIntrinsic(SparseMFMAOp op, Chipset chipset)
Returns the rocdl intrinsic corresponding to a SparseMFMA (smfmac) operation if one exists.
static Value makeBufferRsrc(ConversionPatternRewriter &rewriter, Location loc, Value basePointer, Value numRecords, bool boundsCheck, amdgpu::Chipset chipset, Value cacheSwizzleStride=nullptr, unsigned addressSpace=8)
static Value createI64Constant(ConversionPatternRewriter &rewriter, Location loc, int64_t value)
static bool hasDot9Insts(const Chipset &chipset)
static std::optional< StringRef > wmmaOpToIntrinsicGfx1250(Type elemSourceType, Type elemBSourceType, Type elemDestType, uint32_t k)
Return the rocdl intrinsic corresponding to a WMMA operation wmma for the gfx1250 architecture.
constexpr Chipset kGfx1200
static Value getNumRecords(ConversionPatternRewriter &rewriter, Location loc, MemRefType memrefType, MemRefDescriptor &memrefDescriptor, ArrayRef< int64_t > strides, int64_t elementByteWidth, amdgpu::Chipset chipset, bool boundsCheck)
Compute the contents of the num_records field for a given memref descriptor - that is,...
static Value packSmallFloatVectorOperand(ConversionPatternRewriter &rewriter, Location loc, Value input, bool allowBf16=true)
Pack small float vector operands (fp4/fp6/fp8/bf16) into the format expected by scaled matrix multipl...
static std::optional< uint32_t > getWmmaScaleFormat(Type elemType)
Maps f8 scale element types to WMMA scale format codes.
static Value convertPackedVectorOperand(ConversionPatternRewriter &rewriter, Location loc, Value input, bool allowBf16=true)
Converts packed vector operands to the expected ROCDL types.
static Value getLinearIndexI32(ConversionPatternRewriter &rewriter, Location loc, MemRefDescriptor &memRefDescriptor, ValueRange indices, ArrayRef< int64_t > strides)
Returns the linear index used to access an element in the memref.
static Value convertUnsignedToI32(ConversionPatternRewriter &rewriter, Location loc, Value val)
Convert an unsigned number val to i32.
static bool hasDot8Insts(const Chipset &chipset)
static bool hasDot2Insts(const Chipset &chipset)
static Value createI32Constant(ConversionPatternRewriter &rewriter, Location loc, int32_t value)
static bool hasDot12Insts(const Chipset &chipset)
static std::optional< uint32_t > smallFloatTypeToFormatCode(Type mlirElemType)
static Value convertUnsignedToI64(ConversionPatternRewriter &rewriter, Location loc, Value val)
Convert an unsigned number val to i64.
constexpr Chipset kGfx950
static bool hasDot1Insts(const Chipset &chipset)
return success()
b
Return true if permutation is a valid permutation of the outer_dims_perm (case OuterOrInnerPerm::Oute...
ArrayAttr()
b getContext())
auto load
*if copies could not be generated due to yet unimplemented cases *copyInPlacementStart and copyOutPlacementStart in copyPlacementBlock *specify the insertion points where the incoming copies and outgoing should be the output argument nBegin is set to its * replacement(set to `begin` if no invalidation happens). Since outgoing *copies could have been inserted at `end`
static constexpr unsigned kSizePosInMemRefDescriptor
static constexpr unsigned kStridePosInMemRefDescriptor
static constexpr unsigned kOffsetPosInMemRefDescriptor
static constexpr unsigned kAllocatedPtrPosInMemRefDescriptor
static constexpr unsigned kAlignedPtrPosInMemRefDescriptor
static Value clamp(ImplicitLocOpBuilder &builder, Value value, Value lowerBound, Value upperBound)
This class provides a shared interface for ranked and unranked memref types.
Utility class for operation conversions targeting the LLVM dialect that match exactly one source oper...
Definition Pattern.h:227
ConvertOpToLLVMPattern(const LLVMTypeConverter &typeConverter, PatternBenefit benefit=1)
Definition Pattern.h:233
typename SourceOp::template GenericAdaptor< ArrayRef< ValueRange > > OneToNOpAdaptor
Definition Pattern.h:230
typename SourceOp::Adaptor OpAdaptor
Definition Pattern.h:229
Value getStridedElementPtr(ConversionPatternRewriter &rewriter, Location loc, MemRefType type, Value memRefDesc, ValueRange indices, LLVM::GEPNoWrapFlags noWrapFlags=LLVM::GEPNoWrapFlags::none) const
Convenience wrapper for the corresponding helper utility.
Definition Pattern.cpp:66
llvm::TypeSize getTypeSizeInBits(Type t) const
Returns the size in bits of the given type in the current scope.
Conversion from types to the LLVM IR dialect.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition Location.h:76
MLIRContext is the top-level object for a collection of MLIR operations.
Definition MLIRContext.h:63
Helper class to produce LLVM dialect operations extracting or inserting elements of a MemRef descript...
Value stride(OpBuilder &builder, Location loc, unsigned pos)
Builds IR extracting the pos-th size from the descriptor.
Value size(OpBuilder &builder, Location loc, unsigned pos)
Builds IR extracting the pos-th size from the descriptor.
NamedAttribute represents a combination of a name and an Attribute value.
Definition Attributes.h:164
This class helps build Operations.
Definition Builders.h:209
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Definition Operation.h:432
result_range getResults()
Definition Operation.h:440
unsigned getNumResults()
Return the number of results held by this operation.
Definition Operation.h:429
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
This class provides an abstraction over the various different ranges of value types.
Definition TypeRange.h:40
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition Types.h:74
bool isF64() const
Definition Types.cpp:41
MLIRContext * getContext() const
Return the MLIRContext in which this type was uniqued.
Definition Types.cpp:35
bool isF8E5M2() const
Definition Types.cpp:45
bool isSignedInteger() const
Return true if this is a signed integer type (with the specified width).
Definition Types.cpp:78
bool isF8E4M3FN() const
Definition Types.cpp:44
bool isF32() const
Definition Types.cpp:40
bool isUnsignedInteger() const
Return true if this is an unsigned integer type (with the specified width).
Definition Types.cpp:90
bool isInteger() const
Return true if this is an integer type (with the specified width).
Definition Types.cpp:58
bool isF16() const
Definition Types.cpp:38
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
Definition Types.cpp:124
bool isBF16() const
Definition Types.cpp:37
This class provides an abstraction over the different types of ranges over Values.
Definition ValueRange.h:389
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
Type getType() const
Return the type of this value.
Definition Value.h:105
Value getStridedElementPtr(OpBuilder &builder, Location loc, const LLVMTypeConverter &converter, MemRefType type, Value memRefDesc, ValueRange indices, LLVM::GEPNoWrapFlags noWrapFlags=LLVM::GEPNoWrapFlags::none)
Performs the index computation to get to the element at indices of the memory pointed to by memRefDes...
Definition Pattern.cpp:603
int32_t getGlobalPrefetchLLVMEncoding(amdgpu::LoadTemporalHint hint, amdgpu::Scope scope, bool isSpeculative)
Definition AMDGPUEnums.h:18
bool hasOcpFp8(const Chipset &chipset)
Definition Chipset.h:52
void populateCommonGPUTypeAndAttributeConversions(TypeConverter &typeConverter)
Remap common GPU memory spaces (Workgroup, Private, etc) to LLVM address spaces.
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition Remarks.h:717
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
Definition Matchers.h:490
SmallVector< OpFoldResult > getMixedValues(ArrayRef< int64_t > staticValues, ValueRange dynamicValues, MLIRContext *context)
Return a vector of OpFoldResults with the same size a staticValues, but all elements for which Shaped...
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
Definition Utils.cpp:307
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
detail::constant_int_predicate_matcher m_Zero()
Matches a constant scalar / vector splat / tensor splat integer zero.
Definition Matchers.h:442
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
void populateGpuMemorySpaceAttributeConversions(TypeConverter &typeConverter, const MemorySpaceMapping &mapping)
Populates memory space attribute conversion rules for lowering gpu.address_space to integer values.
void populateAMDGPUToROCDLConversionPatterns(LLVMTypeConverter &converter, RewritePatternSet &patterns, amdgpu::Chipset chipset)
Note: This function will also add conversions for the AMDGPU-specific address spaces and types,...
llvm::TypeSwitch< T, ResultT > TypeSwitch
Definition LLVM.h:139
void populateAMDGPUTypeAndAttributeConversions(TypeConverter &typeConverter)
Remap AMDGPU memory spaces to LLVM address spaces by mapping amdgpu::AddressSpace::fat_raw_buffer to ...
Returns the rocdl intrinsic corresponding to a SparseWMMA operation swmmac if one exists.
Represents the amdgpu gfx chipset version, e.g., gfx90a, gfx942, gfx1103.
Definition Chipset.h:22
unsigned majorVersion
Definition Chipset.h:23
unsigned minorVersion
Definition Chipset.h:24
static FailureOr< Chipset > parse(StringRef name)
Parses the chipset version string and returns the chipset on success, and failure otherwise.
Definition Chipset.cpp:14
unsigned steppingVersion
Definition Chipset.h:25