MLIR 22.0.0git
AMDGPUToROCDL.cpp
Go to the documentation of this file.
1//===- AMDGPUToROCDL.cpp - AMDGPU to ROCDL dialect conversion -------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
10
22#include "mlir/Pass/Pass.h"
23
25
26#include "llvm/ADT/STLExtras.h"
27#include "llvm/ADT/TypeSwitch.h"
28#include "llvm/Support/Casting.h"
29#include "llvm/Support/ErrorHandling.h"
30#include <optional>
31
32namespace mlir {
33#define GEN_PASS_DEF_CONVERTAMDGPUTOROCDLPASS
34#include "mlir/Conversion/Passes.h.inc"
35} // namespace mlir
36
37using namespace mlir;
38using namespace mlir::amdgpu;
39
40// Define commonly used chipsets versions for convenience.
41constexpr Chipset kGfx908 = Chipset(9, 0, 8);
42constexpr Chipset kGfx90a = Chipset(9, 0, 0xa);
43constexpr Chipset kGfx942 = Chipset(9, 4, 2);
44constexpr Chipset kGfx950 = Chipset(9, 5, 0);
45
46/// Convert an unsigned number `val` to i32.
47static Value convertUnsignedToI32(ConversionPatternRewriter &rewriter,
48 Location loc, Value val) {
49 IntegerType i32 = rewriter.getI32Type();
50 // Force check that `val` is of int type.
51 auto valTy = cast<IntegerType>(val.getType());
52 if (i32 == valTy)
53 return val;
54 return valTy.getWidth() > 32
55 ? Value(LLVM::TruncOp::create(rewriter, loc, i32, val))
56 : Value(LLVM::ZExtOp::create(rewriter, loc, i32, val));
57}
58
59static Value createI32Constant(ConversionPatternRewriter &rewriter,
60 Location loc, int32_t value) {
61 return LLVM::ConstantOp::create(rewriter, loc, rewriter.getI32Type(), value);
62}
63
64/// Convert an unsigned number `val` to i64.
65static Value convertUnsignedToI64(ConversionPatternRewriter &rewriter,
66 Location loc, Value val) {
67 IntegerType i64 = rewriter.getI64Type();
68 // Force check that `val` is of int type.
69 auto valTy = cast<IntegerType>(val.getType());
70 if (i64 == valTy)
71 return val;
72 return valTy.getWidth() > 64
73 ? Value(LLVM::TruncOp::create(rewriter, loc, i64, val))
74 : Value(LLVM::ZExtOp::create(rewriter, loc, i64, val));
75}
76
77static Value createI64Constant(ConversionPatternRewriter &rewriter,
78 Location loc, int64_t value) {
79 return LLVM::ConstantOp::create(rewriter, loc, rewriter.getI64Type(), value);
80}
81
82static Value createI1Constant(ConversionPatternRewriter &rewriter, Location loc,
83 bool value) {
84 Type llvmI1 = rewriter.getI1Type();
85 return LLVM::ConstantOp::create(rewriter, loc, llvmI1, value);
86}
87
88/// Returns the linear index used to access an element in the memref.
89static Value getLinearIndexI32(ConversionPatternRewriter &rewriter,
90 Location loc, MemRefDescriptor &memRefDescriptor,
92 IntegerType i32 = rewriter.getI32Type();
94 for (auto [i, increment, stride] : llvm::enumerate(indices, strides)) {
95 if (stride != 1) { // Skip if stride is 1.
96 Value strideValue =
97 ShapedType::isDynamic(stride)
98 ? convertUnsignedToI32(rewriter, loc,
99 memRefDescriptor.stride(rewriter, loc, i))
100 : LLVM::ConstantOp::create(rewriter, loc, i32, stride);
101 increment = LLVM::MulOp::create(rewriter, loc, increment, strideValue);
102 }
103 index = index ? LLVM::AddOp::create(rewriter, loc, index, increment)
104 : increment;
105 }
106 return index ? index : createI32Constant(rewriter, loc, 0);
107}
108
109/// Compute the contents of the `num_records` field for a given memref
110/// descriptor - that is, the number of bytes that's one element past the
111/// greatest possible valid index into the memref.
112static Value getNumRecords(ConversionPatternRewriter &rewriter, Location loc,
113 MemRefType memrefType,
114 MemRefDescriptor &memrefDescriptor,
115 ArrayRef<int64_t> strides,
116 int64_t elementByteWidth) {
117 if (memrefType.hasStaticShape() &&
118 !llvm::any_of(strides, ShapedType::isDynamic)) {
119 int64_t size = memrefType.getRank() == 0 ? 1 : 0;
120 ArrayRef<int64_t> shape = memrefType.getShape();
121 for (uint32_t i = 0, e = memrefType.getRank(); i < e; ++i)
122 size = std::max(shape[i] * strides[i], size);
123 size = size * elementByteWidth;
124 return createI64Constant(rewriter, loc, size);
125 }
126 Value maxIndex;
127 for (uint32_t i = 0, e = memrefType.getRank(); i < e; ++i) {
128 Value size = memrefDescriptor.size(rewriter, loc, i);
129 Value stride = memrefDescriptor.stride(rewriter, loc, i);
130 Value maxThisDim = LLVM::MulOp::create(rewriter, loc, size, stride);
131 maxIndex = maxIndex
132 ? LLVM::UMaxOp::create(rewriter, loc, maxIndex, maxThisDim)
133 : maxThisDim;
134 }
135 Value maxIndexI64 = convertUnsignedToI64(rewriter, loc, maxIndex);
136 Value byteWidthConst = createI64Constant(rewriter, loc, elementByteWidth);
137 return LLVM::MulOp::create(rewriter, loc, maxIndexI64, byteWidthConst);
138}
139
140static Value makeBufferRsrc(ConversionPatternRewriter &rewriter, Location loc,
141 Value basePointer, Value numRecords,
142 bool boundsCheck, amdgpu::Chipset chipset,
143 Value cacheSwizzleStride = nullptr,
144 unsigned addressSpace = 8) {
145 // The stride value is generally 0. However, on MI-300 and onward, you can
146 // enable a cache swizzling mode by setting bit 14 of the stride field
147 // and setting that stride to a cache stride.
148 Type i16 = rewriter.getI16Type();
149 Value stride;
150 if (chipset.majorVersion == 9 && chipset >= kGfx942 && cacheSwizzleStride) {
151 Value cacheStrideZext =
152 LLVM::ZExtOp::create(rewriter, loc, i16, cacheSwizzleStride);
153 Value swizzleBit = LLVM::ConstantOp::create(
154 rewriter, loc, i16, rewriter.getI16IntegerAttr(1 << 14));
155 stride = LLVM::OrOp::create(rewriter, loc, cacheStrideZext, swizzleBit,
156 /*isDisjoint=*/true);
157 } else {
158 stride = LLVM::ConstantOp::create(rewriter, loc, i16,
159 rewriter.getI16IntegerAttr(0));
160 }
161 // Get the number of elements.
162 // Flag word:
163 // bits 0-11: dst sel, ignored by these intrinsics
164 // bits 12-14: data format (ignored, must be nonzero, 7=float)
165 // bits 15-18: data format (ignored, must be nonzero, 4=32bit)
166 // bit 19: In nested heap (0 here)
167 // bit 20: Behavior on unmap (0 means "return 0 / ignore")
168 // bits 21-22: Index stride for swizzles (N/A)
169 // bit 23: Add thread ID (0)
170 // bit 24: Reserved to 1 (RDNA) or 0 (CDNA)
171 // bits 25-26: Reserved (0)
172 // bit 27: Buffer is non-volatile (CDNA only)
173 // bits 28-29: Out of bounds select (0 = structured, 1 = check index, 2 =
174 // none, 3 = either swizzles or testing against offset field) RDNA only
175 // bits 30-31: Type (must be 0)
176 uint32_t flags = (7 << 12) | (4 << 15);
177 if (chipset.majorVersion >= 10) {
178 flags |= (1 << 24);
179 uint32_t oob = boundsCheck ? 3 : 2;
180 flags |= (oob << 28);
181 }
182 Value flagsConst = createI32Constant(rewriter, loc, flags);
183 Type rsrcType =
184 LLVM::LLVMPointerType::get(rewriter.getContext(), addressSpace);
185 Value resource = rewriter.createOrFold<ROCDL::MakeBufferRsrcOp>(
186 loc, rsrcType, basePointer, stride, numRecords, flagsConst);
187 return resource;
188}
189
190namespace {
191struct FatRawBufferCastLowering
192 : public ConvertOpToLLVMPattern<FatRawBufferCastOp> {
193 FatRawBufferCastLowering(const LLVMTypeConverter &converter, Chipset chipset)
194 : ConvertOpToLLVMPattern<FatRawBufferCastOp>(converter),
195 chipset(chipset) {}
196
197 Chipset chipset;
198
199 LogicalResult
200 matchAndRewrite(FatRawBufferCastOp op, FatRawBufferCastOpAdaptor adaptor,
201 ConversionPatternRewriter &rewriter) const override {
202 Location loc = op.getLoc();
203 Value memRef = adaptor.getSource();
204 Value unconvertedMemref = op.getSource();
205 MemRefType memrefType = cast<MemRefType>(unconvertedMemref.getType());
206 MemRefDescriptor descriptor(memRef);
207
208 DataLayout dataLayout = DataLayout::closest(op);
209 int64_t elementByteWidth =
210 dataLayout.getTypeSizeInBits(memrefType.getElementType()) / 8;
211
212 int64_t unusedOffset = 0;
213 SmallVector<int64_t, 5> strideVals;
214 if (failed(memrefType.getStridesAndOffset(strideVals, unusedOffset)))
215 return op.emitOpError("Can't lower non-stride-offset memrefs");
216
217 Value numRecords = adaptor.getValidBytes();
218 if (!numRecords)
219 numRecords = getNumRecords(rewriter, loc, memrefType, descriptor,
220 strideVals, elementByteWidth);
221
222 Value basePointer =
223 adaptor.getResetOffset()
224 ? descriptor.bufferPtr(rewriter, loc, *getTypeConverter(),
225 memrefType)
226 : descriptor.alignedPtr(rewriter, loc);
227
228 Value offset = adaptor.getResetOffset()
229 ? LLVM::ConstantOp::create(rewriter, loc, getIndexType(),
230 rewriter.getIndexAttr(0))
231 : descriptor.offset(rewriter, loc);
232
233 bool hasSizes = memrefType.getRank() > 0;
234 // No need to unpack() and pack() all the individual sizes and strides,
235 // so we'll just extract the arrays.
236 Value sizes = hasSizes
237 ? LLVM::ExtractValueOp::create(rewriter, loc, descriptor,
239 : Value{};
240 Value strides =
241 hasSizes ? LLVM::ExtractValueOp::create(rewriter, loc, descriptor,
243 : Value{};
244
245 Value fatPtr = makeBufferRsrc(
246 rewriter, loc, basePointer, numRecords, adaptor.getBoundsCheck(),
247 chipset, adaptor.getCacheSwizzleStride(), /*addressSpace=*/7);
248
249 Value result = MemRefDescriptor::poison(
250 rewriter, loc,
251 getTypeConverter()->convertType(op.getResult().getType()));
252 SmallVector<int64_t> pos{kAllocatedPtrPosInMemRefDescriptor};
253 result = LLVM::InsertValueOp::create(rewriter, loc, result, fatPtr, pos);
254 result = LLVM::InsertValueOp::create(rewriter, loc, result, fatPtr,
256 result = LLVM::InsertValueOp::create(rewriter, loc, result, offset,
258 if (hasSizes) {
259 result = LLVM::InsertValueOp::create(rewriter, loc, result, sizes,
261 result = LLVM::InsertValueOp::create(rewriter, loc, result, strides,
263 }
264 rewriter.replaceOp(op, result);
265 return success();
266 }
267};
268
269/// Define lowering patterns for raw buffer ops
270template <typename GpuOp, typename Intrinsic>
271struct RawBufferOpLowering : public ConvertOpToLLVMPattern<GpuOp> {
272 RawBufferOpLowering(const LLVMTypeConverter &converter, Chipset chipset)
273 : ConvertOpToLLVMPattern<GpuOp>(converter), chipset(chipset) {}
274
275 Chipset chipset;
276 static constexpr uint32_t maxVectorOpWidth = 128;
277
278 LogicalResult
279 matchAndRewrite(GpuOp gpuOp, typename GpuOp::Adaptor adaptor,
280 ConversionPatternRewriter &rewriter) const override {
281 Location loc = gpuOp.getLoc();
282 Value memref = adaptor.getMemref();
283 Value unconvertedMemref = gpuOp.getMemref();
284 MemRefType memrefType = cast<MemRefType>(unconvertedMemref.getType());
285
286 if (chipset.majorVersion < 9)
287 return gpuOp.emitOpError("raw buffer ops require GCN or higher");
288
289 Value storeData = adaptor.getODSOperands(0)[0];
290 if (storeData == memref) // no write component to this op
291 storeData = Value();
292 Type wantedDataType;
293 if (storeData)
294 wantedDataType = storeData.getType();
295 else
296 wantedDataType = gpuOp.getODSResults(0)[0].getType();
297
298 Value atomicCmpData = Value();
299 // Operand index 1 of a load is the indices, trying to read them can crash.
300 if (storeData) {
301 Value maybeCmpData = adaptor.getODSOperands(1)[0];
302 if (maybeCmpData != memref)
303 atomicCmpData = maybeCmpData;
304 }
305
306 Type llvmWantedDataType = this->typeConverter->convertType(wantedDataType);
307
308 Type i32 = rewriter.getI32Type();
309
310 // Get the type size in bytes.
311 DataLayout dataLayout = DataLayout::closest(gpuOp);
312 int64_t elementByteWidth =
313 dataLayout.getTypeSizeInBits(memrefType.getElementType()) / 8;
314 Value byteWidthConst = createI32Constant(rewriter, loc, elementByteWidth);
315
316 // If we want to load a vector<NxT> with total size <= 32
317 // bits, use a scalar load and bitcast it. Similarly, if bitsize(T) < 32
318 // and the total load size is >= 32, use a vector load of N / (bitsize(T) /
319 // 32) x i32 and bitcast. Also, the CAS intrinsic requires integer operands,
320 // so bitcast any floats to integers.
321 Type llvmBufferValType = llvmWantedDataType;
322 if (atomicCmpData) {
323 if (auto floatType = dyn_cast<FloatType>(wantedDataType))
324 llvmBufferValType = this->getTypeConverter()->convertType(
325 rewriter.getIntegerType(floatType.getWidth()));
326 }
327 if (auto dataVector = dyn_cast<VectorType>(wantedDataType)) {
328 uint32_t vecLen = dataVector.getNumElements();
329 uint32_t elemBits =
330 dataLayout.getTypeSizeInBits(dataVector.getElementType());
331 uint32_t totalBits = elemBits * vecLen;
332 bool usePackedFp16 =
333 isa_and_present<RawBufferAtomicFaddOp>(*gpuOp) && vecLen == 2;
334 if (totalBits > maxVectorOpWidth)
335 return gpuOp.emitOpError(
336 "Total width of loads or stores must be no more than " +
337 Twine(maxVectorOpWidth) + " bits, but we call for " +
338 Twine(totalBits) +
339 " bits. This should've been caught in validation");
340 if (!usePackedFp16 && elemBits < 32) {
341 if (totalBits > 32) {
342 if (totalBits % 32 != 0)
343 return gpuOp.emitOpError("Load or store of more than 32-bits that "
344 "doesn't fit into words. Can't happen\n");
345 llvmBufferValType = this->typeConverter->convertType(
346 VectorType::get(totalBits / 32, i32));
347 } else {
348 llvmBufferValType = this->typeConverter->convertType(
349 rewriter.getIntegerType(totalBits));
350 }
351 }
352 }
353 if (auto vecType = dyn_cast<VectorType>(llvmBufferValType)) {
354 // Buffer intrinsics doesn't support 1-element vectors, cast them to
355 // scalars.
356 if (vecType.getNumElements() == 1)
357 llvmBufferValType = vecType.getElementType();
358 }
359
360 SmallVector<Value, 6> args;
361 if (storeData) {
362 if (llvmBufferValType != llvmWantedDataType) {
363 Value castForStore = LLVM::BitcastOp::create(
364 rewriter, loc, llvmBufferValType, storeData);
365 args.push_back(castForStore);
366 } else {
367 args.push_back(storeData);
368 }
369 }
370
371 if (atomicCmpData) {
372 if (llvmBufferValType != llvmWantedDataType) {
373 Value castForCmp = LLVM::BitcastOp::create(
374 rewriter, loc, llvmBufferValType, atomicCmpData);
375 args.push_back(castForCmp);
376 } else {
377 args.push_back(atomicCmpData);
378 }
379 }
380
381 // Construct buffer descriptor from memref, attributes
382 int64_t offset = 0;
383 SmallVector<int64_t, 5> strides;
384 if (failed(memrefType.getStridesAndOffset(strides, offset)))
385 return gpuOp.emitOpError("Can't lower non-stride-offset memrefs");
386
387 MemRefDescriptor memrefDescriptor(memref);
388
389 Value ptr = memrefDescriptor.bufferPtr(
390 rewriter, loc, *this->getTypeConverter(), memrefType);
391 Value numRecords = getNumRecords(
392 rewriter, loc, memrefType, memrefDescriptor, strides, elementByteWidth);
393 Value resource = makeBufferRsrc(rewriter, loc, ptr, numRecords,
394 adaptor.getBoundsCheck(), chipset);
395 args.push_back(resource);
396
397 // Indexing (voffset)
398 Value voffset = getLinearIndexI32(rewriter, loc, memrefDescriptor,
399 adaptor.getIndices(), strides);
400 if (std::optional<int32_t> indexOffset = adaptor.getIndexOffset();
401 indexOffset && *indexOffset > 0) {
402 Value extraOffsetConst = createI32Constant(rewriter, loc, *indexOffset);
403 voffset = voffset ? LLVM::AddOp::create(rewriter, loc, voffset,
404 extraOffsetConst)
405 : extraOffsetConst;
406 }
407 voffset = LLVM::MulOp::create(rewriter, loc, voffset, byteWidthConst);
408 args.push_back(voffset);
409
410 // SGPR offset.
411 Value sgprOffset = adaptor.getSgprOffset();
412 if (!sgprOffset)
413 sgprOffset = createI32Constant(rewriter, loc, 0);
414 sgprOffset = LLVM::MulOp::create(rewriter, loc, sgprOffset, byteWidthConst);
415 args.push_back(sgprOffset);
417 // bit 0: GLC = 0 (atomics drop value, less coherency)
418 // bits 1-2: SLC, DLC = 0 (similarly)
419 // bit 3: swizzled (0 for raw)
420 args.push_back(createI32Constant(rewriter, loc, 0));
422 llvm::SmallVector<Type, 1> resultTypes(gpuOp->getNumResults(),
423 llvmBufferValType);
424 Operation *lowered = Intrinsic::create(rewriter, loc, resultTypes, args,
426 if (lowered->getNumResults() == 1) {
427 Value replacement = lowered->getResult(0);
428 if (llvmBufferValType != llvmWantedDataType) {
429 replacement = LLVM::BitcastOp::create(rewriter, loc, llvmWantedDataType,
432 rewriter.replaceOp(gpuOp, replacement);
433 } else {
434 rewriter.eraseOp(gpuOp);
436 return success();
437 }
439
440// TODO: AMDGPU backend already have all this bitpacking logic, we should move
441// it to some common place.
442/// Vmcnt, Expcnt and Lgkmcnt are decoded as follows:
443/// Vmcnt = Waitcnt[3:0] (pre-gfx9)
444/// Vmcnt = Waitcnt[15:14,3:0] (gfx9,10)
445/// Vmcnt = Waitcnt[15:10] (gfx11)
446/// Expcnt = Waitcnt[6:4] (pre-gfx11)
447/// Expcnt = Waitcnt[2:0] (gfx11)
448/// Lgkmcnt = Waitcnt[11:8] (pre-gfx10)
449/// Lgkmcnt = Waitcnt[13:8] (gfx10)
450/// Lgkmcnt = Waitcnt[9:4] (gfx11)
451static FailureOr<unsigned> encodeWaitcnt(Chipset chipset, unsigned vmcnt,
452 unsigned expcnt, unsigned lgkmcnt) {
453 if (chipset.majorVersion < 9) {
454 vmcnt = std::min(15u, vmcnt);
455 expcnt = std::min(7u, expcnt);
456 lgkmcnt = std::min(15u, lgkmcnt);
457 return vmcnt | (expcnt << 4) | (lgkmcnt << 8);
458 }
459 if (chipset.majorVersion == 9) {
460 vmcnt = std::min(63u, vmcnt);
461 expcnt = std::min(7u, expcnt);
462 lgkmcnt = std::min(15u, lgkmcnt);
463 unsigned lowBits = vmcnt & 0xF;
464 unsigned highBits = (vmcnt >> 4) << 14;
465 unsigned otherCnts = (expcnt << 4) | (lgkmcnt << 8);
466 return lowBits | highBits | otherCnts;
467 }
468 if (chipset.majorVersion == 10) {
469 vmcnt = std::min(63u, vmcnt);
470 expcnt = std::min(7u, expcnt);
471 lgkmcnt = std::min(63u, lgkmcnt);
472 unsigned lowBits = vmcnt & 0xF;
473 unsigned highBits = (vmcnt >> 4) << 14;
474 unsigned otherCnts = (expcnt << 4) | (lgkmcnt << 8);
475 return lowBits | highBits | otherCnts;
476 }
477 if (chipset.majorVersion == 11) {
478 vmcnt = std::min(63u, vmcnt);
479 expcnt = std::min(7u, expcnt);
480 lgkmcnt = std::min(63u, lgkmcnt);
481 return (vmcnt << 10) | expcnt | (lgkmcnt << 4);
482 }
483 return failure();
484}
485
486struct MemoryCounterWaitOpLowering
487 : public ConvertOpToLLVMPattern<MemoryCounterWaitOp> {
488 MemoryCounterWaitOpLowering(const LLVMTypeConverter &converter,
489 Chipset chipset)
490 : ConvertOpToLLVMPattern<MemoryCounterWaitOp>(converter),
491 chipset(chipset) {}
492
493 Chipset chipset;
494
495 LogicalResult
496 matchAndRewrite(MemoryCounterWaitOp op, OpAdaptor adaptor,
497 ConversionPatternRewriter &rewriter) const override {
498 if (chipset.majorVersion >= 12) {
499 Location loc = op.getLoc();
500 if (std::optional<int> ds = adaptor.getDs())
501 ROCDL::WaitDscntOp::create(rewriter, loc, *ds);
502
503 if (std::optional<int> load = adaptor.getLoad())
504 ROCDL::WaitLoadcntOp::create(rewriter, loc, *load);
505
506 if (std::optional<int> store = adaptor.getStore())
507 ROCDL::WaitStorecntOp::create(rewriter, loc, *store);
508
509 if (std::optional<int> exp = adaptor.getExp())
510 ROCDL::WaitExpcntOp::create(rewriter, loc, *exp);
511
512 rewriter.eraseOp(op);
513 return success();
514 }
515
516 auto getVal = [](Attribute attr) -> unsigned {
517 if (attr)
518 return cast<IntegerAttr>(attr).getInt();
519
520 // This value will be clamped to the maximum value for the chipset.
521 return 1024;
522 };
523 unsigned ds = getVal(adaptor.getDsAttr());
524 unsigned exp = getVal(adaptor.getExpAttr());
525
526 unsigned vmcnt = 1024;
527 Attribute load = adaptor.getLoadAttr();
528 Attribute store = adaptor.getStoreAttr();
529 if (load && store) {
530 vmcnt = getVal(load) + getVal(store);
531 } else if (load) {
532 vmcnt = getVal(load);
533 } else if (store) {
534 vmcnt = getVal(store);
535 }
536
537 FailureOr<unsigned> waitcnt = encodeWaitcnt(chipset, vmcnt, exp, ds);
538 if (failed(waitcnt))
539 return op.emitOpError("unsupported chipset");
540
541 rewriter.replaceOpWithNewOp<ROCDL::SWaitcntOp>(op, *waitcnt);
542 return success();
543 }
544};
545
546struct LDSBarrierOpLowering : public ConvertOpToLLVMPattern<LDSBarrierOp> {
547 LDSBarrierOpLowering(const LLVMTypeConverter &converter, Chipset chipset)
548 : ConvertOpToLLVMPattern<LDSBarrierOp>(converter), chipset(chipset) {}
549
550 Chipset chipset;
551
552 LogicalResult
553 matchAndRewrite(LDSBarrierOp op, LDSBarrierOp::Adaptor adaptor,
554 ConversionPatternRewriter &rewriter) const override {
555 Location loc = op.getLoc();
556 // This ensures that waits on global memory aren't introduced on
557 // chips that don't have the BackOffBarrier feature enabled in LLVM.
558 bool requiresInlineAsm = chipset < kGfx90a;
559
560 Attribute mmra =
561 rewriter.getAttr<LLVM::MMRATagAttr>("amdgpu-synchronize-as", "local");
562 // Note: while there *is* a workgroup-one-as scope, this, when combined with
563 // the MMRA, will lead to the fence having no effect. This is because the
564 // codepaths for an atomic load or store will observe that a
565 // one-address-space atomic to LDS requires no synchronization because
566 // operations on LDS are totally ordered with respect to each other, and so
567 // will not emit the correct waitcnt operations that these fences are
568 // intended to produce. Therefore, we use a broader type of fence and rely
569 // on the MMRA to relax it to the semantics we want.
570 StringRef scope = "workgroup";
571
572 auto relFence = LLVM::FenceOp::create(rewriter, loc,
573 LLVM::AtomicOrdering::release, scope);
574 relFence->setDiscardableAttr(LLVM::LLVMDialect::getMmraAttrName(), mmra);
575 if (requiresInlineAsm) {
576 auto asmDialectAttr = LLVM::AsmDialectAttr::get(rewriter.getContext(),
577 LLVM::AsmDialect::AD_ATT);
578 const char *asmStr = ";;;WARNING: BREAKS DEBUG WATCHES\ns_barrier";
579 const char *constraints = "";
580 LLVM::InlineAsmOp::create(
581 rewriter, loc,
582 /*resultTypes=*/TypeRange(), /*operands=*/ValueRange(),
583 /*asm_string=*/asmStr, constraints, /*has_side_effects=*/true,
584 /*is_align_stack=*/false, LLVM::TailCallKind::None,
585 /*asm_dialect=*/asmDialectAttr,
586 /*operand_attrs=*/ArrayAttr());
587 } else if (chipset.majorVersion < 12) {
588 ROCDL::SBarrierOp::create(rewriter, loc);
589 } else {
590 ROCDL::BarrierSignalOp::create(rewriter, loc, -1);
591 ROCDL::BarrierWaitOp::create(rewriter, loc, -1);
592 }
593
594 auto acqFence = LLVM::FenceOp::create(rewriter, loc,
595 LLVM::AtomicOrdering::acquire, scope);
596 acqFence->setDiscardableAttr(LLVM::LLVMDialect::getMmraAttrName(), mmra);
597 rewriter.replaceOp(op, acqFence);
598 return success();
599 }
600};
601
602struct SchedBarrierOpLowering : public ConvertOpToLLVMPattern<SchedBarrierOp> {
603 SchedBarrierOpLowering(const LLVMTypeConverter &converter, Chipset chipset)
604 : ConvertOpToLLVMPattern<SchedBarrierOp>(converter), chipset(chipset) {}
605
606 Chipset chipset;
607
608 LogicalResult
609 matchAndRewrite(SchedBarrierOp op, SchedBarrierOp::Adaptor adaptor,
610 ConversionPatternRewriter &rewriter) const override {
611 rewriter.replaceOpWithNewOp<ROCDL::SchedBarrier>(op,
612 (uint32_t)op.getOpts());
613 return success();
614 }
615};
616
617} // namespace
618
619/// Converts a MFMA vector operand from MLIR AMDGPU dialect convention to ROCDL
620/// and LLVM AMDGPU intrinsics convention.
621///
622/// Specifically:
623/// 1. If the element type is bfloat16, bitcast it to i16 unless rocdl intrinsic
624/// allows bf16. Newer MFMAs support bf16 types on operand, check
625/// IntrinsicsAMDGPU.td file for reference.
626/// 2. If instead we have a more than 64-bit quantity, use a <N / 4 x i32>
627/// instead, which is what the f8f6f4 intrinsics use.
628/// 3. If `input` is a vector of N <= 8 bytes, bitcast it to a (N * 8)-bit
629/// integer.
630///
631/// Note that the type of `input` has already been LLVM type converted:
632/// therefore 8-bit and smaller floats are represented as their corresponding
633/// `iN` integers.
634static Value convertMFMAVectorOperand(ConversionPatternRewriter &rewriter,
635 Location loc, Value input,
636 bool allowBf16 = true) {
637 Type inputType = input.getType();
638 if (auto vectorType = dyn_cast<VectorType>(inputType)) {
639 if (vectorType.getElementType().isBF16() && !allowBf16)
640 return LLVM::BitcastOp::create(
641 rewriter, loc, vectorType.clone(rewriter.getI16Type()), input);
642 if (vectorType.getElementType().isInteger(8) &&
643 vectorType.getNumElements() <= 8)
644 return LLVM::BitcastOp::create(
645 rewriter, loc,
646 rewriter.getIntegerType(vectorType.getNumElements() * 8), input);
647 if (isa<IntegerType>(vectorType.getElementType()) &&
648 vectorType.getElementTypeBitWidth() <= 8) {
649 int64_t numWords = llvm::divideCeil(
650 vectorType.getNumElements() * vectorType.getElementTypeBitWidth(),
651 32);
652 return LLVM::BitcastOp::create(
653 rewriter, loc, VectorType::get(numWords, rewriter.getI32Type()),
654 input);
655 }
656 }
657 return input;
658}
659
660/// Converts the scaled MFMA operands, `scalesA` and `scalesB`, from MLIR AMDGPU
661/// dialect convention to ROCDL and LLVM AMDGPU intrinsics convention.
662///
663/// Specifically:
664/// 1. If `input` is a i8 value, zero extend it to i32
665/// 2. If `input` is a vector of length 4 and type i8, cast it to i32
666///
667/// Note that the type of `input` has already been LLVM type converted:
668/// therefore 8-bit and smaller floats are represented as their corresponding
669/// `iN` integers.
670static Value castMFMAScaleOperand(ConversionPatternRewriter &rewriter,
671 Location loc, Value input) {
672 Type inputType = input.getType();
673 Type outputType = rewriter.getI32Type();
674 if (auto intType = dyn_cast<IntegerType>(inputType))
675 return LLVM::ZExtOp::create(rewriter, loc, outputType, input);
676 return LLVM::BitcastOp::create(rewriter, loc, outputType, input);
677}
678
679/// Push an input operand. If it is a float type, nothing to do. If it is
680/// an integer type, then we need to also push its signdness (1 for signed, 0
681/// for unsigned) and we need to pack the input 16xi8 vector into a 4xi32
682/// vector (or the 8xi8 vector into a 2xi32 one for gfx12+).
683/// We also need to convert bfloat inputs to i16 to account for the bfloat
684/// intrinsics having been defined before the AMD backend supported bfloat. We
685/// similarly need to pack 8-bit float types into integers as if they were i8
686/// (which they are for the backend's purposes).
687static void wmmaPushInputOperand(ConversionPatternRewriter &rewriter,
688 Location loc,
689 const TypeConverter *typeConverter,
690 bool isUnsigned, Value llvmInput,
691 Value mlirInput,
692 SmallVector<Value, 4> &operands) {
693 Type inputType = llvmInput.getType();
694 auto vectorType = dyn_cast<VectorType>(inputType);
695 if (!vectorType) {
696 operands.push_back(llvmInput);
697 return;
698 }
699 Type elemType = vectorType.getElementType();
700
701 if (elemType.isBF16())
702 llvmInput = LLVM::BitcastOp::create(
703 rewriter, loc, vectorType.clone(rewriter.getI16Type()), llvmInput);
704 if (elemType.getIntOrFloatBitWidth() > 8) {
705 operands.push_back(llvmInput);
706 return;
707 }
708
709 // We need to check the type of the input before conversion to properly test
710 // for int8. This is because, in LLVM, fp8 type is converted to int8, so the
711 // fp8/int8 information is lost during the conversion process.
712 auto mlirInputType = cast<VectorType>(mlirInput.getType());
713 bool isInputInteger = mlirInputType.getElementType().isInteger();
714 if (isInputInteger) {
715 // if element type is 8-bit signed or unsigned, ignore the isUnsigned flag
716 bool localIsUnsigned = isUnsigned;
717 if (elemType.isUnsignedInteger()) {
718 localIsUnsigned = true;
719 } else if (elemType.isSignedInteger()) {
720 localIsUnsigned = false;
721 }
722 Value sign = createI1Constant(rewriter, loc, !localIsUnsigned);
723 operands.push_back(sign);
724 }
725
726 int64_t numBits =
727 vectorType.getNumElements() * elemType.getIntOrFloatBitWidth();
728 Type i32 = rewriter.getI32Type();
729 Type intrinsicInType = numBits <= 32
730 ? (Type)rewriter.getIntegerType(numBits)
731 : (Type)VectorType::get(numBits / 32, i32);
732 auto llvmIntrinsicInType = typeConverter->convertType(intrinsicInType);
733 Value castInput = rewriter.createOrFold<LLVM::BitcastOp>(
734 loc, llvmIntrinsicInType, llvmInput);
735 // The wave64-mode 16x16x16 intrinsics that take 4-bit integers only need
736 // (256 / 64) * 4 = 16 bits of input (on gfx12+) but take i32 arguments.
737 // Add in the zeros here.
738 if (numBits < 32)
739 castInput = LLVM::ZExtOp::create(rewriter, loc, i32, castInput);
740 operands.push_back(castInput);
741}
742
743/// Push the output operand. For many cases this is only pushing the output in
744/// the operand list. But when we have f16 -> f16 or bf16 -> bf16 intrinsics,
745/// since the same numbers of VGPRs is used, we need to decide if to store the
746/// result in the upper 16 bits of the VGPRs or in the lower part. To store the
747/// result in the lower 16 bits, set subwordOffset to 1, otherwise result will
748/// be stored it in the upper part. The subwordOffset must not be set for gfx12,
749/// as the instructions have been changed to return fewer registers instead.
750static void wmmaPushOutputOperand(ConversionPatternRewriter &rewriter,
751 Location loc,
752 const TypeConverter *typeConverter,
753 Value output, int32_t subwordOffset,
754 bool clamp, SmallVector<Value, 4> &operands) {
755 Type inputType = output.getType();
756 auto vectorType = dyn_cast<VectorType>(inputType);
757 Type elemType = vectorType.getElementType();
758 if (elemType.isBF16())
759 output = LLVM::BitcastOp::create(
760 rewriter, loc, vectorType.clone(rewriter.getI16Type()), output);
761 operands.push_back(output);
762 if (elemType.isF16() || elemType.isBF16() || elemType.isInteger(16)) {
763 operands.push_back(createI1Constant(rewriter, loc, subwordOffset));
764 } else if (elemType.isInteger(32)) {
765 operands.push_back(createI1Constant(rewriter, loc, clamp));
766 }
767}
768
769/// Return true if `type` is the E5M2 variant of an 8-bit float that is
770/// supported by the `_bf8` instructions on the given `chipset`.
771static bool typeIsExpectedBf8ForChipset(Chipset chipset, Type type) {
772 return (chipset == kGfx942 && isa<Float8E5M2FNUZType>(type)) ||
773 (hasOcpFp8(chipset) && isa<Float8E5M2Type>(type));
774}
775
776/// Return true if `type` is the E4M3FN variant of an 8-bit float that is
777/// supported by the `_fp8` instructions on the given `chipset`.
778static bool typeIsExpectedFp8ForChipset(Chipset chipset, Type type) {
779 return (chipset == kGfx942 && isa<Float8E4M3FNUZType>(type)) ||
780 (hasOcpFp8(chipset) && isa<Float8E4M3FNType>(type));
781}
782
783/// Return the `rocdl` intrinsic corresponding to a MFMA operation `mfma`
784/// if one exists. This includes checking to ensure the intrinsic is supported
785/// on the architecture you are compiling for.
786static std::optional<StringRef> mfmaOpToIntrinsic(MFMAOp mfma,
787 Chipset chipset) {
788 uint32_t m = mfma.getM(), n = mfma.getN(), k = mfma.getK(),
789 b = mfma.getBlocks();
790 Type sourceElem = getElementTypeOrSelf(mfma.getSourceA().getType());
791 Type destElem = getElementTypeOrSelf(mfma.getDestC().getType());
792
793 if (sourceElem.isF32() && destElem.isF32()) {
794 if (mfma.getReducePrecision() && chipset >= kGfx942) {
795 if (m == 32 && n == 32 && k == 4 && b == 1)
796 return ROCDL::mfma_f32_32x32x4_xf32::getOperationName();
797 if (m == 16 && n == 16 && k == 8 && b == 1)
798 return ROCDL::mfma_f32_16x16x8_xf32::getOperationName();
799 }
800 if (m == 32 && n == 32 && k == 1 && b == 2)
801 return ROCDL::mfma_f32_32x32x1f32::getOperationName();
802 if (m == 16 && n == 16 && k == 1 && b == 4)
803 return ROCDL::mfma_f32_16x16x1f32::getOperationName();
804 if (m == 4 && n == 4 && k == 1 && b == 16)
805 return ROCDL::mfma_f32_4x4x1f32::getOperationName();
806 if (m == 32 && n == 32 && k == 2 && b == 1)
807 return ROCDL::mfma_f32_32x32x2f32::getOperationName();
808 if (m == 16 && n == 16 && k == 4 && b == 1)
809 return ROCDL::mfma_f32_16x16x4f32::getOperationName();
810 }
811
812 if (sourceElem.isF16() && destElem.isF32()) {
813 if (chipset >= kGfx950) {
814 if (m == 32 && n == 32 && k == 16 && b == 1)
815 return ROCDL::mfma_f32_32x32x16_f16::getOperationName();
816 if (m == 16 && n == 16 && k == 32 && b == 1)
817 return ROCDL::mfma_f32_16x16x32_f16::getOperationName();
818 }
819 if (m == 32 && n == 32 && k == 4 && b == 2)
820 return ROCDL::mfma_f32_32x32x4f16::getOperationName();
821 if (m == 16 && n == 16 && k == 4 && b == 4)
822 return ROCDL::mfma_f32_16x16x4f16::getOperationName();
823 if (m == 4 && n == 4 && k == 4 && b == 16)
824 return ROCDL::mfma_f32_4x4x4f16::getOperationName();
825 if (m == 32 && n == 32 && k == 8 && b == 1)
826 return ROCDL::mfma_f32_32x32x8f16::getOperationName();
827 if (m == 16 && n == 16 && k == 16 && b == 1)
828 return ROCDL::mfma_f32_16x16x16f16::getOperationName();
829 }
830
831 if (sourceElem.isBF16() && destElem.isF32()) {
832 if (chipset >= kGfx950) {
833 if (m == 32 && n == 32 && k == 16 && b == 1)
834 return ROCDL::mfma_f32_32x32x16_bf16::getOperationName();
835 if (m == 16 && n == 16 && k == 32 && b == 1)
836 return ROCDL::mfma_f32_16x16x32_bf16::getOperationName();
837 }
838 if (chipset >= kGfx90a) {
839 if (m == 32 && n == 32 && k == 4 && b == 2)
840 return ROCDL::mfma_f32_32x32x4bf16_1k::getOperationName();
841 if (m == 16 && n == 16 && k == 4 && b == 4)
842 return ROCDL::mfma_f32_16x16x4bf16_1k::getOperationName();
843 if (m == 4 && n == 4 && k == 4 && b == 16)
844 return ROCDL::mfma_f32_4x4x4bf16_1k::getOperationName();
845 if (m == 32 && n == 32 && k == 8 && b == 1)
846 return ROCDL::mfma_f32_32x32x8bf16_1k::getOperationName();
847 if (m == 16 && n == 16 && k == 16 && b == 1)
848 return ROCDL::mfma_f32_16x16x16bf16_1k::getOperationName();
849 }
850 if (m == 32 && n == 32 && k == 2 && b == 2)
851 return ROCDL::mfma_f32_32x32x2bf16::getOperationName();
852 if (m == 16 && n == 16 && k == 2 && b == 4)
853 return ROCDL::mfma_f32_16x16x2bf16::getOperationName();
854 if (m == 4 && n == 4 && k == 2 && b == 16)
855 return ROCDL::mfma_f32_4x4x2bf16::getOperationName();
856 if (m == 32 && n == 32 && k == 4 && b == 1)
857 return ROCDL::mfma_f32_32x32x4bf16::getOperationName();
858 if (m == 16 && n == 16 && k == 8 && b == 1)
859 return ROCDL::mfma_f32_16x16x8bf16::getOperationName();
860 }
861
862 if (sourceElem.isInteger(8) && destElem.isInteger(32)) {
863 if (chipset >= kGfx950) {
864 if (m == 32 && n == 32 && k == 32 && b == 1)
865 return ROCDL::mfma_i32_32x32x32_i8::getOperationName();
866 if (m == 16 && n == 16 && k == 64 && b == 1)
867 return ROCDL::mfma_i32_16x16x64_i8::getOperationName();
868 }
869 if (m == 32 && n == 32 && k == 4 && b == 2)
870 return ROCDL::mfma_i32_32x32x4i8::getOperationName();
871 if (m == 16 && n == 16 && k == 4 && b == 4)
872 return ROCDL::mfma_i32_16x16x4i8::getOperationName();
873 if (m == 4 && n == 4 && k == 4 && b == 16)
874 return ROCDL::mfma_i32_4x4x4i8::getOperationName();
875 if (m == 32 && n == 32 && k == 8 && b == 1)
876 return ROCDL::mfma_i32_32x32x8i8::getOperationName();
877 if (m == 16 && n == 16 && k == 16 && b == 1)
878 return ROCDL::mfma_i32_16x16x16i8::getOperationName();
879 if (m == 32 && n == 32 && k == 16 && b == 1 && chipset >= kGfx942)
880 return ROCDL::mfma_i32_32x32x16_i8::getOperationName();
881 if (m == 16 && n == 16 && k == 32 && b == 1 && chipset >= kGfx942)
882 return ROCDL::mfma_i32_16x16x32_i8::getOperationName();
883 }
884
885 if (sourceElem.isF64() && destElem.isF64() && chipset >= kGfx90a) {
886 if (m == 16 && n == 16 && k == 4 && b == 1)
887 return ROCDL::mfma_f64_16x16x4f64::getOperationName();
888 if (m == 4 && n == 4 && k == 4 && b == 4)
889 return ROCDL::mfma_f64_4x4x4f64::getOperationName();
890 }
891
892 if (destElem.isF32() && typeIsExpectedBf8ForChipset(chipset, sourceElem)) {
893 // Known to be correct because there are no scalar f8 instructions and
894 // because a length mismatch will have been caught by the verifier.
895 Type sourceBElem =
896 cast<VectorType>(mfma.getSourceB().getType()).getElementType();
897 if (m == 16 && n == 16 && k == 32 && b == 1) {
898 if (typeIsExpectedBf8ForChipset(chipset, sourceBElem))
899 return ROCDL::mfma_f32_16x16x32_bf8_bf8::getOperationName();
900 if (typeIsExpectedFp8ForChipset(chipset, sourceBElem))
901 return ROCDL::mfma_f32_16x16x32_bf8_fp8::getOperationName();
902 }
903 if (m == 32 && n == 32 && k == 16 && b == 1) {
904 if (typeIsExpectedBf8ForChipset(chipset, sourceBElem))
905 return ROCDL::mfma_f32_32x32x16_bf8_bf8::getOperationName();
906 if (typeIsExpectedFp8ForChipset(chipset, sourceBElem))
907 return ROCDL::mfma_f32_32x32x16_bf8_fp8::getOperationName();
908 }
909 }
910
911 if (destElem.isF32() && typeIsExpectedFp8ForChipset(chipset, sourceElem)) {
912 Type sourceBElem =
913 cast<VectorType>(mfma.getSourceB().getType()).getElementType();
914 if (m == 16 && n == 16 && k == 32 && b == 1) {
915 if (typeIsExpectedBf8ForChipset(chipset, sourceBElem))
916 return ROCDL::mfma_f32_16x16x32_fp8_bf8::getOperationName();
917 if (typeIsExpectedFp8ForChipset(chipset, sourceBElem))
918 return ROCDL::mfma_f32_16x16x32_fp8_fp8::getOperationName();
919 }
920 if (m == 32 && n == 32 && k == 16 && b == 1) {
921 if (typeIsExpectedBf8ForChipset(chipset, sourceBElem))
922 return ROCDL::mfma_f32_32x32x16_fp8_bf8::getOperationName();
923 if (typeIsExpectedFp8ForChipset(chipset, sourceBElem))
924 return ROCDL::mfma_f32_32x32x16_fp8_fp8::getOperationName();
925 }
926 }
927
928 return std::nullopt;
929}
930
931static std::optional<uint32_t> mfmaTypeSelectCode(Type mlirElemType) {
933 .Case([](Float8E4M3FNType) { return 0u; })
934 .Case([](Float8E5M2Type) { return 1u; })
935 .Case([](Float6E2M3FNType) { return 2u; })
936 .Case([](Float6E3M2FNType) { return 3u; })
937 .Case([](Float4E2M1FNType) { return 4u; })
938 .Default(std::nullopt);
939}
940
941/// If there is a scaled MFMA instruction for the input element types `aType`
942/// and `bType`, output type `destType`, problem size M, N, K, and B (number of
943/// blocks) on the given `chipset`, return a tuple consisting of the
944/// OperationName of the intrinsic and the type codes that need to be passed to
945/// that intrinsic. Note that this is also used to implement some un-scaled
946/// MFMAs, since the compiler represents the ordinary instruction as a "scaled"
947/// MFMA with a scale of 0.
948static std::optional<std::tuple<StringRef, uint32_t, uint32_t>>
949mfmaOpToScaledIntrinsic(Type aType, Type bType, Type destType, uint32_t m,
950 uint32_t n, uint32_t k, uint32_t b, Chipset chipset) {
951 aType = getElementTypeOrSelf(aType);
952 bType = getElementTypeOrSelf(bType);
953 destType = getElementTypeOrSelf(destType);
954
955 if (chipset < kGfx950)
956 return std::nullopt;
957 if (!isa<Float32Type>(destType))
958 return std::nullopt;
959
960 std::optional<uint32_t> aTypeCode = mfmaTypeSelectCode(aType);
961 std::optional<uint32_t> bTypeCode = mfmaTypeSelectCode(bType);
962 if (!aTypeCode || !bTypeCode)
963 return std::nullopt;
964
965 if (m == 32 && n == 32 && k == 64 && b == 1)
966 return std::tuple{ROCDL::mfma_scale_f32_32x32x64_f8f6f4::getOperationName(),
967 *aTypeCode, *bTypeCode};
968 if (m == 16 && n == 16 && k == 128 && b == 1)
969 return std::tuple{
970 ROCDL::mfma_scale_f32_16x16x128_f8f6f4::getOperationName(), *aTypeCode,
971 *bTypeCode};
972
973 return std::nullopt;
974}
975
976static std::optional<std::tuple<StringRef, uint32_t, uint32_t>>
977mfmaOpToScaledIntrinsic(MFMAOp mfma, Chipset chipset) {
979 mfma.getSourceA().getType(), mfma.getSourceB().getType(),
980 mfma.getDestC().getType(), mfma.getM(), mfma.getN(), mfma.getK(),
981 mfma.getBlocks(), chipset);
982}
983
984static std::optional<std::tuple<StringRef, uint32_t, uint32_t>>
985mfmaOpToScaledIntrinsic(ScaledMFMAOp smfma, Chipset chipset) {
986 return mfmaOpToScaledIntrinsic(smfma.getSourceA().getType(),
987 smfma.getSourceB().getType(),
988 smfma.getDestC().getType(), smfma.getM(),
989 smfma.getN(), smfma.getK(), 1u, chipset);
990}
991
992/// Returns the `rocdl` intrinsic corresponding to a WMMA operation `wmma`
993/// for RDNA3/4 architectures.
994static std::optional<StringRef>
995wmmaOpToIntrinsicRDNA(Type elemSourceType, Type elemBSourceType,
996 Type elemDestType, uint32_t k, bool isRDNA3) {
997 using fp8 = Float8E4M3FNType;
998 using bf8 = Float8E5M2Type;
999
1000 // Handle k == 16 for RDNA3/4.
1001 if (k == 16) {
1002 // Common patterns for RDNA3 and RDNA4.
1003 if (elemSourceType.isF16() && elemDestType.isF32())
1004 return ROCDL::wmma_f32_16x16x16_f16::getOperationName();
1005 if (elemSourceType.isBF16() && elemDestType.isF32())
1006 return ROCDL::wmma_f32_16x16x16_bf16::getOperationName();
1007 if (elemSourceType.isF16() && elemDestType.isF16())
1008 return ROCDL::wmma_f16_16x16x16_f16::getOperationName();
1009 if (elemSourceType.isBF16() && elemDestType.isBF16())
1010 return ROCDL::wmma_bf16_16x16x16_bf16::getOperationName();
1011 if (elemSourceType.isInteger(8) && elemDestType.isInteger(32))
1012 return ROCDL::wmma_i32_16x16x16_iu8::getOperationName();
1013
1014 // RDNA3 specific patterns.
1015 if (isRDNA3) {
1016 if (elemSourceType.isInteger(4) && elemDestType.isInteger(32))
1017 return ROCDL::wmma_i32_16x16x16_iu4::getOperationName();
1018 return std::nullopt;
1019 }
1020
1021 // RDNA4 specific patterns (fp8/bf8).
1022 if (isa<fp8>(elemSourceType) && isa<fp8>(elemBSourceType) &&
1023 elemDestType.isF32())
1024 return ROCDL::wmma_f32_16x16x16_fp8_fp8::getOperationName();
1025 if (isa<fp8>(elemSourceType) && isa<bf8>(elemBSourceType) &&
1026 elemDestType.isF32())
1027 return ROCDL::wmma_f32_16x16x16_fp8_bf8::getOperationName();
1028 if (isa<bf8>(elemSourceType) && isa<bf8>(elemBSourceType) &&
1029 elemDestType.isF32())
1030 return ROCDL::wmma_f32_16x16x16_bf8_bf8::getOperationName();
1031 if (isa<bf8>(elemSourceType) && isa<fp8>(elemBSourceType) &&
1032 elemDestType.isF32())
1033 return ROCDL::wmma_f32_16x16x16_bf8_fp8::getOperationName();
1034 if (elemSourceType.isInteger(4) && elemDestType.isInteger(32))
1035 return ROCDL::wmma_i32_16x16x16_iu4::getOperationName();
1036
1037 return std::nullopt;
1038 }
1039
1040 // Handle k == 32 for RDNA4.
1041 if (k == 32 && !isRDNA3) {
1042 if (elemSourceType.isInteger(4) && elemDestType.isInteger(32))
1043 return ROCDL::wmma_i32_16x16x32_iu4::getOperationName();
1044 }
1045
1046 return std::nullopt;
1047}
1048
1049/// Return the `rocdl` intrinsic corresponding to a WMMA operation `wmma`
1050/// for the gfx1250 architecture.
1051static std::optional<StringRef> wmmaOpToIntrinsicGfx1250(Type elemSourceType,
1052 Type elemBSourceType,
1053 Type elemDestType,
1054 uint32_t k) {
1055 using fp8 = Float8E4M3FNType;
1056 using bf8 = Float8E5M2Type;
1057
1058 if (k == 4) {
1059 if (elemSourceType.isF32() && elemDestType.isF32())
1060 return ROCDL::wmma_f32_16x16x4_f32::getOperationName();
1061
1062 return std::nullopt;
1063 }
1064
1065 if (k == 32) {
1066 if (elemSourceType.isF16() && elemDestType.isF32())
1067 return ROCDL::wmma_f32_16x16x32_f16::getOperationName();
1068 if (elemSourceType.isBF16() && elemDestType.isF32())
1069 return ROCDL::wmma_f32_16x16x32_bf16::getOperationName();
1070 if (elemSourceType.isF16() && elemDestType.isF16())
1071 return ROCDL::wmma_f16_16x16x32_f16::getOperationName();
1072 if (elemSourceType.isBF16() && elemDestType.isBF16())
1073 return ROCDL::wmma_bf16_16x16x32_bf16::getOperationName();
1074
1075 return std::nullopt;
1076 }
1077
1078 if (k == 64) {
1079 if (isa<fp8>(elemSourceType) && isa<fp8>(elemBSourceType)) {
1080 if (elemDestType.isF32())
1081 return ROCDL::wmma_f32_16x16x64_fp8_fp8::getOperationName();
1082 if (elemDestType.isF16())
1083 return ROCDL::wmma_f16_16x16x64_fp8_fp8::getOperationName();
1084 }
1085 if (isa<fp8>(elemSourceType) && isa<bf8>(elemBSourceType)) {
1086 if (elemDestType.isF32())
1087 return ROCDL::wmma_f32_16x16x64_fp8_bf8::getOperationName();
1088 if (elemDestType.isF16())
1089 return ROCDL::wmma_f16_16x16x64_fp8_bf8::getOperationName();
1090 }
1091 if (isa<bf8>(elemSourceType) && isa<bf8>(elemBSourceType)) {
1092 if (elemDestType.isF32())
1093 return ROCDL::wmma_f32_16x16x64_bf8_bf8::getOperationName();
1094 if (elemDestType.isF16())
1095 return ROCDL::wmma_f16_16x16x64_bf8_bf8::getOperationName();
1096 }
1097 if (isa<bf8>(elemSourceType) && isa<fp8>(elemBSourceType)) {
1098 if (elemDestType.isF32())
1099 return ROCDL::wmma_f32_16x16x64_bf8_fp8::getOperationName();
1100 if (elemDestType.isF16())
1101 return ROCDL::wmma_f16_16x16x64_bf8_fp8::getOperationName();
1102 }
1103 if (elemSourceType.isInteger(8) && elemDestType.isInteger(32))
1104 return ROCDL::wmma_i32_16x16x64_iu8::getOperationName();
1105
1106 return std::nullopt;
1107 }
1108
1109 if (k == 128) {
1110 if (isa<fp8>(elemSourceType) && isa<fp8>(elemBSourceType)) {
1111 if (elemDestType.isF32())
1112 return ROCDL::wmma_f32_16x16x128_fp8_fp8::getOperationName();
1113 if (elemDestType.isF16())
1114 return ROCDL::wmma_f16_16x16x128_fp8_fp8::getOperationName();
1115 }
1116 if (isa<fp8>(elemSourceType) && isa<bf8>(elemBSourceType)) {
1117 if (elemDestType.isF32())
1118 return ROCDL::wmma_f32_16x16x128_fp8_bf8::getOperationName();
1119 if (elemDestType.isF16())
1120 return ROCDL::wmma_f16_16x16x128_fp8_bf8::getOperationName();
1121 }
1122 if (isa<bf8>(elemSourceType) && isa<bf8>(elemBSourceType)) {
1123 if (elemDestType.isF32())
1124 return ROCDL::wmma_f32_16x16x128_bf8_bf8::getOperationName();
1125 if (elemDestType.isF16())
1126 return ROCDL::wmma_f16_16x16x128_bf8_bf8::getOperationName();
1127 }
1128 if (isa<bf8>(elemSourceType) && isa<fp8>(elemBSourceType)) {
1129 if (elemDestType.isF32())
1130 return ROCDL::wmma_f32_16x16x128_bf8_fp8::getOperationName();
1131 if (elemDestType.isF16())
1132 return ROCDL::wmma_f16_16x16x128_bf8_fp8::getOperationName();
1133 }
1134
1135 return std::nullopt;
1136 }
1137
1138 return std::nullopt;
1139}
1140
1141/// Returns the `rocdl` intrinsic corresponding to a WMMA operation `wmma`
1142/// if one exists. This includes checking to ensure the intrinsic is supported
1143/// on the architecture you are compiling for.
1144static std::optional<StringRef> wmmaOpToIntrinsic(WMMAOp wmma,
1145 Chipset chipset) {
1146 auto sourceVectorType = cast<VectorType>(wmma.getSourceA().getType());
1147 auto sourceBVectorType = cast<VectorType>(wmma.getSourceB().getType());
1148 auto destVectorType = cast<VectorType>(wmma.getDestC().getType());
1149 Type elemSourceType = sourceVectorType.getElementType();
1150 Type elemBSourceType = sourceBVectorType.getElementType();
1151 Type elemDestType = destVectorType.getElementType();
1152
1153 const uint32_t k = wmma.getK();
1154 const bool isRDNA3 = chipset.majorVersion == 11;
1155 const bool isRDNA4 = chipset.majorVersion == 12 && chipset.minorVersion == 0;
1156
1157 // Handle RDNA3 and RDNA4.
1158 if (isRDNA3 || isRDNA4)
1159 return wmmaOpToIntrinsicRDNA(elemSourceType, elemBSourceType, elemDestType,
1160 k, isRDNA3);
1161
1162 // Handle gfx1250.
1163 if (chipset == Chipset{12, 5, 0})
1164 return wmmaOpToIntrinsicGfx1250(elemSourceType, elemBSourceType,
1165 elemDestType, k);
1166
1167 return std::nullopt;
1168}
1169
1170namespace {
1171struct MFMAOpLowering : public ConvertOpToLLVMPattern<MFMAOp> {
1172 MFMAOpLowering(const LLVMTypeConverter &converter, Chipset chipset)
1173 : ConvertOpToLLVMPattern<MFMAOp>(converter), chipset(chipset) {}
1174
1175 Chipset chipset;
1176
1177 LogicalResult
1178 matchAndRewrite(MFMAOp op, MFMAOpAdaptor adaptor,
1179 ConversionPatternRewriter &rewriter) const override {
1180 Location loc = op.getLoc();
1181 Type outType = typeConverter->convertType(op.getDestD().getType());
1182 Type intrinsicOutType = outType;
1183 if (auto outVecType = dyn_cast<VectorType>(outType))
1184 if (outVecType.getElementType().isBF16())
1185 intrinsicOutType = outVecType.clone(rewriter.getI16Type());
1186
1187 if (chipset.majorVersion != 9 || chipset < kGfx908)
1188 return op->emitOpError("MFMA only supported on gfx908+");
1189 uint32_t getBlgpField = static_cast<uint32_t>(op.getBlgp());
1190 if (op.getNegateA() || op.getNegateB() || op.getNegateC()) {
1191 if (chipset < kGfx942)
1192 return op.emitOpError("negation unsupported on older than gfx942");
1193 getBlgpField |=
1194 op.getNegateA() | (op.getNegateB() << 1) | (op.getNegateC() << 2);
1195 }
1196 std::optional<StringRef> maybeIntrinsic = mfmaOpToIntrinsic(op, chipset);
1197 std::optional<std::tuple<StringRef, uint32_t, uint32_t>>
1198 maybeScaledIntrinsic = mfmaOpToScaledIntrinsic(op, chipset);
1199 if (!maybeIntrinsic.has_value() && !maybeScaledIntrinsic.has_value())
1200 return op.emitOpError("no intrinsic matching MFMA size on given chipset");
1201
1202 bool isScaled =
1203 !maybeIntrinsic.has_value() && maybeScaledIntrinsic.has_value();
1204 if (isScaled &&
1205 (adaptor.getAbid() > 0 || getBlgpField > 0 || op.getCbsz() > 0)) {
1206 return op.emitOpError(
1207 "non-default abid, blgp, and cbsz aren't supported on MFMAs that can "
1208 "be scaled as those fields are used for type information");
1209 }
1210
1211 StringRef intrinsicName =
1212 isScaled ? std::get<0>(*maybeScaledIntrinsic) : *maybeIntrinsic;
1213 // Determine if we can use bf16 in the intrinsic. Newer MFMAs in gfx950+
1214 // allows bf16 as the input. For reference check IntrinsicsAMDGPU.td file.
1215 bool allowBf16 = [&]() {
1216 if (chipset < kGfx950)
1217 return false;
1218 if (isScaled)
1219 return true;
1220 return intrinsicName.contains("16x16x32.bf16") ||
1221 intrinsicName.contains("32x32x16.bf16");
1222 }();
1223 OperationState loweredOp(loc, intrinsicName);
1224 loweredOp.addTypes(intrinsicOutType);
1225 loweredOp.addOperands({convertMFMAVectorOperand(
1226 rewriter, loc, adaptor.getSourceA(), allowBf16),
1228 rewriter, loc, adaptor.getSourceB(), allowBf16),
1229 adaptor.getDestC()});
1230 if (isScaled) {
1231 Value zero = createI32Constant(rewriter, loc, 0);
1232 auto [_scaledName, aTypeCode, bTypeCode] = *maybeScaledIntrinsic;
1233 loweredOp.addOperands({createI32Constant(rewriter, loc, aTypeCode),
1234 createI32Constant(rewriter, loc, bTypeCode),
1235 /*scale A byte=*/zero, /*scale A=*/zero,
1236 /*scale B byte=*/zero, /*scale B=*/zero});
1237 } else {
1238 loweredOp.addOperands({createI32Constant(rewriter, loc, op.getCbsz()),
1239 createI32Constant(rewriter, loc, op.getAbid()),
1240 createI32Constant(rewriter, loc, getBlgpField)});
1241 };
1242 Value lowered = rewriter.create(loweredOp)->getResult(0);
1243 if (outType != intrinsicOutType)
1244 lowered = LLVM::BitcastOp::create(rewriter, loc, outType, lowered);
1245 rewriter.replaceOp(op, lowered);
1246 return success();
1247 }
1248};
1249
1250struct ScaledMFMAOpLowering : public ConvertOpToLLVMPattern<ScaledMFMAOp> {
1251 ScaledMFMAOpLowering(const LLVMTypeConverter &converter, Chipset chipset)
1252 : ConvertOpToLLVMPattern(converter), chipset(chipset) {}
1253
1254 Chipset chipset;
1255
1256 LogicalResult
1257 matchAndRewrite(ScaledMFMAOp op, ScaledMFMAOpAdaptor adaptor,
1258 ConversionPatternRewriter &rewriter) const override {
1259 Location loc = op.getLoc();
1260 Type intrinsicOutType = typeConverter->convertType(op.getDestD().getType());
1261
1262 if (chipset.majorVersion != 9 || chipset < kGfx950)
1263 return op->emitOpError("scaled MFMA only supported on gfx908+");
1264 std::optional<std::tuple<StringRef, uint32_t, uint32_t>>
1265 maybeScaledIntrinsic = mfmaOpToScaledIntrinsic(op, chipset);
1266 if (!maybeScaledIntrinsic.has_value())
1267 return op.emitOpError(
1268 "no intrinsic matching scaled MFMA size on given chipset");
1269
1270 auto [intrinsicName, aTypeCode, bTypeCode] = *maybeScaledIntrinsic;
1271 OperationState loweredOp(loc, intrinsicName);
1272 loweredOp.addTypes(intrinsicOutType);
1273 loweredOp.addOperands(
1274 {convertMFMAVectorOperand(rewriter, loc, adaptor.getSourceA()),
1275 convertMFMAVectorOperand(rewriter, loc, adaptor.getSourceB()),
1276 adaptor.getDestC()});
1277 Value scalesIdxA =
1278 createI32Constant(rewriter, loc, adaptor.getScalesIdxA());
1279 Value scalesIdxB =
1280 createI32Constant(rewriter, loc, adaptor.getScalesIdxB());
1281 loweredOp.addOperands(
1282 {createI32Constant(rewriter, loc, aTypeCode),
1283 createI32Constant(rewriter, loc, bTypeCode),
1284 /*scales idx A=*/scalesIdxA,
1285 /*scales A*/
1286 castMFMAScaleOperand(rewriter, loc, adaptor.getScalesA()),
1287 /*scales idx B=*/scalesIdxB,
1288 /*scales B*/
1289 castMFMAScaleOperand(rewriter, loc, adaptor.getScalesB())});
1290 Value lowered = rewriter.create(loweredOp)->getResult(0);
1291 rewriter.replaceOp(op, lowered);
1292 return success();
1293 }
1294};
1295
1296struct WMMAOpLowering : public ConvertOpToLLVMPattern<WMMAOp> {
1297 WMMAOpLowering(const LLVMTypeConverter &converter, Chipset chipset)
1298 : ConvertOpToLLVMPattern<WMMAOp>(converter), chipset(chipset) {}
1299
1300 Chipset chipset;
1301
1302 LogicalResult
1303 matchAndRewrite(WMMAOp op, WMMAOpAdaptor adaptor,
1304 ConversionPatternRewriter &rewriter) const override {
1305 Location loc = op.getLoc();
1306 auto outType =
1307 typeConverter->convertType<VectorType>(op.getDestD().getType());
1308 if (!outType)
1309 return rewriter.notifyMatchFailure(op, "type conversion failed");
1310
1311 if (chipset.majorVersion != 11 && chipset.majorVersion != 12)
1312 return op->emitOpError("WMMA only supported on gfx11 and gfx12");
1313
1314 // The WMMA operations represent vectors of bf16s as vectors of i16s, so we
1315 // need to bitcast bfloats to i16 and then bitcast them back.
1316 VectorType rawOutType = outType;
1317 if (outType.getElementType().isBF16())
1318 rawOutType = outType.clone(rewriter.getI16Type());
1319
1320 std::optional<StringRef> maybeIntrinsic = wmmaOpToIntrinsic(op, chipset);
1321
1322 if (!maybeIntrinsic.has_value())
1323 return op.emitOpError("no intrinsic matching WMMA on the given chipset");
1324
1325 if (chipset.majorVersion >= 12 && op.getSubwordOffset() != 0)
1326 return op.emitOpError("subwordOffset not supported on gfx12+");
1327
1328 OperationState loweredOp(loc, *maybeIntrinsic);
1329 loweredOp.addTypes(rawOutType);
1330
1331 SmallVector<Value, 4> operands;
1332 wmmaPushInputOperand(rewriter, loc, typeConverter, op.getUnsignedA(),
1333 adaptor.getSourceA(), op.getSourceA(), operands);
1334 wmmaPushInputOperand(rewriter, loc, typeConverter, op.getUnsignedB(),
1335 adaptor.getSourceB(), op.getSourceB(), operands);
1336 wmmaPushOutputOperand(rewriter, loc, typeConverter, adaptor.getDestC(),
1337 op.getSubwordOffset(), op.getClamp(), operands);
1338
1339 loweredOp.addOperands(operands);
1340 Operation *lowered = rewriter.create(loweredOp);
1341
1342 Operation *maybeCastBack = lowered;
1343 if (rawOutType != outType)
1344 maybeCastBack = LLVM::BitcastOp::create(rewriter, loc, outType,
1345 lowered->getResult(0));
1346 rewriter.replaceOp(op, maybeCastBack->getResults());
1347
1348 return success();
1349 }
1350};
1351
1352struct TransposeLoadOpLowering
1353 : public ConvertOpToLLVMPattern<TransposeLoadOp> {
1354 TransposeLoadOpLowering(const LLVMTypeConverter &converter, Chipset chipset)
1355 : ConvertOpToLLVMPattern<TransposeLoadOp>(converter), chipset(chipset) {}
1356
1357 Chipset chipset;
1358
1359 LogicalResult
1360 matchAndRewrite(TransposeLoadOp op, TransposeLoadOpAdaptor adaptor,
1361 ConversionPatternRewriter &rewriter) const override {
1362 if (chipset != kGfx950)
1363 return op.emitOpError("Non-gfx950 chipset not supported");
1364
1365 Location loc = op.getLoc();
1366 auto srcMemRefType = cast<MemRefType>(op.getSrc().getType());
1367
1368 // Elements in subbyte memrefs are stored non-contiguously,
1369 // reject if source is sub-byte memref. Use emulated memrefs instead.
1370 size_t srcElementSize =
1371 srcMemRefType.getElementType().getIntOrFloatBitWidth();
1372 if (srcElementSize < 8)
1373 return op.emitOpError("Expect source memref to have at least 8 bits "
1374 "element size, got ")
1375 << srcElementSize;
1376
1377 auto resultType = cast<VectorType>(op.getResult().getType());
1378 Value srcPtr =
1379 getStridedElementPtr(rewriter, loc, srcMemRefType, adaptor.getSrc(),
1380 (adaptor.getSrcIndices()));
1381
1382 size_t numElements = resultType.getNumElements();
1383 size_t elementTypeSize =
1384 resultType.getElementType().getIntOrFloatBitWidth();
1385
1386 // ROCDL transpose load intrinsics return vectors of 32-bit integers, if
1387 // the element size is smaller than 16 bits.
1388 Type rocdlResultType = VectorType::get((numElements * elementTypeSize) / 32,
1389 rewriter.getIntegerType(32));
1390 Type llvmResultType = typeConverter->convertType(resultType);
1391
1392 switch (elementTypeSize) {
1393 case 4: {
1394 assert(numElements == 16);
1395 auto rocdlOp = ROCDL::ds_read_tr4_b64::create(rewriter, loc,
1396 rocdlResultType, srcPtr);
1397 rewriter.replaceOpWithNewOp<LLVM::BitcastOp>(op, llvmResultType, rocdlOp);
1398 break;
1399 }
1400 case 6: {
1401 assert(numElements == 16);
1402 auto rocdlOp = ROCDL::ds_read_tr6_b96::create(rewriter, loc,
1403 rocdlResultType, srcPtr);
1404 rewriter.replaceOpWithNewOp<LLVM::BitcastOp>(op, llvmResultType, rocdlOp);
1405 break;
1406 }
1407 case 8: {
1408 assert(numElements == 8);
1409 auto rocdlOp = ROCDL::ds_read_tr8_b64::create(rewriter, loc,
1410 rocdlResultType, srcPtr);
1411 rewriter.replaceOpWithNewOp<LLVM::BitcastOp>(op, llvmResultType, rocdlOp);
1412 break;
1413 }
1414 case 16: {
1415 assert(numElements == 4);
1416 rewriter.replaceOpWithNewOp<ROCDL::ds_read_tr16_b64>(op, llvmResultType,
1417 srcPtr);
1418 break;
1419 }
1420 default:
1421 return op.emitOpError("Unsupported element size for transpose load");
1422 }
1423 return success();
1424 }
1425};
1426
1427struct GatherToLDSOpLowering : public ConvertOpToLLVMPattern<GatherToLDSOp> {
1428 GatherToLDSOpLowering(const LLVMTypeConverter &converter, Chipset chipset)
1429 : ConvertOpToLLVMPattern<GatherToLDSOp>(converter), chipset(chipset) {}
1430
1431 Chipset chipset;
1432
1433 LogicalResult
1434 matchAndRewrite(GatherToLDSOp op, GatherToLDSOpAdaptor adaptor,
1435 ConversionPatternRewriter &rewriter) const override {
1436 if (chipset.majorVersion < 9 || chipset.majorVersion > 10)
1437 return op.emitOpError("pre-gfx9 and post-gfx10 not supported");
1438
1439 Location loc = op.getLoc();
1440
1441 auto srcMemRefType = cast<MemRefType>(op.getSrc().getType());
1442 auto dstMemRefType = cast<MemRefType>(op.getDst().getType());
1443
1444 // TODO: instead of only transfering one element per thread, we could
1445 // augment it to transfer multiple elements per thread by issuing multiple
1446 // `global_load_lds` instructions.
1447 Type transferType = op.getTransferType();
1448 int loadWidth = [&]() -> int {
1449 if (auto transferVectorType = dyn_cast<VectorType>(transferType)) {
1450 return (transferVectorType.getNumElements() *
1451 transferVectorType.getElementTypeBitWidth()) /
1452 8;
1453 }
1454 return transferType.getIntOrFloatBitWidth() / 8;
1455 }();
1456
1457 // Currently only 1, 2, 4, 12 and 16 byte loads are supported.
1458 if (!llvm::is_contained({1, 2, 4, 12, 16}, loadWidth))
1459 return op.emitOpError("chipset unsupported element size");
1460
1461 if (chipset != kGfx950 && llvm::is_contained({12, 16}, loadWidth))
1462 return op.emitOpError("Gather to LDS instructions with 12-byte and "
1463 "16-byte load widths are only supported on gfx950");
1464
1465 Value srcPtr =
1466 getStridedElementPtr(rewriter, loc, srcMemRefType, adaptor.getSrc(),
1467 (adaptor.getSrcIndices()));
1468 Value dstPtr =
1469 getStridedElementPtr(rewriter, loc, dstMemRefType, adaptor.getDst(),
1470 (adaptor.getDstIndices()));
1471
1472 rewriter.replaceOpWithNewOp<ROCDL::LoadToLDSOp>(
1473 op, srcPtr, dstPtr, rewriter.getI32IntegerAttr(loadWidth),
1474 /*offset=*/rewriter.getI32IntegerAttr(0),
1475 /*aux=*/rewriter.getI32IntegerAttr(0), ArrayAttr{}, ArrayAttr{},
1476 ArrayAttr{});
1477
1478 return success();
1479 }
1480};
1481
1482namespace {
1483struct ExtPackedFp8OpLowering final
1484 : public ConvertOpToLLVMPattern<ExtPackedFp8Op> {
1485 ExtPackedFp8OpLowering(const LLVMTypeConverter &converter, Chipset chipset)
1486 : ConvertOpToLLVMPattern<amdgpu::ExtPackedFp8Op>(converter),
1487 chipset(chipset) {}
1488 Chipset chipset;
1489
1490 LogicalResult
1491 matchAndRewrite(ExtPackedFp8Op op, ExtPackedFp8OpAdaptor adaptor,
1492 ConversionPatternRewriter &rewriter) const override;
1493};
1494
1495struct PackedTrunc2xFp8OpLowering final
1496 : public ConvertOpToLLVMPattern<PackedTrunc2xFp8Op> {
1497 PackedTrunc2xFp8OpLowering(const LLVMTypeConverter &converter,
1498 Chipset chipset)
1499 : ConvertOpToLLVMPattern<amdgpu::PackedTrunc2xFp8Op>(converter),
1500 chipset(chipset) {}
1501 Chipset chipset;
1502
1503 LogicalResult
1504 matchAndRewrite(PackedTrunc2xFp8Op op, PackedTrunc2xFp8OpAdaptor adaptor,
1505 ConversionPatternRewriter &rewriter) const override;
1506};
1507
1508struct PackedStochRoundFp8OpLowering final
1509 : public ConvertOpToLLVMPattern<PackedStochRoundFp8Op> {
1510 PackedStochRoundFp8OpLowering(const LLVMTypeConverter &converter,
1511 Chipset chipset)
1512 : ConvertOpToLLVMPattern<amdgpu::PackedStochRoundFp8Op>(converter),
1513 chipset(chipset) {}
1514 Chipset chipset;
1515
1516 LogicalResult
1517 matchAndRewrite(PackedStochRoundFp8Op op,
1518 PackedStochRoundFp8OpAdaptor adaptor,
1519 ConversionPatternRewriter &rewriter) const override;
1520};
1521
1522struct ScaledExtPackedOpLowering final
1523 : public ConvertOpToLLVMPattern<ScaledExtPackedOp> {
1524 ScaledExtPackedOpLowering(const LLVMTypeConverter &converter, Chipset chipset)
1525 : ConvertOpToLLVMPattern<amdgpu::ScaledExtPackedOp>(converter),
1526 chipset(chipset) {}
1527 Chipset chipset;
1528
1529 LogicalResult
1530 matchAndRewrite(ScaledExtPackedOp op, ScaledExtPackedOpAdaptor adaptor,
1531 ConversionPatternRewriter &rewriter) const override;
1532};
1533
1534struct PackedScaledTruncOpLowering final
1535 : public ConvertOpToLLVMPattern<PackedScaledTruncOp> {
1536 PackedScaledTruncOpLowering(const LLVMTypeConverter &converter,
1537 Chipset chipset)
1538 : ConvertOpToLLVMPattern<amdgpu::PackedScaledTruncOp>(converter),
1539 chipset(chipset) {}
1540 Chipset chipset;
1541
1542 LogicalResult
1543 matchAndRewrite(PackedScaledTruncOp op, PackedScaledTruncOpAdaptor adaptor,
1544 ConversionPatternRewriter &rewriter) const override;
1545};
1546
1547} // end namespace
1548
1549LogicalResult ExtPackedFp8OpLowering::matchAndRewrite(
1550 ExtPackedFp8Op op, ExtPackedFp8OpAdaptor adaptor,
1551 ConversionPatternRewriter &rewriter) const {
1552 Location loc = op.getLoc();
1553 if (!(chipset == kGfx942 || hasOcpFp8(chipset)))
1554 return rewriter.notifyMatchFailure(
1555 loc, "Fp8 conversion instructions are not available on target "
1556 "architecture and their emulation is not implemented");
1557 Type v4i8 =
1558 getTypeConverter()->convertType(VectorType::get(4, rewriter.getI8Type()));
1559 Type i32 = getTypeConverter()->convertType(rewriter.getI32Type());
1560 Type f32 = getTypeConverter()->convertType(op.getResult().getType());
1561
1562 Value source = adaptor.getSource();
1563 auto sourceVecType = dyn_cast<VectorType>(op.getSource().getType());
1564 auto resultVecType = dyn_cast<VectorType>(op.getResult().getType());
1565 Type sourceElemType = getElementTypeOrSelf(op.getSource());
1566 // Extend to a v4i8
1567 if (!sourceVecType || sourceVecType.getNumElements() < 4) {
1568 Value longVec = LLVM::UndefOp::create(rewriter, loc, v4i8);
1569 if (!sourceVecType) {
1570 longVec = LLVM::InsertElementOp::create(
1571 rewriter, loc, longVec, source, createI32Constant(rewriter, loc, 0));
1572 } else {
1573 for (int32_t i = 0, e = sourceVecType.getNumElements(); i < e; ++i) {
1574 Value idx = createI32Constant(rewriter, loc, i);
1575 Value elem = LLVM::ExtractElementOp::create(rewriter, loc, source, idx);
1576 longVec =
1577 LLVM::InsertElementOp::create(rewriter, loc, longVec, elem, idx);
1578 }
1579 }
1580 source = longVec;
1581 }
1582 Value i32Source = LLVM::BitcastOp::create(rewriter, loc, i32, source);
1583 if (resultVecType) {
1584 if (typeIsExpectedBf8ForChipset(chipset, sourceElemType)) {
1585 rewriter.replaceOpWithNewOp<ROCDL::CvtPkF32Bf8Op>(op, f32, i32Source,
1586 op.getIndex());
1587 } else if (typeIsExpectedFp8ForChipset(chipset, sourceElemType)) {
1588 rewriter.replaceOpWithNewOp<ROCDL::CvtPkF32Fp8Op>(op, f32, i32Source,
1589 op.getIndex());
1590 }
1591 } else {
1592 if (typeIsExpectedBf8ForChipset(chipset, sourceElemType)) {
1593 rewriter.replaceOpWithNewOp<ROCDL::CvtF32Bf8Op>(op, f32, i32Source,
1594 op.getIndex());
1595 } else if (typeIsExpectedFp8ForChipset(chipset, sourceElemType)) {
1596 rewriter.replaceOpWithNewOp<ROCDL::CvtF32Fp8Op>(op, f32, i32Source,
1597 op.getIndex());
1598 }
1599 }
1600 return success();
1601}
1602
1603LogicalResult ScaledExtPackedOpLowering::matchAndRewrite(
1604 ScaledExtPackedOp op, ScaledExtPackedOpAdaptor adaptor,
1605 ConversionPatternRewriter &rewriter) const {
1606 Location loc = op.getLoc();
1607 if (chipset != kGfx950)
1608 return rewriter.notifyMatchFailure(
1609 loc, "Scaled fp conversion instructions are not available on target "
1610 "architecture and their emulation is not implemented");
1611 Type i32 = getTypeConverter()->convertType(rewriter.getI32Type());
1612
1613 Value source = adaptor.getSource();
1614 Value scale = adaptor.getScale();
1615
1616 VectorType sourceVecType = cast<VectorType>(op.getSource().getType());
1617 Type sourceElemType = sourceVecType.getElementType();
1618 VectorType destVecType = cast<VectorType>(op.getResult().getType());
1619 Type destElemType = destVecType.getElementType();
1620
1621 VectorType packedVecType;
1622 if (isa<Float8E5M2Type, Float8E4M3FNType>(sourceElemType)) {
1623 VectorType v4i8 = VectorType::get(4, rewriter.getI8Type());
1624 packedVecType = cast<VectorType>(getTypeConverter()->convertType(v4i8));
1625 } else if (isa<Float4E2M1FNType>(sourceElemType)) {
1626 VectorType v8i4 = VectorType::get(8, rewriter.getI4Type());
1627 packedVecType = cast<VectorType>(getTypeConverter()->convertType(v8i4));
1628 } else {
1629 llvm_unreachable("invalid element type for scaled ext");
1630 }
1631
1632 // Extend to a packedVectorType
1633 if (sourceVecType.getNumElements() < packedVecType.getNumElements()) {
1634 Value longVec = LLVM::ZeroOp::create(rewriter, loc, packedVecType);
1635 if (!sourceVecType) {
1636 longVec = LLVM::InsertElementOp::create(
1637 rewriter, loc, longVec, source, createI32Constant(rewriter, loc, 0));
1638 } else {
1639 for (int32_t i = 0, e = sourceVecType.getNumElements(); i < e; ++i) {
1640 Value idx = createI32Constant(rewriter, loc, i);
1641 Value elem = LLVM::ExtractElementOp::create(rewriter, loc, source, idx);
1642 longVec =
1643 LLVM::InsertElementOp::create(rewriter, loc, longVec, elem, idx);
1644 }
1645 }
1646 source = longVec;
1647 }
1648 Value i32Source = LLVM::BitcastOp::create(rewriter, loc, i32, source);
1649
1650 if (isa<Float8E5M2Type>(sourceElemType) && destElemType.isF32())
1651 rewriter.replaceOpWithNewOp<ROCDL::CvtScaleF32PkF32Bf8Op>(
1652 op, destVecType, i32Source, scale, op.getIndex());
1653 else if (isa<Float8E5M2Type>(sourceElemType) && destElemType.isF16())
1654 rewriter.replaceOpWithNewOp<ROCDL::CvtScaleF32PkF16Bf8Op>(
1655 op, destVecType, i32Source, scale, op.getIndex());
1656 else if (isa<Float8E5M2Type>(sourceElemType) && destElemType.isBF16())
1657 rewriter.replaceOpWithNewOp<ROCDL::CvtScaleF32PkBf16Bf8Op>(
1658 op, destVecType, i32Source, scale, op.getIndex());
1659 else if (isa<Float8E4M3FNType>(sourceElemType) && destElemType.isF32())
1660 rewriter.replaceOpWithNewOp<ROCDL::CvtScaleF32PkF32Fp8Op>(
1661 op, destVecType, i32Source, scale, op.getIndex());
1662 else if (isa<Float8E4M3FNType>(sourceElemType) && destElemType.isF16())
1663 rewriter.replaceOpWithNewOp<ROCDL::CvtScaleF32PkF16Fp8Op>(
1664 op, destVecType, i32Source, scale, op.getIndex());
1665 else if (isa<Float8E4M3FNType>(sourceElemType) && destElemType.isBF16())
1666 rewriter.replaceOpWithNewOp<ROCDL::CvtScaleF32PkBf16Fp8Op>(
1667 op, destVecType, i32Source, scale, op.getIndex());
1668 else if (isa<Float4E2M1FNType>(sourceElemType) && destElemType.isF32())
1669 rewriter.replaceOpWithNewOp<ROCDL::CvtScaleF32PkF32Fp4Op>(
1670 op, destVecType, i32Source, scale, op.getIndex());
1671 else if (isa<Float4E2M1FNType>(sourceElemType) && destElemType.isF16())
1672 rewriter.replaceOpWithNewOp<ROCDL::CvtScaleF32PkF16Fp4Op>(
1673 op, destVecType, i32Source, scale, op.getIndex());
1674 else if (isa<Float4E2M1FNType>(sourceElemType) && destElemType.isBF16())
1675 rewriter.replaceOpWithNewOp<ROCDL::CvtScaleF32PkBf16Fp4Op>(
1676 op, destVecType, i32Source, scale, op.getIndex());
1677 else
1678 return failure();
1679
1680 return success();
1681}
1682
1683LogicalResult PackedScaledTruncOpLowering::matchAndRewrite(
1684 PackedScaledTruncOp op, PackedScaledTruncOpAdaptor adaptor,
1685 ConversionPatternRewriter &rewriter) const {
1686 Location loc = op.getLoc();
1687 if (chipset != kGfx950)
1688 return rewriter.notifyMatchFailure(
1689 loc, "Scaled fp conversion instructions are not available on target "
1690 "architecture and their emulation is not implemented");
1691 Type v2i16 = getTypeConverter()->convertType(
1692 VectorType::get(2, rewriter.getI16Type()));
1693 Type i32 = getTypeConverter()->convertType(rewriter.getI32Type());
1694
1695 Type resultType = op.getResult().getType();
1696 Type resultElemType = getElementTypeOrSelf(resultType);
1697 VectorType sourceVecType = cast<VectorType>(op.getSource().getType());
1698 Type sourceElemType = sourceVecType.getElementType();
1699
1700 Type intResultType = isa<Float4E2M1FNType>(resultElemType) ? i32 : v2i16;
1701
1702 Value source = adaptor.getSource();
1703 Value scale = adaptor.getScale();
1704 Value existing = adaptor.getExisting();
1705 if (existing)
1706 existing = LLVM::BitcastOp::create(rewriter, loc, intResultType, existing);
1707 else
1708 existing = LLVM::ZeroOp::create(rewriter, loc, intResultType);
1709
1710 if (sourceVecType.getNumElements() < 2) {
1711 Value c0 = createI32Constant(rewriter, loc, 0);
1712 Value elem0 = LLVM::ExtractElementOp::create(rewriter, loc, source, c0);
1713 VectorType v2 = VectorType::get(2, sourceElemType);
1714 source = LLVM::ZeroOp::create(rewriter, loc, v2);
1715 source = LLVM::InsertElementOp::create(rewriter, loc, source, elem0, c0);
1716 }
1717
1718 Value sourceA, sourceB;
1719 if (sourceElemType.isF32()) {
1720 Value c0 = createI32Constant(rewriter, loc, 0);
1721 Value c1 = createI32Constant(rewriter, loc, 1);
1722 sourceA = LLVM::ExtractElementOp::create(rewriter, loc, source, c0);
1723 sourceB = LLVM::ExtractElementOp::create(rewriter, loc, source, c1);
1724 }
1725
1726 Value result;
1727 if (sourceElemType.isF32() && isa<Float8E5M2Type>(resultElemType))
1728 result = ROCDL::CvtScaleF32PkBf8F32Op::create(rewriter, loc, intResultType,
1729 existing, sourceA, sourceB,
1730 scale, op.getIndex());
1731 else if (sourceElemType.isF16() && isa<Float8E5M2Type>(resultElemType))
1732 result = ROCDL::CvtScaleF32PkBf8F16Op::create(
1733 rewriter, loc, intResultType, existing, source, scale, op.getIndex());
1734 else if (sourceElemType.isBF16() && isa<Float8E5M2Type>(resultElemType))
1735 result = ROCDL::CvtScaleF32PkBf8Bf16Op::create(
1736 rewriter, loc, intResultType, existing, source, scale, op.getIndex());
1737 else if (sourceElemType.isF32() && isa<Float8E4M3FNType>(resultElemType))
1738 result = ROCDL::CvtScaleF32PkFp8F32Op::create(rewriter, loc, intResultType,
1739 existing, sourceA, sourceB,
1740 scale, op.getIndex());
1741 else if (sourceElemType.isF16() && isa<Float8E4M3FNType>(resultElemType))
1742 result = ROCDL::CvtScaleF32PkFp8F16Op::create(
1743 rewriter, loc, intResultType, existing, source, scale, op.getIndex());
1744 else if (sourceElemType.isBF16() && isa<Float8E4M3FNType>(resultElemType))
1745 result = ROCDL::CvtScaleF32PkFp8Bf16Op::create(
1746 rewriter, loc, intResultType, existing, source, scale, op.getIndex());
1747 else if (sourceElemType.isF32() && isa<Float4E2M1FNType>(resultElemType))
1748 result = ROCDL::CvtScaleF32PkFp4F32Op::create(rewriter, loc, intResultType,
1749 existing, sourceA, sourceB,
1750 scale, op.getIndex());
1751 else if (sourceElemType.isF16() && isa<Float4E2M1FNType>(resultElemType))
1752 result = ROCDL::CvtScaleF32PkFp4F16Op::create(
1753 rewriter, loc, intResultType, existing, source, scale, op.getIndex());
1754 else if (sourceElemType.isBF16() && isa<Float4E2M1FNType>(resultElemType))
1755 result = ROCDL::CvtScaleF32PkFp4Bf16Op::create(
1756 rewriter, loc, intResultType, existing, source, scale, op.getIndex());
1757 else
1758 return failure();
1759
1760 result = rewriter.replaceOpWithNewOp<LLVM::BitcastOp>(
1761 op, getTypeConverter()->convertType(resultType), result);
1762 return success();
1763}
1764
1765LogicalResult PackedTrunc2xFp8OpLowering::matchAndRewrite(
1766 PackedTrunc2xFp8Op op, PackedTrunc2xFp8OpAdaptor adaptor,
1767 ConversionPatternRewriter &rewriter) const {
1768 Location loc = op.getLoc();
1769 if (!(chipset == kGfx942 || hasOcpFp8(chipset)))
1770 return rewriter.notifyMatchFailure(
1771 loc, "Fp8 conversion instructions are not available on target "
1772 "architecture and their emulation is not implemented");
1773 Type i32 = getTypeConverter()->convertType(rewriter.getI32Type());
1774
1775 Type resultType = op.getResult().getType();
1776 Type resultElemType = getElementTypeOrSelf(resultType);
1777
1778 Value sourceA = adaptor.getSourceA();
1779 Value sourceB = adaptor.getSourceB();
1780 if (!sourceB)
1781 sourceB = LLVM::UndefOp::create(rewriter, loc, sourceA.getType());
1782 Value existing = adaptor.getExisting();
1783 if (existing)
1784 existing = LLVM::BitcastOp::create(rewriter, loc, i32, existing);
1785 else
1786 existing = LLVM::UndefOp::create(rewriter, loc, i32);
1787
1788 Value result;
1789 if (typeIsExpectedBf8ForChipset(chipset, resultElemType))
1790 result = ROCDL::CvtPkBf8F32Op::create(rewriter, loc, i32, sourceA, sourceB,
1791 existing, op.getWordIndex());
1792 else if (typeIsExpectedFp8ForChipset(chipset, resultElemType))
1793 result = ROCDL::CvtPkFp8F32Op::create(rewriter, loc, i32, sourceA, sourceB,
1794 existing, op.getWordIndex());
1795
1796 result = rewriter.replaceOpWithNewOp<LLVM::BitcastOp>(
1797 op, getTypeConverter()->convertType(resultType), result);
1798 return success();
1799}
1800
1801LogicalResult PackedStochRoundFp8OpLowering::matchAndRewrite(
1802 PackedStochRoundFp8Op op, PackedStochRoundFp8OpAdaptor adaptor,
1803 ConversionPatternRewriter &rewriter) const {
1804 Location loc = op.getLoc();
1805 if (!(chipset == kGfx942 || hasOcpFp8(chipset)))
1806 return rewriter.notifyMatchFailure(
1807 loc, "Fp8 conversion instructions are not available on target "
1808 "architecture and their emulation is not implemented");
1809 Type i32 = getTypeConverter()->convertType(rewriter.getI32Type());
1810
1811 Type resultType = op.getResult().getType();
1812 Type resultElemType = getElementTypeOrSelf(resultType);
1813
1814 Value source = adaptor.getSource();
1815 Value stoch = adaptor.getStochiasticParam();
1816 Value existing = adaptor.getExisting();
1817 if (existing)
1818 existing = LLVM::BitcastOp::create(rewriter, loc, i32, existing);
1819 else
1820 existing = LLVM::UndefOp::create(rewriter, loc, i32);
1821
1822 Value result;
1823 if (typeIsExpectedBf8ForChipset(chipset, resultElemType))
1824 result = ROCDL::CvtSrBf8F32Op::create(rewriter, loc, i32, source, stoch,
1825 existing, op.getStoreIndex());
1826 else if (typeIsExpectedFp8ForChipset(chipset, resultElemType))
1827 result = ROCDL::CvtSrFp8F32Op::create(rewriter, loc, i32, source, stoch,
1828 existing, op.getStoreIndex());
1829
1830 result = rewriter.replaceOpWithNewOp<LLVM::BitcastOp>(
1831 op, getTypeConverter()->convertType(resultType), result);
1832 return success();
1833}
1834
1835// Implement the AMDGPU_DPPLowering class that will convert the amdgpu.dpp
1836// operation into the corresponding ROCDL instructions.
1837struct AMDGPUDPPLowering : public ConvertOpToLLVMPattern<DPPOp> {
1838 AMDGPUDPPLowering(const LLVMTypeConverter &converter, Chipset chipset)
1839 : ConvertOpToLLVMPattern<DPPOp>(converter), chipset(chipset) {}
1840 Chipset chipset;
1841
1842 LogicalResult
1843 matchAndRewrite(DPPOp DppOp, DPPOp::Adaptor adaptor,
1844 ConversionPatternRewriter &rewriter) const override {
1845
1846 // Convert the source operand to the corresponding LLVM type
1847 Location loc = DppOp.getLoc();
1848 Value src = adaptor.getSrc();
1849 Value old = adaptor.getOld();
1850 Type srcType = src.getType();
1851 Type oldType = old.getType();
1852 Type llvmType = nullptr;
1853 if (srcType.getIntOrFloatBitWidth() < 32) {
1854 llvmType = rewriter.getI32Type();
1855 } else if (isa<FloatType>(srcType)) {
1856 llvmType = (srcType.getIntOrFloatBitWidth() == 32)
1857 ? rewriter.getF32Type()
1858 : rewriter.getF64Type();
1859 } else if (isa<IntegerType>(srcType)) {
1860 llvmType = (srcType.getIntOrFloatBitWidth() == 32)
1861 ? rewriter.getI32Type()
1862 : rewriter.getI64Type();
1863 }
1864 auto llvmSrcIntType = typeConverter->convertType(
1865 rewriter.getIntegerType(srcType.getIntOrFloatBitWidth()));
1866
1867 // If the source type is less of 32, use bitcast to convert it to i32.
1868 auto convertOperand = [&](Value operand, Type operandType) {
1869 if (operandType.getIntOrFloatBitWidth() <= 16) {
1870 if (llvm::isa<FloatType>(operandType)) {
1871 operand =
1872 LLVM::BitcastOp::create(rewriter, loc, llvmSrcIntType, operand);
1873 }
1874 auto llvmVecType = typeConverter->convertType(mlir::VectorType::get(
1875 32 / operandType.getIntOrFloatBitWidth(), llvmSrcIntType));
1876 Value undefVec = LLVM::UndefOp::create(rewriter, loc, llvmVecType);
1877 operand =
1878 LLVM::InsertElementOp::create(rewriter, loc, undefVec, operand,
1879 createI32Constant(rewriter, loc, 0));
1880 operand = LLVM::BitcastOp::create(rewriter, loc, llvmType, operand);
1881 }
1882 return operand;
1883 };
1884
1885 src = convertOperand(src, srcType);
1886 old = convertOperand(old, oldType);
1887
1888 // This is taken from the following file llvm/lib/Target/AMDGPU/SIDefines.h
1889 enum DppCtrl : unsigned {
1890 ROW_SHL0 = 0x100,
1891 ROW_SHR0 = 0x110,
1892 ROW_ROR0 = 0x120,
1893 WAVE_SHL1 = 0x130,
1894 WAVE_ROL1 = 0x134,
1895 WAVE_SHR1 = 0x138,
1896 WAVE_ROR1 = 0x13C,
1897 ROW_MIRROR = 0x140,
1898 ROW_HALF_MIRROR = 0x141,
1899 BCAST15 = 0x142,
1900 BCAST31 = 0x143,
1901 };
1902
1903 auto kind = DppOp.getKind();
1904 auto permArgument = DppOp.getPermArgument();
1905 uint32_t DppCtrl = 0;
1906
1907 switch (kind) {
1908
1909 case DPPPerm::quad_perm:
1910 if (auto quadPermAttr = cast<ArrayAttr>(*permArgument)) {
1911 int32_t i = 0;
1912 for (auto elem : quadPermAttr.getAsRange<IntegerAttr>()) {
1913 uint32_t num = elem.getInt();
1914 DppCtrl |= num << (i * 2);
1915 i++;
1916 }
1917 }
1918 break;
1919 case DPPPerm::row_shl:
1920 if (auto intAttr = cast<IntegerAttr>(*permArgument)) {
1921 DppCtrl = intAttr.getInt() + DppCtrl::ROW_SHL0;
1922 }
1923 break;
1924 case DPPPerm::row_shr:
1925 if (auto intAttr = cast<IntegerAttr>(*permArgument)) {
1926 DppCtrl = intAttr.getInt() + DppCtrl::ROW_SHR0;
1927 }
1928 break;
1929 case DPPPerm::row_ror:
1930 if (auto intAttr = cast<IntegerAttr>(*permArgument)) {
1931 DppCtrl = intAttr.getInt() + DppCtrl::ROW_ROR0;
1932 }
1933 break;
1934 case DPPPerm::wave_shl:
1935 DppCtrl = DppCtrl::WAVE_SHL1;
1936 break;
1937 case DPPPerm::wave_shr:
1938 DppCtrl = DppCtrl::WAVE_SHR1;
1939 break;
1940 case DPPPerm::wave_rol:
1941 DppCtrl = DppCtrl::WAVE_ROL1;
1942 break;
1943 case DPPPerm::wave_ror:
1944 DppCtrl = DppCtrl::WAVE_ROR1;
1945 break;
1946 case DPPPerm::row_mirror:
1947 DppCtrl = DppCtrl::ROW_MIRROR;
1948 break;
1949 case DPPPerm::row_half_mirror:
1950 DppCtrl = DppCtrl::ROW_HALF_MIRROR;
1951 break;
1952 case DPPPerm::row_bcast_15:
1953 DppCtrl = DppCtrl::BCAST15;
1954 break;
1955 case DPPPerm::row_bcast_31:
1956 DppCtrl = DppCtrl::BCAST31;
1957 break;
1958 }
1959
1960 // Check for row_mask, bank_mask, bound_ctrl if they exist and create
1961 // constants
1962 auto rowMask = DppOp->getAttrOfType<IntegerAttr>("row_mask").getInt();
1963 auto bankMask = DppOp->getAttrOfType<IntegerAttr>("bank_mask").getInt();
1964 bool boundCtrl = DppOp->getAttrOfType<BoolAttr>("bound_ctrl").getValue();
1965
1966 // create a ROCDL_DPPMovOp instruction with the appropriate attributes
1967 auto dppMovOp =
1968 ROCDL::DPPUpdateOp::create(rewriter, loc, llvmType, old, src, DppCtrl,
1969 rowMask, bankMask, boundCtrl);
1970
1971 Value result = dppMovOp.getRes();
1972 if (srcType.getIntOrFloatBitWidth() < 32) {
1973 result = LLVM::TruncOp::create(rewriter, loc, llvmSrcIntType, result);
1974 if (!llvm::isa<IntegerType>(srcType)) {
1975 result = LLVM::BitcastOp::create(rewriter, loc, srcType, result);
1976 }
1977 }
1978
1979 // We are replacing the AMDGPU_DPPOp instruction with the new
1980 // ROCDL_DPPMovOp instruction
1981 rewriter.replaceOp(DppOp, ValueRange(result));
1982 return success();
1983 }
1984};
1985
1986struct AMDGPUSwizzleBitModeLowering
1987 : public ConvertOpToLLVMPattern<SwizzleBitModeOp> {
1989
1990 LogicalResult
1991 matchAndRewrite(SwizzleBitModeOp op, OpAdaptor adaptor,
1992 ConversionPatternRewriter &rewriter) const override {
1993 Location loc = op.getLoc();
1994 Type i32 = rewriter.getI32Type();
1995 Value src = adaptor.getSrc();
1996 SmallVector<Value> decomposed =
1997 LLVM::decomposeValue(rewriter, loc, src, i32);
1998 unsigned andMask = op.getAndMask();
1999 unsigned orMask = op.getOrMask();
2000 unsigned xorMask = op.getXorMask();
2001
2002 // bit 15 is 0 for the BitMode swizzle.
2003 // https://gpuopen.com/learn/amd-gcn-assembly-cross-lane-operations/
2004 unsigned mask = andMask | (orMask << 5) | (xorMask << 10);
2005 Value maskValue = createI32Constant(rewriter, loc, mask);
2006 SmallVector<Value> swizzled;
2007 for (Value v : decomposed) {
2008 Value res =
2009 ROCDL::DsSwizzleOp::create(rewriter, loc, v.getType(), v, maskValue);
2010 swizzled.emplace_back(res);
2011 }
2012
2013 Value result = LLVM::composeValue(rewriter, loc, swizzled, src.getType());
2014 rewriter.replaceOp(op, result);
2015 return success();
2016 }
2017};
2018
2019struct AMDGPUPermlaneLowering : public ConvertOpToLLVMPattern<PermlaneSwapOp> {
2021
2022 AMDGPUPermlaneLowering(const LLVMTypeConverter &converter, Chipset chipset)
2023 : ConvertOpToLLVMPattern<PermlaneSwapOp>(converter), chipset(chipset) {}
2024 Chipset chipset;
2025
2026 LogicalResult
2027 matchAndRewrite(PermlaneSwapOp op, OpAdaptor adaptor,
2028 ConversionPatternRewriter &rewriter) const override {
2029 if (chipset < kGfx950)
2030 return op->emitOpError("permlane_swap is only supported on gfx950+");
2031
2032 Location loc = op.getLoc();
2033 Type i32 = rewriter.getI32Type();
2034 Value src = adaptor.getSrc();
2035 unsigned rowLength = op.getRowLength();
2036 bool fi = op.getFetchInactive();
2037 bool boundctrl = op.getBoundCtrl();
2038
2039 SmallVector<Value> decomposed =
2040 LLVM::decomposeValue(rewriter, loc, src, i32);
2041
2042 SmallVector<Value> permuted;
2043 for (Value v : decomposed) {
2044 Value res;
2045 Type i32pair = LLVM::LLVMStructType::getLiteral(
2046 rewriter.getContext(), {v.getType(), v.getType()});
2047
2048 if (rowLength == 16)
2049 res = ROCDL::Permlane16SwapOp::create(rewriter, loc, i32pair, v, v, fi,
2050 boundctrl);
2051 else if (rowLength == 32)
2052 res = ROCDL::Permlane32SwapOp::create(rewriter, loc, i32pair, v, v, fi,
2053 boundctrl);
2054 else
2055 llvm_unreachable("unsupported row length");
2056
2057 Value vdst0 = LLVM::ExtractValueOp::create(rewriter, loc, res, {0});
2058 Value vdst1 = LLVM::ExtractValueOp::create(rewriter, loc, res, {1});
2059
2060 Value isEqual = LLVM::ICmpOp::create(rewriter, loc,
2061 LLVM::ICmpPredicate::eq, vdst0, v);
2062
2063 // Per `permlane(16|32)` semantics: if the first extracted element equals
2064 // 'v', the result is the second element; otherwise it is the first.
2065 Value vdstNew =
2066 LLVM::SelectOp::create(rewriter, loc, isEqual, vdst1, vdst0);
2067 permuted.emplace_back(vdstNew);
2068 }
2069
2070 Value result = LLVM::composeValue(rewriter, loc, permuted, src.getType());
2071 rewriter.replaceOp(op, result);
2072 return success();
2073 }
2074};
2075
2076struct ConvertAMDGPUToROCDLPass
2077 : public impl::ConvertAMDGPUToROCDLPassBase<ConvertAMDGPUToROCDLPass> {
2078 using Base::Base;
2079
2080 void runOnOperation() override {
2081 MLIRContext *ctx = &getContext();
2082 FailureOr<Chipset> maybeChipset = Chipset::parse(chipset);
2083 if (failed(maybeChipset)) {
2084 emitError(UnknownLoc::get(ctx), "Invalid chipset name: " + chipset);
2085 return signalPassFailure();
2086 }
2087
2088 RewritePatternSet patterns(ctx);
2089 LLVMTypeConverter converter(ctx);
2090 populateAMDGPUToROCDLConversionPatterns(converter, patterns, *maybeChipset);
2091 LLVMConversionTarget target(getContext());
2092 target.addIllegalDialect<::mlir::amdgpu::AMDGPUDialect>();
2093 target.addLegalDialect<::mlir::LLVM::LLVMDialect>();
2094 target.addLegalDialect<::mlir::ROCDL::ROCDLDialect>();
2095 if (failed(applyPartialConversion(getOperation(), target,
2096 std::move(patterns))))
2097 signalPassFailure();
2098 }
2099};
2100} // namespace
2101
2103 TypeConverter &typeConverter) {
2104 typeConverter.addTypeAttributeConversion(
2105 [](BaseMemRefType type, amdgpu::AddressSpaceAttr as)
2106 -> TypeConverter::AttributeConversionResult {
2107 MLIRContext *ctx = as.getContext();
2108 Type i64 = IntegerType::get(ctx, 64);
2109 switch (as.getValue()) {
2110 case amdgpu::AddressSpace::FatRawBuffer:
2111 return IntegerAttr::get(i64, 7);
2112 case amdgpu::AddressSpace::BufferRsrc:
2113 return IntegerAttr::get(i64, 8);
2114 case amdgpu::AddressSpace::FatStructuredBuffer:
2115 return IntegerAttr::get(i64, 9);
2116 }
2117 return TypeConverter::AttributeConversionResult::abort();
2118 });
2119}
2120
2123 Chipset chipset) {
2125 patterns
2126 .add<FatRawBufferCastLowering,
2127 RawBufferOpLowering<RawBufferLoadOp, ROCDL::RawPtrBufferLoadOp>,
2128 RawBufferOpLowering<RawBufferStoreOp, ROCDL::RawPtrBufferStoreOp>,
2129 RawBufferOpLowering<RawBufferAtomicFaddOp,
2130 ROCDL::RawPtrBufferAtomicFaddOp>,
2131 RawBufferOpLowering<RawBufferAtomicFmaxOp,
2132 ROCDL::RawPtrBufferAtomicFmaxOp>,
2133 RawBufferOpLowering<RawBufferAtomicSmaxOp,
2134 ROCDL::RawPtrBufferAtomicSmaxOp>,
2135 RawBufferOpLowering<RawBufferAtomicUminOp,
2136 ROCDL::RawPtrBufferAtomicUminOp>,
2137 RawBufferOpLowering<RawBufferAtomicCmpswapOp,
2138 ROCDL::RawPtrBufferAtomicCmpSwap>,
2139 AMDGPUDPPLowering, MemoryCounterWaitOpLowering, LDSBarrierOpLowering,
2140 SchedBarrierOpLowering, MFMAOpLowering, ScaledMFMAOpLowering,
2141 WMMAOpLowering, ExtPackedFp8OpLowering, ScaledExtPackedOpLowering,
2142 PackedScaledTruncOpLowering, PackedTrunc2xFp8OpLowering,
2143 PackedStochRoundFp8OpLowering, GatherToLDSOpLowering,
2144 TransposeLoadOpLowering, AMDGPUPermlaneLowering>(converter, chipset);
2145 patterns.add<AMDGPUSwizzleBitModeLowering>(converter);
2146}
static bool typeIsExpectedFp8ForChipset(Chipset chipset, Type type)
Return true if type is the E4M3FN variant of an 8-bit float that is supported by the _fp8 instruction...
constexpr Chipset kGfx942
static Value convertMFMAVectorOperand(ConversionPatternRewriter &rewriter, Location loc, Value input, bool allowBf16=true)
Converts a MFMA vector operand from MLIR AMDGPU dialect convention to ROCDL and LLVM AMDGPU intrinsic...
static std::optional< StringRef > wmmaOpToIntrinsicRDNA(Type elemSourceType, Type elemBSourceType, Type elemDestType, uint32_t k, bool isRDNA3)
Returns the rocdl intrinsic corresponding to a WMMA operation wmma for RDNA3/4 architectures.
static std::optional< std::tuple< StringRef, uint32_t, uint32_t > > mfmaOpToScaledIntrinsic(Type aType, Type bType, Type destType, uint32_t m, uint32_t n, uint32_t k, uint32_t b, Chipset chipset)
If there is a scaled MFMA instruction for the input element types aType and bType,...
static std::optional< StringRef > mfmaOpToIntrinsic(MFMAOp mfma, Chipset chipset)
Return the rocdl intrinsic corresponding to a MFMA operation mfma if one exists.
static Value createI1Constant(ConversionPatternRewriter &rewriter, Location loc, bool value)
constexpr Chipset kGfx908
constexpr Chipset kGfx90a
static void wmmaPushOutputOperand(ConversionPatternRewriter &rewriter, Location loc, const TypeConverter *typeConverter, Value output, int32_t subwordOffset, bool clamp, SmallVector< Value, 4 > &operands)
Push the output operand.
static bool typeIsExpectedBf8ForChipset(Chipset chipset, Type type)
Return true if type is the E5M2 variant of an 8-bit float that is supported by the _bf8 instructions ...
static std::optional< StringRef > wmmaOpToIntrinsic(WMMAOp wmma, Chipset chipset)
Returns the rocdl intrinsic corresponding to a WMMA operation wmma if one exists.
static Value makeBufferRsrc(ConversionPatternRewriter &rewriter, Location loc, Value basePointer, Value numRecords, bool boundsCheck, amdgpu::Chipset chipset, Value cacheSwizzleStride=nullptr, unsigned addressSpace=8)
static Value createI64Constant(ConversionPatternRewriter &rewriter, Location loc, int64_t value)
static void wmmaPushInputOperand(ConversionPatternRewriter &rewriter, Location loc, const TypeConverter *typeConverter, bool isUnsigned, Value llvmInput, Value mlirInput, SmallVector< Value, 4 > &operands)
Push an input operand.
static std::optional< StringRef > wmmaOpToIntrinsicGfx1250(Type elemSourceType, Type elemBSourceType, Type elemDestType, uint32_t k)
Return the rocdl intrinsic corresponding to a WMMA operation wmma for the gfx1250 architecture.
static Value getLinearIndexI32(ConversionPatternRewriter &rewriter, Location loc, MemRefDescriptor &memRefDescriptor, ValueRange indices, ArrayRef< int64_t > strides)
Returns the linear index used to access an element in the memref.
static Value convertUnsignedToI32(ConversionPatternRewriter &rewriter, Location loc, Value val)
Convert an unsigned number val to i32.
static Value castMFMAScaleOperand(ConversionPatternRewriter &rewriter, Location loc, Value input)
Converts the scaled MFMA operands, scalesA and scalesB, from MLIR AMDGPU dialect convention to ROCDL ...
static Value createI32Constant(ConversionPatternRewriter &rewriter, Location loc, int32_t value)
static Value convertUnsignedToI64(ConversionPatternRewriter &rewriter, Location loc, Value val)
Convert an unsigned number val to i64.
constexpr Chipset kGfx950
static Value getNumRecords(ConversionPatternRewriter &rewriter, Location loc, MemRefType memrefType, MemRefDescriptor &memrefDescriptor, ArrayRef< int64_t > strides, int64_t elementByteWidth)
Compute the contents of the num_records field for a given memref descriptor - that is,...
static std::optional< uint32_t > mfmaTypeSelectCode(Type mlirElemType)
return success()
b
Return true if permutation is a valid permutation of the outer_dims_perm (case OuterOrInnerPerm::Oute...
ArrayAttr()
b getContext())
auto load
*if copies could not be generated due to yet unimplemented cases *copyInPlacementStart and copyOutPlacementStart in copyPlacementBlock *specify the insertion points where the incoming copies and outgoing should be the output argument nBegin is set to its * replacement(set to `begin` if no invalidation happens). Since outgoing *copies could have been inserted at `end`
static constexpr unsigned kSizePosInMemRefDescriptor
static constexpr unsigned kStridePosInMemRefDescriptor
static constexpr unsigned kOffsetPosInMemRefDescriptor
static constexpr unsigned kAllocatedPtrPosInMemRefDescriptor
static constexpr unsigned kAlignedPtrPosInMemRefDescriptor
static Value clamp(ImplicitLocOpBuilder &builder, Value value, Value lowerBound, Value upperBound)
This class provides a shared interface for ranked and unranked memref types.
Utility class for operation conversions targeting the LLVM dialect that match exactly one source oper...
Definition Pattern.h:209
ConvertOpToLLVMPattern(const LLVMTypeConverter &typeConverter, PatternBenefit benefit=1)
Definition Pattern.h:215
llvm::TypeSize getTypeSizeInBits(Type t) const
Returns the size in bits of the given type in the current scope.
Conversion from types to the LLVM IR dialect.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition Location.h:76
MLIRContext is the top-level object for a collection of MLIR operations.
Definition MLIRContext.h:63
Helper class to produce LLVM dialect operations extracting or inserting elements of a MemRef descript...
Value stride(OpBuilder &builder, Location loc, unsigned pos)
Builds IR extracting the pos-th size from the descriptor.
Value size(OpBuilder &builder, Location loc, unsigned pos)
Builds IR extracting the pos-th size from the descriptor.
Operation is the basic unit of execution within MLIR.
Definition Operation.h:88
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Definition Operation.h:407
result_range getResults()
Definition Operation.h:415
unsigned getNumResults()
Return the number of results held by this operation.
Definition Operation.h:404
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition Types.h:74
bool isF64() const
Definition Types.cpp:41
bool isSignedInteger() const
Return true if this is a signed integer type (with the specified width).
Definition Types.cpp:76
bool isF32() const
Definition Types.cpp:40
bool isUnsignedInteger() const
Return true if this is an unsigned integer type (with the specified width).
Definition Types.cpp:88
bool isInteger() const
Return true if this is an integer type (with the specified width).
Definition Types.cpp:56
bool isF16() const
Definition Types.cpp:38
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
Definition Types.cpp:122
bool isBF16() const
Definition Types.cpp:37
This class provides an abstraction over the different types of ranges over Values.
Definition ValueRange.h:387
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
Type getType() const
Return the type of this value.
Definition Value.h:105
::mlir::Pass::Option< std::string > chipset
Value getStridedElementPtr(OpBuilder &builder, Location loc, const LLVMTypeConverter &converter, MemRefType type, Value memRefDesc, ValueRange indices, LLVM::GEPNoWrapFlags noWrapFlags=LLVM::GEPNoWrapFlags::none)
Performs the index computation to get to the element at indices of the memory pointed to by memRefDes...
Definition Pattern.cpp:478
Value composeValue(OpBuilder &builder, Location loc, ValueRange src, Type dstType)
Composes a set of src values into a single value of type dstType through series of bitcasts and vecto...
Definition Pattern.cpp:439
SmallVector< Value > decomposeValue(OpBuilder &builder, Location loc, Value src, Type dstType)
Decomposes a src value into a set of values of type dstType through series of bitcasts and vector ops...
Definition Pattern.cpp:400
bool hasOcpFp8(const Chipset &chipset)
Definition Chipset.h:52
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition Remarks.h:561
Include the generated interface declarations.
void populateAMDGPUMemorySpaceAttributeConversions(TypeConverter &typeConverter)
Remap AMDGPU memory spaces to LLVM address spaces by mapping amdgpu::AddressSpace::fat_raw_buffer to ...
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
Definition Utils.cpp:304
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
const FrozenRewritePatternSet & patterns
void populateAMDGPUToROCDLConversionPatterns(LLVMTypeConverter &converter, RewritePatternSet &patterns, amdgpu::Chipset chipset)
Note: This function will also add conversions for the AMDGPU-specific address spaces,...
Represents the amdgpu gfx chipset version, e.g., gfx90a, gfx942, gfx1103.
Definition Chipset.h:22
unsigned majorVersion
Definition Chipset.h:23
unsigned minorVersion
Definition Chipset.h:24
static FailureOr< Chipset > parse(StringRef name)
Parses the chipset version string and returns the chipset on success, and failure otherwise.
Definition Chipset.cpp:14